summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--devices/src/virtio/block.rs4
-rw-r--r--devices/src/virtio/descriptor_utils.rs373
-rw-r--r--devices/src/virtio/fs/server.rs9
3 files changed, 203 insertions, 183 deletions
diff --git a/devices/src/virtio/block.rs b/devices/src/virtio/block.rs
index c9dda55..80d5103 100644
--- a/devices/src/virtio/block.rs
+++ b/devices/src/virtio/block.rs
@@ -267,9 +267,7 @@ impl Worker {
         let status_offset = available_bytes
             .checked_sub(1)
             .ok_or(ExecuteError::MissingStatus)?;
-        let mut status_writer = writer
-            .split_at(status_offset)
-            .map_err(ExecuteError::Descriptor)?;
+        let mut status_writer = writer.split_at(status_offset);
 
         let status = match Block::execute_request(
             &mut reader,
diff --git a/devices/src/virtio/descriptor_utils.rs b/devices/src/virtio/descriptor_utils.rs
index 27d4b1c..b767d42 100644
--- a/devices/src/virtio/descriptor_utils.rs
+++ b/devices/src/virtio/descriptor_utils.rs
@@ -3,7 +3,8 @@
 // found in the LICENSE file.
 
 use std::cmp;
-use std::collections::VecDeque;
+use std::convert::TryInto;
+use std::ffi::c_void;
 use std::fmt::{self, Display};
 use std::io::{self, Read, Write};
 use std::iter::FromIterator;
@@ -53,8 +54,10 @@ impl std::error::Error for Error {}
 
 #[derive(Clone)]
 struct DescriptorChainConsumer<'a> {
-    buffers: VecDeque<VolatileSlice<'a>>,
+    buffers: Vec<libc::iovec>,
+    current: usize,
     bytes_consumed: usize,
+    mem: PhantomData<&'a GuestMemory>,
 }
 
 impl<'a> DescriptorChainConsumer<'a> {
@@ -62,140 +65,136 @@ impl<'a> DescriptorChainConsumer<'a> {
         // This is guaranteed not to overflow because the total length of the chain
         // is checked during all creations of `DescriptorChainConsumer` (see
         // `Reader::new()` and `Writer::new()`).
-        self.buffers
+        self.get_remaining()
             .iter()
-            .fold(0usize, |count, vs| count + vs.size() as usize)
+            .fold(0usize, |count, buf| count + buf.iov_len)
     }
 
     fn bytes_consumed(&self) -> usize {
         self.bytes_consumed
     }
 
-    /// Consumes at most `count` bytes from the `DescriptorChain`. Callers must provide a function
-    /// that takes a `&[VolatileSlice]` and returns the total number of bytes consumed. This
-    /// function guarantees that the combined length of all the slices in the `&[VolatileSlice]` is
-    /// less than or equal to `count`.
+    /// Returns all the remaining buffers in the `DescriptorChain`. Calling this function does not
+    /// consume any bytes from the `DescriptorChain`. Instead callers should use the `consume`
+    /// method to advance the `DescriptorChain`. Multiple calls to `get` with no intervening calls
+    /// to `consume` will return the same data.
+    fn get_remaining(&self) -> &[libc::iovec] {
+        &self.buffers[self.current..]
+    }
+
+    /// Consumes `count` bytes from the `DescriptorChain`. If `count` is larger than
+    /// `self.available_bytes()` then all remaining bytes in the `DescriptorChain` will be consumed.
     ///
     /// # Errors
     ///
-    /// If the provided function returns any error then no bytes are consumed from the buffer and
-    /// the error is returned to the caller.
-    fn consume<F>(&mut self, count: usize, f: F) -> io::Result<usize>
-    where
-        F: FnOnce(&[VolatileSlice]) -> io::Result<usize>,
-    {
-        let mut buflen = 0;
-        let mut bufs = Vec::with_capacity(self.buffers.len());
-        for &vs in &self.buffers {
-            if buflen >= count {
+    /// Returns an error if the total bytes consumed by this `DescriptorChainConsumer` overflows a
+    /// usize.
+    fn consume(&mut self, mut count: usize) {
+        // The implementation is adapted from `IoSlice::advance` in libstd. We can't use
+        // `get_remaining` here because then the compiler complains that `self.current` is already
+        // borrowed and doesn't allow us to modify it.  We also need to borrow the iovecs mutably.
+        let current = self.current;
+        for buf in &mut self.buffers[current..] {
+            if count == 0 {
                 break;
             }
 
-            let rem = count - buflen;
-            if (rem as u64) < vs.size() {
-                let buf = vs.sub_slice(0, rem as u64).map_err(|e| {
-                    io::Error::new(io::ErrorKind::InvalidData, Error::VolatileMemoryError(e))
-                })?;
-                bufs.push(buf);
-                buflen += rem;
+            let consumed = if count < buf.iov_len {
+                // Safe because we know that the iovec pointed to valid memory and we are adding a
+                // value that is smaller than the length of the memory.
+                buf.iov_base = unsafe { (buf.iov_base as *mut u8).add(count) as *mut c_void };
+                buf.iov_len -= count;
+                count
             } else {
-                bufs.push(vs);
-                buflen += vs.size() as usize;
-            }
+                self.current += 1;
+                buf.iov_len
+            };
+
+            // This shouldn't overflow because `consumed <= buf.iov_len` and we already verified
+            // that adding all `buf.iov_len` values will not overflow when the Reader/Writer was
+            // constructed.
+            self.bytes_consumed += consumed;
+            count -= consumed;
         }
+    }
 
-        if bufs.is_empty() {
-            return Ok(0);
-        }
+    fn split_at(&mut self, offset: usize) -> DescriptorChainConsumer<'a> {
+        let mut other = self.clone();
+        other.consume(offset);
+        other.bytes_consumed = 0;
 
-        let bytes_consumed = f(&*bufs)?;
-
-        // This can happen if a driver tricks a device into reading/writing more data than
-        // fits in a `usize`.
-        let total_bytes_consumed =
-            self.bytes_consumed
-                .checked_add(bytes_consumed)
-                .ok_or_else(|| {
-                    io::Error::new(io::ErrorKind::InvalidData, Error::DescriptorChainOverflow)
-                })?;
-
-        let mut rem = bytes_consumed;
-        while let Some(vs) = self.buffers.pop_front() {
-            if (rem as u64) < vs.size() {
-                // Split the slice and push the remainder back into the buffer list. Safe because we
-                // know that `rem` is not out of bounds due to the check and we checked the bounds
-                // on `vs` when we added it to the buffer list.
-                self.buffers.push_front(vs.offset(rem as u64).unwrap());
+        let mut rem = offset;
+        let mut end = self.current;
+        for buf in &mut self.buffers[self.current..] {
+            if rem < buf.iov_len {
+                buf.iov_len = rem;
                 break;
             }
 
-            // No need for checked math because we know that `vs.size() <= rem`.
-            rem -= vs.size() as usize;
+            end += 1;
+            rem -= buf.iov_len;
         }
 
-        self.bytes_consumed = total_bytes_consumed;
+        self.buffers.truncate(end + 1);
 
-        Ok(bytes_consumed)
+        other
     }
 
-    fn split_at(&mut self, offset: usize) -> Result<DescriptorChainConsumer<'a>> {
-        let mut rem = offset;
-        let pos = self.buffers.iter().position(|vs| {
-            if (rem as u64) < vs.size() {
-                true
-            } else {
-                rem -= vs.size() as usize;
-                false
-            }
-        });
-
-        if let Some(at) = pos {
-            let mut other = self.buffers.split_off(at);
-
-            if rem > 0 {
-                // There must be at least one element in `other` because we checked
-                // its `size` value in the call to `position` above.
-                let front = other.pop_front().expect("empty VecDeque after split");
-                self.buffers.push_back(
-                    front
-                        .sub_slice(0, rem as u64)
-                        .map_err(Error::VolatileMemoryError)?,
-                );
-                other.push_front(
-                    front
-                        .offset(rem as u64)
-                        .map_err(Error::VolatileMemoryError)?,
-                );
-            }
+    // Temporary method for converting iovecs into VolatileSlices until we can change the
+    // ReadWriteVolatile traits. The irony here is that the standard implementation of the
+    // ReadWriteVolatile traits will convert the VolatileSlices back into iovecs.
+    fn get_volatile_slices(&mut self, mut count: usize) -> Vec<VolatileSlice> {
+        let bufs = self.get_remaining();
+        let mut iovs = Vec::with_capacity(bufs.len());
+        for b in bufs {
+            // Safe because we verified during construction that the memory at `b.iov_base` is
+            // `b.iov_len` bytes long. The lifetime of the `VolatileSlice` is tied to the lifetime
+            // of this `DescriptorChainConsumer`, which is in turn tied to the lifetime of the
+            // `GuestMemory` used to create it and so the memory will be available for the duration
+            // of the `VolatileSlice`.
+            let iov = unsafe {
+                if count < b.iov_len {
+                    VolatileSlice::new(
+                        b.iov_base as *mut u8,
+                        count.try_into().expect("usize doesn't fit in u64"),
+                    )
+                } else {
+                    VolatileSlice::new(
+                        b.iov_base as *mut u8,
+                        b.iov_len.try_into().expect("usize doesn't fit in u64"),
+                    )
+                }
+            };
 
-            Ok(DescriptorChainConsumer {
-                buffers: other,
-                bytes_consumed: 0,
-            })
-        } else if rem == 0 {
-            Ok(DescriptorChainConsumer {
-                buffers: VecDeque::new(),
-                bytes_consumed: 0,
-            })
-        } else {
-            Err(Error::SplitOutOfBounds(offset))
+            count -= iov.size() as usize;
+            iovs.push(iov);
         }
+
+        iovs
     }
 
     fn get_iovec(&mut self, len: usize) -> io::Result<DescriptorIovec<'a>> {
-        let mut iovec = Vec::new();
+        let mut iovec = Vec::with_capacity(self.get_remaining().len());
+
+        let mut rem = len;
+        for buf in self.get_remaining() {
+            let iov = if rem < buf.iov_len {
+                libc::iovec {
+                    iov_base: buf.iov_base,
+                    iov_len: rem,
+                }
+            } else {
+                buf.clone()
+            };
 
-        self.consume(len, |bufs| {
-            let mut total = 0;
-            for vs in bufs {
-                iovec.push(libc::iovec {
-                    iov_base: vs.as_ptr() as *mut libc::c_void,
-                    iov_len: vs.size() as usize,
-                });
-                total += vs.size() as usize;
+            rem -= iov.iov_len;
+            iovec.push(iov);
+
+            if rem == 0 {
+                break;
             }
-            Ok(total)
-        })?;
+        }
+        self.consume(len);
 
         Ok(DescriptorIovec {
             iovec,
@@ -250,14 +249,21 @@ impl<'a> Reader<'a> {
                     .checked_add(desc.len as usize)
                     .ok_or(Error::DescriptorChainOverflow)?;
 
-                mem.get_slice(desc.addr.offset(), desc.len.into())
-                    .map_err(Error::VolatileMemoryError)
+                let vs = mem
+                    .get_slice(desc.addr.offset(), desc.len.into())
+                    .map_err(Error::VolatileMemoryError)?;
+                Ok(libc::iovec {
+                    iov_base: vs.as_ptr() as *mut c_void,
+                    iov_len: vs.size() as usize,
+                })
             })
-            .collect::<Result<VecDeque<VolatileSlice<'a>>>>()?;
+            .collect::<Result<Vec<libc::iovec>>>()?;
         Ok(Reader {
             buffer: DescriptorChainConsumer {
                 buffers,
+                current: 0,
                 bytes_consumed: 0,
+                mem: PhantomData,
             },
         })
     }
@@ -305,8 +311,10 @@ impl<'a> Reader<'a> {
         mut dst: F,
         count: usize,
     ) -> io::Result<usize> {
-        self.buffer
-            .consume(count, |bufs| dst.write_vectored_volatile(bufs))
+        let iovs = self.buffer.get_volatile_slices(count);
+        let written = dst.write_vectored_volatile(&iovs[..])?;
+        self.buffer.consume(written);
+        Ok(written)
     }
 
     /// Reads data from the descriptor chain buffer into a File at offset `off`.
@@ -319,8 +327,10 @@ impl<'a> Reader<'a> {
         count: usize,
         off: u64,
     ) -> io::Result<usize> {
-        self.buffer
-            .consume(count, |bufs| dst.write_vectored_at_volatile(bufs, off))
+        let iovs = self.buffer.get_volatile_slices(count);
+        let written = dst.write_vectored_at_volatile(&iovs[..], off)?;
+        self.buffer.consume(written);
+        Ok(written)
     }
 
     pub fn read_exact_to<F: FileReadWriteVolatile>(
@@ -382,12 +392,14 @@ impl<'a> Reader<'a> {
         self.buffer.bytes_consumed()
     }
 
-    /// Splits this `Reader` into two at the given offset in the `DescriptorChain` buffer.
-    /// After the split, `self` will be able to read up to `offset` bytes while the returned
-    /// `Reader` can read up to `available_bytes() - offset` bytes.  Returns an error if
-    /// `offset > self.available_bytes()`.
-    pub fn split_at(&mut self, offset: usize) -> Result<Reader<'a>> {
-        self.buffer.split_at(offset).map(|buffer| Reader { buffer })
+    /// Splits this `Reader` into two at the given offset in the `DescriptorChain` buffer. After the
+    /// split, `self` will be able to read up to `offset` bytes while the returned `Reader` can read
+    /// up to `available_bytes() - offset` bytes. If `offset > self.available_bytes()`, then the
+    /// returned `Reader` will not be able to read any bytes.
+    pub fn split_at(&mut self, offset: usize) -> Reader<'a> {
+        Reader {
+            buffer: self.buffer.split_at(offset),
+        }
     }
 
     /// Returns a DescriptorIovec for the next `len` bytes of the descriptor chain
@@ -399,27 +411,25 @@ impl<'a> Reader<'a> {
 
 impl<'a> io::Read for Reader<'a> {
     fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
-        self.buffer.consume(buf.len(), |bufs| {
-            let mut rem = buf;
-            let mut total = 0;
-            for vs in bufs {
-                // This is guaranteed by the implementation of `consume`.
-                debug_assert_eq!(vs.size(), cmp::min(rem.len() as u64, vs.size()));
-
-                // Safe because we have already verified that `vs` points to valid memory.
-                unsafe {
-                    copy_nonoverlapping(
-                        vs.as_ptr() as *const u8,
-                        rem.as_mut_ptr(),
-                        vs.size() as usize,
-                    );
-                }
-                let copied = vs.size() as usize;
-                rem = &mut rem[copied..];
-                total += copied;
+        let mut rem = buf;
+        let mut total = 0;
+        for b in self.buffer.get_remaining() {
+            if rem.len() == 0 {
+                break;
             }
-            Ok(total)
-        })
+
+            let count = cmp::min(rem.len(), b.iov_len);
+
+            // Safe because we have already verified that `b` points to valid memory.
+            unsafe {
+                copy_nonoverlapping(b.iov_base as *const u8, rem.as_mut_ptr(), count);
+            }
+            rem = &mut rem[count..];
+            total += count;
+        }
+
+        self.buffer.consume(total);
+        Ok(total)
     }
 }
 
@@ -450,14 +460,21 @@ impl<'a> Writer<'a> {
                     .checked_add(desc.len as usize)
                     .ok_or(Error::DescriptorChainOverflow)?;
 
-                mem.get_slice(desc.addr.offset(), desc.len.into())
-                    .map_err(Error::VolatileMemoryError)
+                let vs = mem
+                    .get_slice(desc.addr.offset(), desc.len.into())
+                    .map_err(Error::VolatileMemoryError)?;
+                Ok(libc::iovec {
+                    iov_base: vs.as_ptr() as *mut c_void,
+                    iov_len: vs.size() as usize,
+                })
             })
-            .collect::<Result<VecDeque<VolatileSlice<'a>>>>()?;
+            .collect::<Result<Vec<libc::iovec>>>()?;
         Ok(Writer {
             buffer: DescriptorChainConsumer {
                 buffers,
+                current: 0,
                 bytes_consumed: 0,
+                mem: PhantomData,
             },
         })
     }
@@ -495,8 +512,10 @@ impl<'a> Writer<'a> {
         mut src: F,
         count: usize,
     ) -> io::Result<usize> {
-        self.buffer
-            .consume(count, |bufs| src.read_vectored_volatile(bufs))
+        let iovs = self.buffer.get_volatile_slices(count);
+        let read = src.read_vectored_volatile(&iovs[..])?;
+        self.buffer.consume(read);
+        Ok(read)
     }
 
     /// Writes data to the descriptor chain buffer from a File at offset `off`.
@@ -509,8 +528,10 @@ impl<'a> Writer<'a> {
         count: usize,
         off: u64,
     ) -> io::Result<usize> {
-        self.buffer
-            .consume(count, |bufs| src.read_vectored_at_volatile(bufs, off))
+        let iovs = self.buffer.get_volatile_slices(count);
+        let read = src.read_vectored_at_volatile(&iovs[..], off)?;
+        self.buffer.consume(read);
+        Ok(read)
     }
 
     pub fn write_all_from<F: FileReadWriteVolatile>(
@@ -565,12 +586,14 @@ impl<'a> Writer<'a> {
         self.buffer.bytes_consumed()
     }
 
-    /// Splits this `Writer` into two at the given offset in the `DescriptorChain` buffer.
-    /// After the split, `self` will be able to write up to `offset` bytes while the returned
-    /// `Writer` can write up to `available_bytes() - offset` bytes.  Returns an error if
-    /// `offset > self.available_bytes()`.
-    pub fn split_at(&mut self, offset: usize) -> Result<Writer<'a>> {
-        self.buffer.split_at(offset).map(|buffer| Writer { buffer })
+    /// Splits this `Writer` into two at the given offset in the `DescriptorChain` buffer. After the
+    /// split, `self` will be able to write up to `offset` bytes while the returned `Writer` can
+    /// write up to `available_bytes() - offset` bytes. If `offset > self.available_bytes()`, then
+    /// the returned `Writer` will not be able to write any bytes.
+    pub fn split_at(&mut self, offset: usize) -> Writer<'a> {
+        Writer {
+            buffer: self.buffer.split_at(offset),
+        }
     }
 
     /// Returns a DescriptorIovec for the next `len` bytes of the descriptor chain
@@ -582,23 +605,24 @@ impl<'a> Writer<'a> {
 
 impl<'a> io::Write for Writer<'a> {
     fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
-        self.buffer.consume(buf.len(), |bufs| {
-            let mut rem = buf;
-            let mut total = 0;
-            for vs in bufs {
-                // This is guaranteed by the implementation of `consume`.
-                debug_assert_eq!(vs.size(), cmp::min(rem.len() as u64, vs.size()));
-
-                // Safe because we have already verified that `vs` points to valid memory.
-                unsafe {
-                    copy_nonoverlapping(rem.as_ptr(), vs.as_ptr(), vs.size() as usize);
-                }
-                let copied = vs.size() as usize;
-                rem = &rem[copied..];
-                total += copied;
+        let mut rem = buf;
+        let mut total = 0;
+        for b in self.buffer.get_remaining() {
+            if rem.len() == 0 {
+                break;
             }
-            Ok(total)
-        })
+
+            let count = cmp::min(rem.len(), b.iov_len);
+            // Safe because we have already verified that `vs` points to valid memory.
+            unsafe {
+                copy_nonoverlapping(rem.as_ptr(), b.iov_base as *mut u8, count);
+            }
+            rem = &rem[count..];
+            total += count;
+        }
+
+        self.buffer.consume(total);
+        Ok(total)
     }
 
     fn flush(&mut self) -> io::Result<()> {
@@ -1031,7 +1055,7 @@ mod tests {
         .expect("create_descriptor_chain failed");
         let mut reader = Reader::new(&memory, chain).expect("failed to create Reader");
 
-        let other = reader.split_at(32).expect("failed to split Reader");
+        let other = reader.split_at(32);
         assert_eq!(reader.available_bytes(), 32);
         assert_eq!(other.available_bytes(), 96);
     }
@@ -1060,7 +1084,7 @@ mod tests {
         .expect("create_descriptor_chain failed");
         let mut reader = Reader::new(&memory, chain).expect("failed to create Reader");
 
-        let other = reader.split_at(24).expect("failed to split Reader");
+        let other = reader.split_at(24);
         assert_eq!(reader.available_bytes(), 24);
         assert_eq!(other.available_bytes(), 104);
     }
@@ -1089,7 +1113,7 @@ mod tests {
         .expect("create_descriptor_chain failed");
         let mut reader = Reader::new(&memory, chain).expect("failed to create Reader");
 
-        let other = reader.split_at(128).expect("failed to split Reader");
+        let other = reader.split_at(128);
         assert_eq!(reader.available_bytes(), 128);
         assert_eq!(other.available_bytes(), 0);
     }
@@ -1118,7 +1142,7 @@ mod tests {
         .expect("create_descriptor_chain failed");
         let mut reader = Reader::new(&memory, chain).expect("failed to create Reader");
 
-        let other = reader.split_at(0).expect("failed to split Reader");
+        let other = reader.split_at(0);
         assert_eq!(reader.available_bytes(), 0);
         assert_eq!(other.available_bytes(), 128);
     }
@@ -1147,9 +1171,12 @@ mod tests {
         .expect("create_descriptor_chain failed");
         let mut reader = Reader::new(&memory, chain).expect("failed to create Reader");
 
-        if let Ok(_) = reader.split_at(256) {
-            panic!("successfully split Reader with out of bounds offset");
-        }
+        let other = reader.split_at(256);
+        assert_eq!(
+            other.available_bytes(),
+            0,
+            "Reader returned from out-of-bounds split still has available bytes"
+        );
     }
 
     #[test]
diff --git a/devices/src/virtio/fs/server.rs b/devices/src/virtio/fs/server.rs
index 33b7c98..c1af80c 100644
--- a/devices/src/virtio/fs/server.rs
+++ b/devices/src/virtio/fs/server.rs
@@ -496,10 +496,7 @@ impl<F: FileSystem + Sync> Server<F> {
         };
 
         // Split the writer into 2 pieces: one for the `OutHeader` and the rest for the data.
-        let data_writer = ZCWriter(
-            w.split_at(size_of::<OutHeader>())
-                .map_err(Error::InvalidDescriptorChain)?,
-        );
+        let data_writer = ZCWriter(w.split_at(size_of::<OutHeader>()));
 
         match self.fs.read(
             Context::from(in_header),
@@ -910,9 +907,7 @@ impl<F: FileSystem + Sync> Server<F> {
         }
 
         // Skip over enough bytes for the header.
-        let mut cursor = w
-            .split_at(size_of::<OutHeader>())
-            .map_err(Error::InvalidDescriptorChain)?;
+        let mut cursor = w.split_at(size_of::<OutHeader>());
 
         let res = if plus {
             self.fs.readdirplus(