summary refs log tree commit diff
diff options
context:
space:
mode:
authorAlyssa Ross <hi@alyssa.is>2020-03-13 04:49:34 +0000
committerAlyssa Ross <hi@alyssa.is>2020-06-15 09:36:32 +0000
commitc8ea9a839e28254474000729fa522b51fa860925 (patch)
treeeebdcaf1556ce8882552f21efa686d21813c67f4
parent29dd8155e7bf1698a035b1f17be6e9cc225e7445 (diff)
downloadcrosvm-c8ea9a839e28254474000729fa522b51fa860925.tar
crosvm-c8ea9a839e28254474000729fa522b51fa860925.tar.gz
crosvm-c8ea9a839e28254474000729fa522b51fa860925.tar.bz2
crosvm-c8ea9a839e28254474000729fa522b51fa860925.tar.lz
crosvm-c8ea9a839e28254474000729fa522b51fa860925.tar.xz
crosvm-c8ea9a839e28254474000729fa522b51fa860925.tar.zst
crosvm-c8ea9a839e28254474000729fa522b51fa860925.zip
poly_msg_socket
we're gonna need this to send all of VirtioDevice over a socket
-rw-r--r--Cargo.lock50
-rw-r--r--Cargo.toml1
-rw-r--r--devices/Cargo.toml3
-rw-r--r--devices/src/virtio/controller.rs62
-rw-r--r--devices/src/virtio/virtio_pci_device.rs3
-rw-r--r--src/linux.rs5
-rw-r--r--src/wl.rs25
7 files changed, 129 insertions, 20 deletions
diff --git a/Cargo.lock b/Cargo.lock
index d236770..efbcb10 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -56,6 +56,15 @@ dependencies = [
 ]
 
 [[package]]
+name = "bincode"
+version = "1.2.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+dependencies = [
+ "byteorder 1.3.4 (registry+https://github.com/rust-lang/crates.io-index)",
+ "serde 1.0.104 (registry+https://github.com/rust-lang/crates.io-index)",
+]
+
+[[package]]
 name = "bit_field"
 version = "0.1.0"
 dependencies = [
@@ -77,6 +86,11 @@ version = "1.1.0"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 
 [[package]]
+name = "byteorder"
+version = "1.3.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+
+[[package]]
 name = "cc"
 version = "1.0.25"
 source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -134,6 +148,7 @@ dependencies = [
  "msg_socket 0.1.0",
  "net_util 0.1.0",
  "p9 0.1.0",
+ "poly_msg_socket 0.1.0",
  "protobuf 2.8.1 (registry+https://github.com/rust-lang/crates.io-index)",
  "protos 0.1.0",
  "rand_ish 0.1.0",
@@ -193,9 +208,12 @@ dependencies = [
  "net_sys 0.1.0",
  "net_util 0.1.0",
  "p9 0.1.0",
+ "poly_msg_socket 0.1.0",
  "protos 0.1.0",
  "remain 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)",
  "resources 0.1.0",
+ "serde 1.0.104 (registry+https://github.com/rust-lang/crates.io-index)",
+ "serde_derive 1.0.104 (registry+https://github.com/rust-lang/crates.io-index)",
  "sync 0.1.0",
  "sys_util 0.1.0",
  "syscall_defines 0.1.0",
@@ -553,6 +571,16 @@ dependencies = [
 ]
 
 [[package]]
+name = "poly_msg_socket"
+version = "0.1.0"
+dependencies = [
+ "bincode 1.2.1 (registry+https://github.com/rust-lang/crates.io-index)",
+ "msg_socket 0.1.0",
+ "serde 1.0.104 (registry+https://github.com/rust-lang/crates.io-index)",
+ "sys_util 0.1.0",
+]
+
+[[package]]
 name = "proc-macro-hack"
 version = "0.5.11"
 source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -659,6 +687,24 @@ dependencies = [
 ]
 
 [[package]]
+name = "serde"
+version = "1.0.104"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+dependencies = [
+ "serde_derive 1.0.104 (registry+https://github.com/rust-lang/crates.io-index)",
+]
+
+[[package]]
+name = "serde_derive"
+version = "1.0.104"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+dependencies = [
+ "proc-macro2 1.0.8 (registry+https://github.com/rust-lang/crates.io-index)",
+ "quote 1.0.2 (registry+https://github.com/rust-lang/crates.io-index)",
+ "syn 1.0.14 (registry+https://github.com/rust-lang/crates.io-index)",
+]
+
+[[package]]
 name = "slab"
 version = "0.4.2"
 source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -814,7 +860,9 @@ dependencies = [
 ]
 
 [metadata]
+"checksum bincode 1.2.1 (registry+https://github.com/rust-lang/crates.io-index)" = "5753e2a71534719bf3f4e57006c3a4f0d2c672a4b676eec84161f763eca87dbf"
 "checksum bitflags 1.1.0 (registry+https://github.com/rust-lang/crates.io-index)" = "3d155346769a6855b86399e9bc3814ab343cd3d62c7e985113d46a0ec3c281fd"
+"checksum byteorder 1.3.4 (registry+https://github.com/rust-lang/crates.io-index)" = "08c48aae112d48ed9f069b33538ea9e3e90aa263cfa3d1c24309612b1f7472de"
 "checksum cc 1.0.25 (registry+https://github.com/rust-lang/crates.io-index)" = "f159dfd43363c4d08055a07703eb7a3406b0dac4d0584d96965a3262db3c9d16"
 "checksum cfg-if 0.1.5 (registry+https://github.com/rust-lang/crates.io-index)" = "0c4e7bb64a8ebb0d856483e1e682ea3422f883c5f5615a90d51a2c82fe87fdd3"
 "checksum futures 0.3.1 (registry+https://github.com/rust-lang/crates.io-index)" = "b6f16056ecbb57525ff698bb955162d0cd03bee84e6241c27ff75c08d8ca5987"
@@ -844,6 +892,8 @@ dependencies = [
 "checksum protoc-rust 2.8.1 (registry+https://github.com/rust-lang/crates.io-index)" = "234c97039c32bb58a883d0deafa57db37e59428ce536f3bdfe1c46cffec04113"
 "checksum quote 1.0.2 (registry+https://github.com/rust-lang/crates.io-index)" = "053a8c8bcc71fcce321828dc897a98ab9760bef03a4fc36693c231e5b3216cfe"
 "checksum remain 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)" = "99c861227fc40c8da6fdaa3d58144ac84c0537080a43eb1d7d45c28f88dcb888"
+"checksum serde 1.0.104 (registry+https://github.com/rust-lang/crates.io-index)" = "414115f25f818d7dfccec8ee535d76949ae78584fc4f79a6f45a904bf8ab4449"
+"checksum serde_derive 1.0.104 (registry+https://github.com/rust-lang/crates.io-index)" = "128f9e303a5a29922045a830221b8f78ec74a5f544944f3d5984f8ec3895ef64"
 "checksum slab 0.4.2 (registry+https://github.com/rust-lang/crates.io-index)" = "c111b5bd5695e56cffe5129854aa230b39c93a305372fdbb2668ca2394eea9f8"
 "checksum syn 1.0.14 (registry+https://github.com/rust-lang/crates.io-index)" = "af6f3550d8dff9ef7dc34d384ac6f107e5d31c8f57d9f28e0081503f547ac8f5"
 "checksum unicode-width 0.1.5 (registry+https://github.com/rust-lang/crates.io-index)" = "882386231c45df4700b275c7ff55b6f3698780a650026380e72dabe76fa46526"
diff --git a/Cargo.toml b/Cargo.toml
index 0821dc0..3123c3d 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -70,6 +70,7 @@ minijail-sys = "*" # provided by ebuild
 msg_socket = { path = "msg_socket" }
 net_util = { path = "net_util" }
 p9 = { path = "p9" }
+poly_msg_socket = { path = "poly_msg_socket" }
 protobuf = { version = "2.3", optional = true }
 protos = { path = "protos", optional = true }
 rand_ish = { path = "rand_ish" }
diff --git a/devices/Cargo.toml b/devices/Cargo.toml
index 8bea78d..939364a 100644
--- a/devices/Cargo.toml
+++ b/devices/Cargo.toml
@@ -37,6 +37,7 @@ msg_socket = { path = "../msg_socket" }
 net_sys = { path = "../net_sys" }
 net_util = { path = "../net_util" }
 p9 = { path = "../p9" }
+poly_msg_socket = { path = "../poly_msg_socket" }
 protos = { path = "../protos", optional = true }
 remain = "*"
 resources = { path = "../resources" }
@@ -49,6 +50,8 @@ vfio_sys = { path = "../vfio_sys" }
 vhost = { path = "../vhost" }
 virtio_sys = { path = "../virtio_sys" }
 vm_control = { path = "../vm_control" }
+serde = "*"
+serde_derive = "*"
 
 [dev-dependencies]
 tempfile = { path = "../tempfile" }
diff --git a/devices/src/virtio/controller.rs b/devices/src/virtio/controller.rs
index 815bb46..7b1b2cc 100644
--- a/devices/src/virtio/controller.rs
+++ b/devices/src/virtio/controller.rs
@@ -37,12 +37,13 @@ use super::resource_bridge::*;
 use super::{Interrupt, InterruptProxyEvent, Queue, VirtioDevice, TYPE_WL, VIRTIO_F_VERSION_1};
 use vm_control::{MaybeOwnedFd, VmMemoryControlRequestSocket};
 
-use msg_socket::{MsgError, MsgOnSocket, MsgReceiver, MsgSender, MsgSocket};
+use msg_socket::{MsgOnSocket, MsgReceiver, MsgSocket};
+use serde::{Deserialize, Serialize};
 use sys_util::net::UnixSeqpacket;
 use sys_util::{error, EventFd, GuestMemory, PollContext, PollToken, SharedMemory};
 
 #[derive(Debug, MsgOnSocket)]
-pub enum Request {
+pub enum MsgOnSocketRequest {
     Create {
         // wayland_paths: Map<String, PathBuf>,
         vm_socket: MaybeOwnedFd<UnixSeqpacket>,
@@ -60,12 +61,50 @@ pub enum Request {
     Kill,
 }
 
+#[derive(Debug, Serialize, Deserialize)]
+pub enum BincodeRequest {
+    WriteConfig { offset: u64, data: Vec<u8> },
+}
+
+pub type Request = poly_msg_socket::Value<MsgOnSocketRequest, BincodeRequest>;
+
+impl From<MsgOnSocketRequest> for Request {
+    fn from(request: MsgOnSocketRequest) -> Self {
+        Self::MsgOnSocket(request)
+    }
+}
+
+impl From<BincodeRequest> for Request {
+    fn from(request: BincodeRequest) -> Self {
+        Self::Bincode(request)
+    }
+}
+
 #[derive(Debug, MsgOnSocket)]
-pub enum Response {
+pub enum MsgOnSocketResponse {
     Kill,
 }
 
-type Socket = MsgSocket<Request, Response>;
+#[derive(Debug, Deserialize, Serialize)]
+pub struct BincodeResponse;
+
+pub type Response = poly_msg_socket::Value<MsgOnSocketResponse, BincodeResponse>;
+
+impl From<MsgOnSocketResponse> for Response {
+    fn from(response: MsgOnSocketResponse) -> Self {
+        Self::MsgOnSocket(response)
+    }
+}
+
+impl From<BincodeResponse> for Response {
+    fn from(response: BincodeResponse) -> Self {
+        Self::Bincode(response)
+    }
+}
+
+use poly_msg_socket::PolyMsgSocket;
+type Socket =
+    PolyMsgSocket<MsgOnSocketRequest, MsgOnSocketResponse, BincodeRequest, BincodeResponse>;
 
 const VIRTIO_WL_F_TRANS_FLAGS: u32 = 0x01;
 
@@ -96,11 +135,14 @@ impl Worker {
     }
 
     fn handle_response(&mut self) {
+        use poly_msg_socket::Value::*;
         match self.device_socket.recv() {
-            Ok(Response::Kill) => {
+            Ok(MsgOnSocket(MsgOnSocketResponse::Kill)) => {
                 self.shutdown = true;
             }
 
+            Ok(Bincode(BincodeResponse)) => unreachable!(),
+
             Err(e) => {
                 error!("recv failed: {:?}", e);
             }
@@ -122,7 +164,9 @@ impl Worker {
     }
 
     fn kill(&self) {
-        if let Err(e) = self.device_socket.send(&Request::Kill) {
+        if let Err(e) = self.device_socket.send(poly_msg_socket::Value::MsgOnSocket(
+            MsgOnSocketRequest::Kill,
+        )) {
             error!("failed to send Kill message: {}", e);
         }
     }
@@ -180,8 +224,8 @@ impl Controller {
         vm_socket: VmMemoryControlRequestSocket,
         resource_bridge: Option<ResourceRequestSocket>,
         socket: Socket,
-    ) -> Result<Controller, MsgError> {
-        socket.send(&Request::Create {
+    ) -> Result<Controller, poly_msg_socket::Error> {
+        socket.send(MsgOnSocketRequest::Create {
             // wayland_paths,
             vm_socket: MaybeOwnedFd::new_borrowed(&vm_socket),
             // resource_bridge,
@@ -267,7 +311,7 @@ impl VirtioDevice for Controller {
 
             let (ours, theirs) = UnixSeqpacket::pair().expect("pair failed");
 
-            if let Err(e) = socket.send(&Request::Activate {
+            if let Err(e) = socket.send(MsgOnSocketRequest::Activate {
                 shm: MaybeOwnedFd::new_borrowed(&mem),
                 interrupt: MaybeOwnedFd::new_borrowed(&theirs),
                 interrupt_resample_evt: MaybeOwnedFd::new_borrowed(interrupt.get_resample_evt()),
diff --git a/devices/src/virtio/virtio_pci_device.rs b/devices/src/virtio/virtio_pci_device.rs
index 0d44e48..7453e4c 100644
--- a/devices/src/virtio/virtio_pci_device.rs
+++ b/devices/src/virtio/virtio_pci_device.rs
@@ -398,6 +398,9 @@ impl PciDevice for VirtioPciDevice {
 
     fn keep_fds(&self) -> Vec<RawFd> {
         let mut fds = self.device.keep_fds();
+        fds.push(0);
+        fds.push(1);
+        fds.push(2);
         if let Some(interrupt_evt) = &self.interrupt_evt {
             fds.push(interrupt_evt.as_raw_fd());
         }
diff --git a/src/linux.rs b/src/linux.rs
index 6f0e19b..7cb4b9c 100644
--- a/src/linux.rs
+++ b/src/linux.rs
@@ -40,6 +40,7 @@ use io_jail::{self, Minijail};
 use kvm::*;
 use msg_socket::{MsgError, MsgReceiver, MsgResult, MsgSender, MsgSocket};
 use net_util::{Error as NetError, MacAddress, Tap};
+use poly_msg_socket::PolyMsgSocket;
 use remain::sorted;
 use resources::{Alloc, MmioType, SystemAllocator};
 use sync::{Condvar, Mutex};
@@ -85,7 +86,7 @@ pub enum Error {
     BuildVm(<Arch as LinuxArch>::Error),
     ChownTpmStorage(sys_util::Error),
     CloneEventFd(sys_util::Error),
-    ControllerCreate(MsgError),
+    ControllerCreate(poly_msg_socket::Error),
     CreateAc97(devices::PciDeviceError),
     CreateConsole(arch::serial::Error),
     CreateDiskError(disk::Error),
@@ -767,7 +768,7 @@ fn create_wayland_device(
     let mut path = std::env::var("XDG_RUNTIME_DIR").expect("XDG_RUNTIME_DIR missing");
     path.push_str("/crosvm-wl.sock");
     let seq_socket = UnixSeqpacket::connect(&path).expect("connect failed");
-    let msg_socket = MsgSocket::new(seq_socket);
+    let msg_socket = PolyMsgSocket::new(seq_socket);
     let dev = virtio::Controller::create(
         cfg.wayland_socket_paths.clone(),
         socket,
diff --git a/src/wl.rs b/src/wl.rs
index 77cf86c..333be87 100644
--- a/src/wl.rs
+++ b/src/wl.rs
@@ -1,7 +1,11 @@
 // SPDX-License-Identifier: BSD-3-Clause
 
-use devices::virtio::{InterruptProxy, InterruptProxyEvent, Request, Response, VirtioDevice, Wl};
-use msg_socket::{MsgReceiver, MsgSender, MsgSocket};
+use devices::virtio::{
+    BincodeRequest, BincodeResponse, InterruptProxy, InterruptProxyEvent, MsgOnSocketRequest,
+    MsgOnSocketResponse, VirtioDevice, Wl,
+};
+use msg_socket::MsgSocket;
+use poly_msg_socket::PolyMsgSocket;
 use std::collections::BTreeMap;
 use std::fs::remove_file;
 use sys_util::{error, net::UnixSeqpacketListener, warn, GuestMemory};
@@ -11,7 +15,8 @@ pub use aarch64::{arch_memory_regions, MemoryParams};
 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
 pub use x86_64::{arch_memory_regions, MemoryParams};
 
-type Socket = MsgSocket<Response, Request>;
+type Socket =
+    PolyMsgSocket<MsgOnSocketResponse, MsgOnSocketRequest, BincodeResponse, BincodeRequest>;
 
 fn main() {
     eprintln!("hello world");
@@ -25,10 +30,12 @@ fn main() {
 
     // Receive connection from crosvm.
     let conn = server.accept().expect("accept failed");
-    let msg_socket: Socket = MsgSocket::new(conn);
+    let msg_socket: Socket = PolyMsgSocket::new(conn);
 
     let vm_socket = match msg_socket.recv() {
-        Ok(Request::Create { vm_socket }) => MsgSocket::new(vm_socket.owned()),
+        Ok(poly_msg_socket::Value::MsgOnSocket(MsgOnSocketRequest::Create { vm_socket })) => {
+            MsgSocket::new(vm_socket.owned())
+        }
 
         Ok(msg) => {
             panic!("received unexpected message: {:?}", msg);
@@ -46,12 +53,12 @@ fn main() {
 
     loop {
         match msg_socket.recv() {
-            Ok(Request::Kill) => {
+            Ok(poly_msg_socket::Value::MsgOnSocket(MsgOnSocketRequest::Kill)) => {
                 if let Some(wl) = wl.take() {
                     // Will block until worker shuts down.
                     drop(wl);
 
-                    if let Err(e) = msg_socket.send(&Response::Kill) {
+                    if let Err(e) = msg_socket.send(MsgOnSocketResponse::Kill) {
                         error!("failed to send Response::Kill: {}", e);
                         break;
                     }
@@ -60,7 +67,7 @@ fn main() {
                 }
             }
 
-            Ok(Request::Activate {
+            Ok(poly_msg_socket::Value::MsgOnSocket(MsgOnSocketRequest::Activate {
                 shm,
                 interrupt,
                 interrupt_resample_evt,
@@ -68,7 +75,7 @@ fn main() {
                 out_queue,
                 in_queue_evt,
                 out_queue_evt,
-            }) => {
+            })) => {
                 let shm = shm.owned();
 
                 let regions = arch_memory_regions(MemoryParams {