diff options
Diffstat (limited to 'devices/src/virtio/controller.rs')
-rw-r--r-- | devices/src/virtio/controller.rs | 316 |
1 files changed, 103 insertions, 213 deletions
diff --git a/devices/src/virtio/controller.rs b/devices/src/virtio/controller.rs index 0a59072..644ece1 100644 --- a/devices/src/virtio/controller.rs +++ b/devices/src/virtio/controller.rs @@ -28,8 +28,6 @@ //! the virtio queue, and routing messages in and out of `WlState`. Possible events include the kill //! event, available descriptors on the `in` or `out` queue, and incoming data on any vfd's socket. -use std::os::unix::prelude::*; - use std::collections::BTreeMap as Map; use std::fmt::{self, Formatter}; use std::os::unix::io::{AsRawFd, RawFd}; @@ -38,10 +36,9 @@ use std::sync::Arc; use std::thread; use msg_socket::{MsgReceiver, MsgSocket}; +use msg_socket2::de::{EnumAccessWithFds, SeqAccessWithFds, VariantAccessWithFds, VisitorWithFds}; use msg_socket2::ser::{SerializeStructVariantFds, SerializeTupleVariantFds}; use msg_socket2::{DeserializeWithFds, DeserializerWithFds, FdSerializer, SerializeWithFds}; -use msg_socket2_derive::SerializeWithFds; -use serde::de::{Deserializer, EnumAccess, SeqAccess, VariantAccess}; use serde::ser::{SerializeStructVariant, SerializeTupleVariant, Serializer}; use serde::{Deserialize, Serialize}; use sys_util::net::UnixSeqpacket; @@ -136,11 +133,9 @@ impl SerializeWithFds for Request { vm_socket: _, memory_params, } => { - let mut sv = serializer.serialize_struct_variant("Request", 0, "Create", 1)?; - - sv.skip_field("vm_socket")?; + let mut sv = serializer.serialize_struct_variant("Request", 0, "Create", 2)?; + sv.serialize_field("vm_socket", &())?; sv.serialize_field("memory_params", memory_params)?; - sv.end() } @@ -180,16 +175,14 @@ impl SerializeWithFds for Request { in_queue_evt: _, out_queue_evt: _, } => { - let mut sv = serializer.serialize_struct_variant("Request", 8, "Activate", 2)?; - - sv.skip_field("shm")?; - sv.skip_field("interrupt")?; - sv.skip_field("interrupt_resample_evt")?; + let mut sv = serializer.serialize_struct_variant("Request", 8, "Activate", 7)?; + sv.serialize_field("shm", &())?; + sv.serialize_field("interrupt", &())?; + sv.serialize_field("interrupt_resample_evt", &())?; sv.serialize_field("in_queue", in_queue)?; sv.serialize_field("out_queue", out_queue)?; - sv.skip_field("in_queue_evt")?; - sv.skip_field("out_queue_evt")?; - + sv.serialize_field("in_queue_evt", &())?; + sv.serialize_field("out_queue_evt", &())?; sv.end() } @@ -284,26 +277,20 @@ impl SerializeWithFds for Request { } impl<'de> DeserializeWithFds<'de> for Request { - fn deserialize<I, De>(deserializer: DeserializerWithFds<I, De>) -> Result<Self, De::Error> - where - I: Iterator<Item = RawFd>, - De: Deserializer<'de>, - { - struct Visitor<'iter, Iter> { - fds: &'iter mut Iter, - } - - impl<'iter, 'de, Iter> serde::de::Visitor<'de> for Visitor<'iter, Iter> - where - Iter: Iterator<Item = RawFd>, - { + fn deserialize<D: DeserializerWithFds<'de>>(deserializer: D) -> Result<Self, D::Error> { + struct Visitor; + + impl<'de> VisitorWithFds<'de> for Visitor { type Value = Request; fn expecting(&self, f: &mut Formatter) -> fmt::Result { write!(f, "enum Request") } - fn visit_enum<A: EnumAccess<'de>>(self, data: A) -> Result<Self::Value, A::Error> { + fn visit_enum<A: EnumAccessWithFds<'de>>( + self, + data: A, + ) -> Result<Self::Value, A::Error> { #[derive(Debug, Deserialize)] enum Variant { Create, @@ -323,54 +310,38 @@ impl<'de> DeserializeWithFds<'de> for Request { match data.variant()? { (Variant::Create, variant) => { - struct Visitor<'iter, Iter> { - fds: &'iter mut Iter, - } + struct Visitor; - impl<'iter, 'de, Iter> serde::de::Visitor<'de> for Visitor<'iter, Iter> - where - Iter: Iterator<Item = RawFd>, - { + impl<'de> VisitorWithFds<'de> for Visitor { type Value = Request; fn expecting(&self, f: &mut Formatter) -> fmt::Result { write!(f, "struct variant Request::Create") } - fn visit_seq<A: SeqAccess<'de>>( + fn visit_seq<A: SeqAccessWithFds<'de>>( self, mut seq: A, ) -> Result<Request, A::Error> { use serde::de::Error; + fn too_short<E: Error>(len: usize) -> E { + E::invalid_length( + len, + &"struct variant Request::Create with 2 elements", + ) + } + Ok(Request::Create { - vm_socket: match self.fds.next() { - Some(vm_socket) => MaybeOwnedFd::Owned(unsafe { - UnixSeqpacket::from_raw_fd(vm_socket) - }), - None => { - return Err(Error::invalid_length( - 0, - &"struct variant Request::Create with 2 elements", - )) - } - }, - - memory_params: match seq.next_element()? { - Some(memory_params) => memory_params, - None => { - return Err(Error::invalid_length( - 1, - &"struct variant Request::Create with 2 elements", - )) - } - }, + vm_socket: seq.next_element()?.ok_or_else(|| too_short(0))?, + memory_params: seq + .next_element()? + .ok_or_else(|| too_short(1))?, }) } } - let visitor = Visitor { fds: self.fds }; - variant.struct_variant(&["memory_params"], visitor) + variant.struct_variant(&["vm_socket", "memory_params"], Visitor) } (Variant::DebugLabel, variant) => { @@ -400,35 +371,29 @@ impl<'de> DeserializeWithFds<'de> for Request { (Variant::ReadConfig, variant) => { struct Visitor; - impl<'de> serde::de::Visitor<'de> for Visitor { + impl<'de> VisitorWithFds<'de> for Visitor { type Value = Request; fn expecting(&self, f: &mut Formatter) -> fmt::Result { write!(f, "struct variant Request::ReadConfig") } - fn visit_seq<A: SeqAccess<'de>>( + fn visit_seq<A: SeqAccessWithFds<'de>>( self, mut seq: A, ) -> Result<Request, A::Error> { use serde::de::Error; + fn too_short<E: Error>(len: usize) -> E { + E::invalid_length( + len, + &"struct variant Request::ReadConfig with 2 elements", + ) + } + Ok(Request::ReadConfig { - offset: match seq.next_element()? { - Some(offset) => offset, - None => return Err(Error::invalid_length( - 0, - &"struct variant Request::ReadConfig with 2 elements", - )), - }, - - len: match seq.next_element()? { - Some(len) => len, - None => return Err(Error::invalid_length( - 1, - &"struct variant Request::ReadConfig with 2 elements", - )), - }, + offset: seq.next_element()?.ok_or_else(|| too_short(0))?, + len: seq.next_element()?.ok_or_else(|| too_short(1))?, }) } } @@ -439,35 +404,29 @@ impl<'de> DeserializeWithFds<'de> for Request { (Variant::WriteConfig, variant) => { struct Visitor; - impl<'de> serde::de::Visitor<'de> for Visitor { + impl<'de> VisitorWithFds<'de> for Visitor { type Value = Request; fn expecting(&self, f: &mut Formatter) -> fmt::Result { write!(f, "struct variant Request::WriteConfig") } - fn visit_seq<A: SeqAccess<'de>>( + fn visit_seq<A: SeqAccessWithFds<'de>>( self, mut seq: A, ) -> Result<Request, A::Error> { use serde::de::Error; + fn too_short<E: Error>(len: usize) -> E { + E::invalid_length( + len, + &"struct variant Request::WriteConfig with 2 elements", + ) + } + Ok(Request::WriteConfig { - offset: match seq.next_element()? { - Some(offset) => offset, - None => return Err(Error::invalid_length( - 0, - &"struct variant Request::WriteConfig with 2 elements", - )), - }, - - data: match seq.next_element()? { - Some(data) => data, - None => return Err(Error::invalid_length( - 1, - &"struct variant Request::WriteConfig with 2 elements", - )), - }, + offset: seq.next_element()?.ok_or_else(|| too_short(0))?, + data: seq.next_element()?.ok_or_else(|| too_short(1))?, }) } } @@ -476,114 +435,58 @@ impl<'de> DeserializeWithFds<'de> for Request { } (Variant::Activate, variant) => { - struct Visitor<'iter, Iter> { - fds: &'iter mut Iter, - } + struct Visitor; - impl<'iter, 'de, Iter> serde::de::Visitor<'de> for Visitor<'iter, Iter> - where - Iter: Iterator<Item = RawFd>, - { + impl<'de> VisitorWithFds<'de> for Visitor { type Value = Request; fn expecting(&self, f: &mut Formatter) -> fmt::Result { write!(f, "struct variant Request::Activate") } - fn visit_seq<A: SeqAccess<'de>>( + fn visit_seq<A: SeqAccessWithFds<'de>>( self, mut seq: A, ) -> Result<Request, A::Error> { use serde::de::Error; + fn too_short<E: Error>(len: usize) -> E { + E::invalid_length( + len, + &"struct variant Request::Activate with 7 elements", + ) + } + Ok(Request::Activate { - shm: match self.fds.next() { - Some(shm) => MaybeOwnedFd::Owned(unsafe { - FromRawFd::from_raw_fd(shm) - }), - None => { - return Err(Error::invalid_length( - 0, - &"struct variant Request::Activate with 7 elements", - )) - } - }, - - interrupt: match self.fds.next() { - Some(interrupt) => MaybeOwnedFd::Owned(unsafe { - UnixSeqpacket::from_raw_fd(interrupt) - }), - None => { - return Err(Error::invalid_length( - 1, - &"struct variant Request::Activate with 7 elements", - )) - } - }, - - interrupt_resample_evt: match self.fds.next() { - Some(interrupt_resample_evt) => { - MaybeOwnedFd::Owned(unsafe { - EventFd::from_raw_fd(interrupt_resample_evt) - }) - } - None => { - return Err(Error::invalid_length( - 2, - &"struct variant Request::Activate with 7 elements", - )) - } - }, - - in_queue: match seq.next_element()? { - Some(in_queue) => in_queue, - None => { - return Err(Error::invalid_length( - 3, - &"struct variant Request::Activate with 7 elements", - )) - } - }, - - out_queue: match seq.next_element()? { - Some(out_queue) => out_queue, - None => { - return Err(Error::invalid_length( - 4, - &"struct variant Request::Activate with 7 elements", - )) - } - }, - - in_queue_evt: match self.fds.next() { - Some(in_queue_evt) => MaybeOwnedFd::Owned(unsafe { - EventFd::from_raw_fd(in_queue_evt) - }), - None => { - return Err(Error::invalid_length( - 5, - &"struct variant Request::Activate with 7 elements", - )) - } - }, - - out_queue_evt: match self.fds.next() { - Some(out_queue_evt) => MaybeOwnedFd::Owned(unsafe { - EventFd::from_raw_fd(out_queue_evt) - }), - None => { - return Err(Error::invalid_length( - 6, - &"struct variant Request::Activate with 7 elements", - )) - } - }, + shm: seq.next_element()?.ok_or_else(|| too_short(0))?, + interrupt: seq.next_element()?.ok_or_else(|| too_short(1))?, + interrupt_resample_evt: seq + .next_element()? + .ok_or_else(|| too_short(2))?, + in_queue: seq.next_element()?.ok_or_else(|| too_short(3))?, + out_queue: seq.next_element()?.ok_or_else(|| too_short(4))?, + in_queue_evt: seq + .next_element()? + .ok_or_else(|| too_short(5))?, + out_queue_evt: seq + .next_element()? + .ok_or_else(|| too_short(6))?, }) } } - let visitor = Visitor { fds: self.fds }; - variant.struct_variant(&["in_queue", "out_queue"], visitor) + variant.struct_variant( + &[ + "shm", + "interrupt", + "interrupt_resample_evt", + "in_queue", + "out_queue", + "in_queue_evt", + "out_queue_evt", + ], + Visitor, + ) } (Variant::Reset, variant) => { @@ -594,28 +497,29 @@ impl<'de> DeserializeWithFds<'de> for Request { (Variant::GetDeviceBars, variant) => { struct Visitor; - impl<'de> serde::de::Visitor<'de> for Visitor { + impl<'de> VisitorWithFds<'de> for Visitor { type Value = Request; fn expecting(&self, f: &mut Formatter) -> fmt::Result { write!(f, "struct variant Request::GetDeviceBars") } - fn visit_seq<A: SeqAccess<'de>>( + fn visit_seq<A: SeqAccessWithFds<'de>>( self, mut seq: A, ) -> Result<Request, A::Error> { use serde::de::Error; - Ok(Request::GetDeviceBars(match seq.next_element()? { - Some(address) => address, - None => { - return Err(Error::invalid_length( - 0, - &"struct variant Request::GetDeviceBars with 1 element", - )) - } - })) + fn too_short<E: Error>(len: usize) -> E { + E::invalid_length( + len, + &"struct variant Request::GetDeviceBars with 2 elements", + ) + } + + Ok(Request::GetDeviceBars( + seq.next_element()?.ok_or_else(|| too_short(0))?, + )) } } @@ -635,11 +539,6 @@ impl<'de> DeserializeWithFds<'de> for Request { } } - let DeserializerWithFds { - mut fds, - deserializer, - } = deserializer; - let visitor = Visitor { fds: &mut fds }; deserializer.deserialize_enum( "Request", &[ @@ -657,12 +556,12 @@ impl<'de> DeserializeWithFds<'de> for Request { "GetDeviceCaps", "Kill", ], - visitor, + Visitor, ) } } -#[derive(Debug, Deserialize, Serialize, SerializeWithFds)] +#[derive(Debug, Deserialize, DeserializeWithFds, Serialize, SerializeWithFds)] #[msg_socket2(strategy = "serde")] pub enum Response { DebugLabel(String), @@ -676,15 +575,6 @@ pub enum Response { Kill, } -impl<'de> DeserializeWithFds<'de> for Response { - fn deserialize<I, De>(deserializer: DeserializerWithFds<I, De>) -> Result<Self, De::Error> - where - De: Deserializer<'de>, - { - Deserialize::deserialize(deserializer.deserializer) - } -} - type Socket = msg_socket2::Socket<Request, Response>; // TODO: support arbitrary number of queues |