summary refs log tree commit diff
path: root/devices/src/virtio/descriptor_utils.rs
diff options
context:
space:
mode:
Diffstat (limited to 'devices/src/virtio/descriptor_utils.rs')
-rw-r--r--devices/src/virtio/descriptor_utils.rs182
1 files changed, 45 insertions, 137 deletions
diff --git a/devices/src/virtio/descriptor_utils.rs b/devices/src/virtio/descriptor_utils.rs
index b4eff2d..80cc530 100644
--- a/devices/src/virtio/descriptor_utils.rs
+++ b/devices/src/virtio/descriptor_utils.rs
@@ -53,11 +53,13 @@ struct DescriptorChainConsumer<'a> {
 }
 
 impl<'a> DescriptorChainConsumer<'a> {
-    fn available_bytes(&self) -> Result<usize> {
+    fn available_bytes(&self) -> usize {
+        // This is guaranteed not to overflow because the total length of the chain
+        // is checked during all creations of `DescriptorChainConsumer` (see
+        // `Reader::new()` and `Writer::new()`).
         self.buffers
             .iter()
-            .try_fold(0usize, |count, vs| count.checked_add(vs.size() as usize))
-            .ok_or(Error::DescriptorChainOverflow)
+            .fold(0usize, |count, vs| count + vs.size() as usize)
     }
 
     fn bytes_consumed(&self) -> usize {
@@ -192,10 +194,18 @@ impl<'a> Reader<'a> {
     /// Construct a new Reader wrapper over `desc_chain`.
     pub fn new(mem: &'a GuestMemory, desc_chain: DescriptorChain<'a>) -> Result<Reader<'a>> {
         // TODO(jstaron): Update this code to take the indirect descriptors into account.
+        let mut total_len: usize = 0;
         let buffers = desc_chain
             .into_iter()
             .readable()
             .map(|desc| {
+                // Verify that summing the descriptor sizes does not overflow.
+                // This can happen if a driver tricks a device into reading more data than
+                // fits in a `usize`.
+                total_len = total_len
+                    .checked_add(desc.len as usize)
+                    .ok_or(Error::DescriptorChainOverflow)?;
+
                 mem.get_slice(desc.addr.offset(), desc.len.into())
                     .map_err(Error::VolatileMemoryError)
             })
@@ -276,7 +286,7 @@ impl<'a> Reader<'a> {
 
     /// Returns number of bytes available for reading.  May return an error if the combined
     /// lengths of all the buffers in the DescriptorChain would cause an integer overflow.
-    pub fn available_bytes(&self) -> Result<usize> {
+    pub fn available_bytes(&self) -> usize {
         self.buffer.available_bytes()
     }
 
@@ -328,10 +338,18 @@ pub struct Writer<'a> {
 impl<'a> Writer<'a> {
     /// Construct a new Writer wrapper over `desc_chain`.
     pub fn new(mem: &'a GuestMemory, desc_chain: DescriptorChain<'a>) -> Result<Writer<'a>> {
+        let mut total_len: usize = 0;
         let buffers = desc_chain
             .into_iter()
             .writable()
             .map(|desc| {
+                // Verify that summing the descriptor sizes does not overflow.
+                // This can happen if a driver tricks a device into writing more data than
+                // fits in a `usize`.
+                total_len = total_len
+                    .checked_add(desc.len as usize)
+                    .ok_or(Error::DescriptorChainOverflow)?;
+
                 mem.get_slice(desc.addr.offset(), desc.len.into())
                     .map_err(Error::VolatileMemoryError)
             })
@@ -351,7 +369,7 @@ impl<'a> Writer<'a> {
 
     /// Returns number of bytes available for writing.  May return an error if the combined
     /// lengths of all the buffers in the DescriptorChain would cause an overflow.
-    pub fn available_bytes(&self) -> Result<usize> {
+    pub fn available_bytes(&self) -> usize {
         self.buffer.available_bytes()
     }
 
@@ -532,12 +550,7 @@ mod tests {
         )
         .expect("create_descriptor_chain failed");
         let mut reader = Reader::new(&memory, chain).expect("failed to create Reader");
-        assert_eq!(
-            reader
-                .available_bytes()
-                .expect("failed to get available bytes"),
-            106
-        );
+        assert_eq!(reader.available_bytes(), 106);
         assert_eq!(reader.bytes_read(), 0);
 
         let mut buffer = [0 as u8; 64];
@@ -545,12 +558,7 @@ mod tests {
             panic!("read_exact should not fail here");
         }
 
-        assert_eq!(
-            reader
-                .available_bytes()
-                .expect("failed to get available bytes"),
-            42
-        );
+        assert_eq!(reader.available_bytes(), 42);
         assert_eq!(reader.bytes_read(), 64);
 
         match reader.read(&mut buffer) {
@@ -558,12 +566,7 @@ mod tests {
             Ok(length) => assert_eq!(length, 42),
         }
 
-        assert_eq!(
-            reader
-                .available_bytes()
-                .expect("failed to get available bytes"),
-            0
-        );
+        assert_eq!(reader.available_bytes(), 0);
         assert_eq!(reader.bytes_read(), 106);
     }
 
@@ -588,12 +591,7 @@ mod tests {
         )
         .expect("create_descriptor_chain failed");;
         let mut writer = Writer::new(&memory, chain).expect("failed to create Writer");
-        assert_eq!(
-            writer
-                .available_bytes()
-                .expect("failed to get available bytes"),
-            106
-        );
+        assert_eq!(writer.available_bytes(), 106);
         assert_eq!(writer.bytes_written(), 0);
 
         let mut buffer = [0 as u8; 64];
@@ -601,12 +599,7 @@ mod tests {
             panic!("write_all should not fail here");
         }
 
-        assert_eq!(
-            writer
-                .available_bytes()
-                .expect("failed to get available bytes"),
-            42
-        );
+        assert_eq!(writer.available_bytes(), 42);
         assert_eq!(writer.bytes_written(), 64);
 
         match writer.write(&mut buffer) {
@@ -614,12 +607,7 @@ mod tests {
             Ok(length) => assert_eq!(length, 42),
         }
 
-        assert_eq!(
-            writer
-                .available_bytes()
-                .expect("failed to get available bytes"),
-            0
-        );
+        assert_eq!(writer.available_bytes(), 0);
         assert_eq!(writer.bytes_written(), 106);
     }
 
@@ -639,22 +627,12 @@ mod tests {
         )
         .expect("create_descriptor_chain failed");;
         let mut reader = Reader::new(&memory, chain).expect("failed to create Reader");
-        assert_eq!(
-            reader
-                .available_bytes()
-                .expect("failed to get available bytes"),
-            0
-        );
+        assert_eq!(reader.available_bytes(), 0);
         assert_eq!(reader.bytes_read(), 0);
 
         assert!(reader.read_obj::<u8>().is_err());
 
-        assert_eq!(
-            reader
-                .available_bytes()
-                .expect("failed to get available bytes"),
-            0
-        );
+        assert_eq!(reader.available_bytes(), 0);
         assert_eq!(reader.bytes_read(), 0);
     }
 
@@ -674,22 +652,12 @@ mod tests {
         )
         .expect("create_descriptor_chain failed");;
         let mut writer = Writer::new(&memory, chain).expect("failed to create Writer");
-        assert_eq!(
-            writer
-                .available_bytes()
-                .expect("failed to get available bytes"),
-            0
-        );
+        assert_eq!(writer.available_bytes(), 0);
         assert_eq!(writer.bytes_written(), 0);
 
         assert!(writer.write_obj(0u8).is_err());
 
-        assert_eq!(
-            writer
-                .available_bytes()
-                .expect("failed to get available bytes"),
-            0
-        );
+        assert_eq!(writer.available_bytes(), 0);
         assert_eq!(writer.bytes_written(), 0);
     }
 
@@ -727,12 +695,7 @@ mod tests {
         // Linux doesn't do partial writes if you give a buffer larger than the remaining length of
         // the shared memory. And since we passed an iovec with the full contents of the
         // DescriptorChain we ended up not writing any bytes at all.
-        assert_eq!(
-            reader
-                .available_bytes()
-                .expect("failed to get available bytes"),
-            512
-        );
+        assert_eq!(reader.available_bytes(), 512);
         assert_eq!(reader.bytes_read(), 0);
     }
 
@@ -762,12 +725,7 @@ mod tests {
             .write_all_from(&mut shm, 512)
             .expect_err("successfully wrote more bytes than in SharedMemory");
 
-        assert_eq!(
-            writer
-                .available_bytes()
-                .expect("failed to get available bytes"),
-            128
-        );
+        assert_eq!(writer.available_bytes(), 128);
         assert_eq!(writer.bytes_written(), 384);
     }
 
@@ -813,19 +771,9 @@ mod tests {
             .write_all(&buffer[..68])
             .expect("write should not fail here");
 
-        assert_eq!(
-            reader
-                .available_bytes()
-                .expect("failed to get available bytes"),
-            0
-        );
+        assert_eq!(reader.available_bytes(), 0);
         assert_eq!(reader.bytes_read(), 128);
-        assert_eq!(
-            writer
-                .available_bytes()
-                .expect("failed to get available bytes"),
-            0
-        );
+        assert_eq!(writer.available_bytes(), 0);
         assert_eq!(writer.bytes_written(), 68);
     }
 
@@ -923,18 +871,8 @@ mod tests {
         let mut reader = Reader::new(&memory, chain).expect("failed to create Reader");
 
         let other = reader.split_at(32).expect("failed to split Reader");
-        assert_eq!(
-            reader
-                .available_bytes()
-                .expect("failed to get available bytes"),
-            32
-        );
-        assert_eq!(
-            other
-                .available_bytes()
-                .expect("failed to get available bytes"),
-            96
-        );
+        assert_eq!(reader.available_bytes(), 32);
+        assert_eq!(other.available_bytes(), 96);
     }
 
     #[test]
@@ -962,18 +900,8 @@ mod tests {
         let mut reader = Reader::new(&memory, chain).expect("failed to create Reader");
 
         let other = reader.split_at(24).expect("failed to split Reader");
-        assert_eq!(
-            reader
-                .available_bytes()
-                .expect("failed to get available bytes"),
-            24
-        );
-        assert_eq!(
-            other
-                .available_bytes()
-                .expect("failed to get available bytes"),
-            104
-        );
+        assert_eq!(reader.available_bytes(), 24);
+        assert_eq!(other.available_bytes(), 104);
     }
 
     #[test]
@@ -1001,18 +929,8 @@ mod tests {
         let mut reader = Reader::new(&memory, chain).expect("failed to create Reader");
 
         let other = reader.split_at(128).expect("failed to split Reader");
-        assert_eq!(
-            reader
-                .available_bytes()
-                .expect("failed to get available bytes"),
-            128
-        );
-        assert_eq!(
-            other
-                .available_bytes()
-                .expect("failed to get available bytes"),
-            0
-        );
+        assert_eq!(reader.available_bytes(), 128);
+        assert_eq!(other.available_bytes(), 0);
     }
 
     #[test]
@@ -1040,18 +958,8 @@ mod tests {
         let mut reader = Reader::new(&memory, chain).expect("failed to create Reader");
 
         let other = reader.split_at(0).expect("failed to split Reader");
-        assert_eq!(
-            reader
-                .available_bytes()
-                .expect("failed to get available bytes"),
-            0
-        );
-        assert_eq!(
-            other
-                .available_bytes()
-                .expect("failed to get available bytes"),
-            128
-        );
+        assert_eq!(reader.available_bytes(), 0);
+        assert_eq!(other.available_bytes(), 128);
     }
 
     #[test]