summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--devices/src/virtio/descriptor_utils.rs80
1 files changed, 68 insertions, 12 deletions
diff --git a/devices/src/virtio/descriptor_utils.rs b/devices/src/virtio/descriptor_utils.rs
index 3c8ff87..b4eff2d 100644
--- a/devices/src/virtio/descriptor_utils.rs
+++ b/devices/src/virtio/descriptor_utils.rs
@@ -297,14 +297,18 @@ impl<'a> Reader<'a> {
 impl<'a> io::Read for Reader<'a> {
     fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
         self.buffer.consume(buf.len(), |bufs| {
-            if let Some(vs) = bufs.first() {
+            let mut rem = buf;
+            let mut total = 0;
+            for vs in bufs {
                 // This is guaranteed by the implementation of `consume`.
-                debug_assert_eq!(vs.size(), cmp::min(buf.len() as u64, vs.size()));
-                vs.copy_to(buf);
-                Ok(vs.size() as usize)
-            } else {
-                Ok(0)
+                debug_assert_eq!(vs.size(), cmp::min(rem.len() as u64, vs.size()));
+
+                vs.copy_to(rem);
+                let copied = vs.size() as usize;
+                rem = &mut rem[copied..];
+                total += copied;
             }
+            Ok(total)
         })
     }
 }
@@ -417,14 +421,18 @@ impl<'a> Writer<'a> {
 impl<'a> io::Write for Writer<'a> {
     fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
         self.buffer.consume(buf.len(), |bufs| {
-            if let Some(vs) = bufs.first() {
+            let mut rem = buf;
+            let mut total = 0;
+            for vs in bufs {
                 // This is guaranteed by the implementation of `consume`.
-                debug_assert_eq!(vs.size(), cmp::min(buf.len() as u64, vs.size()));
-                vs.copy_from(buf);
-                Ok(vs.size() as usize)
-            } else {
-                Ok(0)
+                debug_assert_eq!(vs.size(), cmp::min(rem.len() as u64, vs.size()));
+
+                vs.copy_from(rem);
+                let copied = vs.size() as usize;
+                rem = &rem[copied..];
+                total += copied;
             }
+            Ok(total)
         })
     }
 
@@ -1074,4 +1082,52 @@ mod tests {
             panic!("successfully split Reader with out of bounds offset");
         }
     }
+
+    #[test]
+    fn read_full() {
+        use DescriptorType::*;
+
+        let memory_start_addr = GuestAddress(0x0);
+        let memory = GuestMemory::new(&vec![(memory_start_addr, 0x10000)]).unwrap();
+
+        let chain = create_descriptor_chain(
+            &memory,
+            GuestAddress(0x0),
+            GuestAddress(0x100),
+            vec![(Readable, 16), (Readable, 16), (Readable, 16)],
+            0,
+        )
+        .expect("create_descriptor_chain failed");
+        let mut reader = Reader::new(&memory, chain).expect("failed to create Reader");
+
+        let mut buf = vec![0u8; 64];
+        assert_eq!(
+            reader.read(&mut buf[..]).expect("failed to read to buffer"),
+            48
+        );
+    }
+
+    #[test]
+    fn write_full() {
+        use DescriptorType::*;
+
+        let memory_start_addr = GuestAddress(0x0);
+        let memory = GuestMemory::new(&vec![(memory_start_addr, 0x10000)]).unwrap();
+
+        let chain = create_descriptor_chain(
+            &memory,
+            GuestAddress(0x0),
+            GuestAddress(0x100),
+            vec![(Writable, 16), (Writable, 16), (Writable, 16)],
+            0,
+        )
+        .expect("create_descriptor_chain failed");
+        let mut writer = Writer::new(&memory, chain).expect("failed to create Writer");
+
+        let buf = vec![0xdeu8; 64];
+        assert_eq!(
+            writer.write(&buf[..]).expect("failed to write from buffer"),
+            48
+        );
+    }
 }