summary refs log tree commit diff
path: root/devices/src/virtio/descriptor_utils.rs
diff options
context:
space:
mode:
authorAlyssa Ross <hi@alyssa.is>2020-06-02 03:03:26 +0000
committerAlyssa Ross <hi@alyssa.is>2020-06-14 11:23:24 +0000
commit28d9682698d287d14cbe67a0ed7acc1427add320 (patch)
tree669ed98d9b1388b553c8e0f0189678cc68dd4162 /devices/src/virtio/descriptor_utils.rs
parent460406d10bbfaa890d56d616b4610813da63a312 (diff)
parent4264464153a7a788ef73c5015ac8bbde5f8ebe1c (diff)
downloadcrosvm-28d9682698d287d14cbe67a0ed7acc1427add320.tar
crosvm-28d9682698d287d14cbe67a0ed7acc1427add320.tar.gz
crosvm-28d9682698d287d14cbe67a0ed7acc1427add320.tar.bz2
crosvm-28d9682698d287d14cbe67a0ed7acc1427add320.tar.lz
crosvm-28d9682698d287d14cbe67a0ed7acc1427add320.tar.xz
crosvm-28d9682698d287d14cbe67a0ed7acc1427add320.tar.zst
crosvm-28d9682698d287d14cbe67a0ed7acc1427add320.zip
Merge remote-tracking branch 'origin/master'
Diffstat (limited to 'devices/src/virtio/descriptor_utils.rs')
-rw-r--r--devices/src/virtio/descriptor_utils.rs261
1 files changed, 131 insertions, 130 deletions
diff --git a/devices/src/virtio/descriptor_utils.rs b/devices/src/virtio/descriptor_utils.rs
index d65341b..902e3c3 100644
--- a/devices/src/virtio/descriptor_utils.rs
+++ b/devices/src/virtio/descriptor_utils.rs
@@ -2,9 +2,9 @@
 // Use of this source code is governed by a BSD-style license that can be
 // found in the LICENSE file.
 
+use std::borrow::Cow;
 use std::cmp;
 use std::convert::TryInto;
-use std::ffi::c_void;
 use std::fmt::{self, Display};
 use std::io::{self, Read, Write};
 use std::iter::FromIterator;
@@ -13,10 +13,8 @@ use std::mem::{size_of, MaybeUninit};
 use std::ptr::copy_nonoverlapping;
 use std::result;
 
-use data_model::{DataInit, Le16, Le32, Le64, VolatileMemory, VolatileMemoryError, VolatileSlice};
-use sys_util::{
-    FileReadWriteAtVolatile, FileReadWriteVolatile, GuestAddress, GuestMemory, IntoIovec,
-};
+use data_model::{DataInit, Le16, Le32, Le64, VolatileMemoryError, VolatileSlice};
+use sys_util::{FileReadWriteAtVolatile, FileReadWriteVolatile, GuestAddress, GuestMemory};
 
 use super::DescriptorChain;
 
@@ -54,10 +52,9 @@ impl std::error::Error for Error {}
 
 #[derive(Clone)]
 struct DescriptorChainConsumer<'a> {
-    buffers: Vec<libc::iovec>,
+    buffers: Vec<VolatileSlice<'a>>,
     current: usize,
     bytes_consumed: usize,
-    mem: PhantomData<&'a GuestMemory>,
 }
 
 impl<'a> DescriptorChainConsumer<'a> {
@@ -67,7 +64,7 @@ impl<'a> DescriptorChainConsumer<'a> {
         // `Reader::new()` and `Writer::new()`).
         self.get_remaining()
             .iter()
-            .fold(0usize, |count, buf| count + buf.iov_len)
+            .fold(0usize, |count, buf| count + buf.size())
     }
 
     fn bytes_consumed(&self) -> usize {
@@ -78,10 +75,38 @@ impl<'a> DescriptorChainConsumer<'a> {
     /// consume any bytes from the `DescriptorChain`. Instead callers should use the `consume`
     /// method to advance the `DescriptorChain`. Multiple calls to `get` with no intervening calls
     /// to `consume` will return the same data.
-    fn get_remaining(&self) -> &[libc::iovec] {
+    fn get_remaining(&self) -> &[VolatileSlice] {
         &self.buffers[self.current..]
     }
 
+    /// Like `get_remaining` but guarantees that the combined length of all the returned iovecs is
+    /// not greater than `count`. The combined length of the returned iovecs may be less than
+    /// `count` but will always be greater than 0 as long as there is still space left in the
+    /// `DescriptorChain`.
+    fn get_remaining_with_count(&self, count: usize) -> Cow<[VolatileSlice]> {
+        let iovs = self.get_remaining();
+        let mut iov_count = 0;
+        let mut rem = count;
+        for iov in iovs {
+            if rem < iov.size() {
+                break;
+            }
+
+            iov_count += 1;
+            rem -= iov.size();
+        }
+
+        // Special case where the number of bytes to be copied is smaller than the `size()` of the
+        // first iovec.
+        if iov_count == 0 && iovs.len() > 0 && count > 0 {
+            debug_assert!(count < iovs[0].size());
+            // Safe because we know that count is smaller than the length of the first slice.
+            Cow::Owned(vec![iovs[0].sub_slice(0, count).unwrap()])
+        } else {
+            Cow::Borrowed(&iovs[..iov_count])
+        }
+    }
+
     /// Consumes `count` bytes from the `DescriptorChain`. If `count` is larger than
     /// `self.available_bytes()` then all remaining bytes in the `DescriptorChain` will be consumed.
     ///
@@ -99,19 +124,18 @@ impl<'a> DescriptorChainConsumer<'a> {
                 break;
             }
 
-            let consumed = if count < buf.iov_len {
+            let consumed = if count < buf.size() {
                 // Safe because we know that the iovec pointed to valid memory and we are adding a
                 // value that is smaller than the length of the memory.
-                buf.iov_base = unsafe { (buf.iov_base as *mut u8).add(count) as *mut c_void };
-                buf.iov_len -= count;
+                *buf = buf.offset(count).unwrap();
                 count
             } else {
                 self.current += 1;
-                buf.iov_len
+                buf.size()
             };
 
-            // This shouldn't overflow because `consumed <= buf.iov_len` and we already verified
-            // that adding all `buf.iov_len` values will not overflow when the Reader/Writer was
+            // This shouldn't overflow because `consumed <= buf.size()` and we already verified
+            // that adding all `buf.size()` values will not overflow when the Reader/Writer was
             // constructed.
             self.bytes_consumed += consumed;
             count -= consumed;
@@ -126,81 +150,20 @@ impl<'a> DescriptorChainConsumer<'a> {
         let mut rem = offset;
         let mut end = self.current;
         for buf in &mut self.buffers[self.current..] {
-            if rem < buf.iov_len {
-                buf.iov_len = rem;
+            if rem < buf.size() {
+                // Safe because we are creating a smaller sub-slice.
+                *buf = buf.sub_slice(0, rem).unwrap();
                 break;
             }
 
             end += 1;
-            rem -= buf.iov_len;
+            rem -= buf.size();
         }
 
         self.buffers.truncate(end + 1);
 
         other
     }
-
-    // Temporary method for converting iovecs into VolatileSlices until we can change the
-    // ReadWriteVolatile traits. The irony here is that the standard implementation of the
-    // ReadWriteVolatile traits will convert the VolatileSlices back into iovecs.
-    fn get_volatile_slices(&mut self, mut count: usize) -> Vec<VolatileSlice> {
-        let bufs = self.get_remaining();
-        let mut iovs = Vec::with_capacity(bufs.len());
-        for b in bufs {
-            // Safe because we verified during construction that the memory at `b.iov_base` is
-            // `b.iov_len` bytes long. The lifetime of the `VolatileSlice` is tied to the lifetime
-            // of this `DescriptorChainConsumer`, which is in turn tied to the lifetime of the
-            // `GuestMemory` used to create it and so the memory will be available for the duration
-            // of the `VolatileSlice`.
-            let iov = unsafe {
-                if count < b.iov_len {
-                    VolatileSlice::new(
-                        b.iov_base as *mut u8,
-                        count.try_into().expect("usize doesn't fit in u64"),
-                    )
-                } else {
-                    VolatileSlice::new(
-                        b.iov_base as *mut u8,
-                        b.iov_len.try_into().expect("usize doesn't fit in u64"),
-                    )
-                }
-            };
-
-            count -= iov.size() as usize;
-            iovs.push(iov);
-        }
-
-        iovs
-    }
-
-    fn get_iovec(&mut self, len: usize) -> io::Result<DescriptorIovec<'a>> {
-        let mut iovec = Vec::with_capacity(self.get_remaining().len());
-
-        let mut rem = len;
-        for buf in self.get_remaining() {
-            let iov = if rem < buf.iov_len {
-                libc::iovec {
-                    iov_base: buf.iov_base,
-                    iov_len: rem,
-                }
-            } else {
-                buf.clone()
-            };
-
-            rem -= iov.iov_len;
-            iovec.push(iov);
-
-            if rem == 0 {
-                break;
-            }
-        }
-        self.consume(len);
-
-        Ok(DescriptorIovec {
-            iovec,
-            mem: PhantomData,
-        })
-    }
 }
 
 /// Provides high-level interface over the sequence of memory regions
@@ -249,21 +212,18 @@ impl<'a> Reader<'a> {
                     .checked_add(desc.len as usize)
                     .ok_or(Error::DescriptorChainOverflow)?;
 
-                let vs = mem
-                    .get_slice(desc.addr.offset(), desc.len.into())
-                    .map_err(Error::VolatileMemoryError)?;
-                Ok(libc::iovec {
-                    iov_base: vs.as_ptr() as *mut c_void,
-                    iov_len: vs.size() as usize,
-                })
+                mem.get_slice_at_addr(
+                    desc.addr,
+                    desc.len.try_into().expect("u32 doesn't fit in usize"),
+                )
+                .map_err(Error::GuestMemoryError)
             })
-            .collect::<Result<Vec<libc::iovec>>>()?;
+            .collect::<Result<Vec<VolatileSlice>>>()?;
         Ok(Reader {
             buffer: DescriptorChainConsumer {
                 buffers,
                 current: 0,
                 bytes_consumed: 0,
-                mem: PhantomData,
             },
         })
     }
@@ -311,7 +271,7 @@ impl<'a> Reader<'a> {
         mut dst: F,
         count: usize,
     ) -> io::Result<usize> {
-        let iovs = self.buffer.get_volatile_slices(count);
+        let iovs = self.buffer.get_remaining_with_count(count);
         let written = dst.write_vectored_volatile(&iovs[..])?;
         self.buffer.consume(written);
         Ok(written)
@@ -327,7 +287,7 @@ impl<'a> Reader<'a> {
         count: usize,
         off: u64,
     ) -> io::Result<usize> {
-        let iovs = self.buffer.get_volatile_slices(count);
+        let iovs = self.buffer.get_remaining_with_count(count);
         let written = dst.write_vectored_at_volatile(&iovs[..], off)?;
         self.buffer.consume(written);
         Ok(written)
@@ -392,6 +352,19 @@ impl<'a> Reader<'a> {
         self.buffer.bytes_consumed()
     }
 
+    /// Returns a `&[VolatileSlice]` that represents all the remaining data in this `Reader`.
+    /// Calling this method does not actually consume any data from the `Reader` and callers should
+    /// call `consume` to advance the `Reader`.
+    pub fn get_remaining(&self) -> &[VolatileSlice] {
+        self.buffer.get_remaining()
+    }
+
+    /// Consumes `amt` bytes from the underlying descriptor chain. If `amt` is larger than the
+    /// remaining data left in this `Reader`, then all remaining data will be consumed.
+    pub fn consume(&mut self, amt: usize) {
+        self.buffer.consume(amt)
+    }
+
     /// Splits this `Reader` into two at the given offset in the `DescriptorChain` buffer. After the
     /// split, `self` will be able to read up to `offset` bytes while the returned `Reader` can read
     /// up to `available_bytes() - offset` bytes. If `offset > self.available_bytes()`, then the
@@ -401,12 +374,6 @@ impl<'a> Reader<'a> {
             buffer: self.buffer.split_at(offset),
         }
     }
-
-    /// Returns a DescriptorIovec for the next `len` bytes of the descriptor chain
-    /// buffer, which can be used as an IntoIovec.
-    pub fn get_iovec(&mut self, len: usize) -> io::Result<DescriptorIovec<'a>> {
-        self.buffer.get_iovec(len)
-    }
 }
 
 impl<'a> io::Read for Reader<'a> {
@@ -418,11 +385,11 @@ impl<'a> io::Read for Reader<'a> {
                 break;
             }
 
-            let count = cmp::min(rem.len(), b.iov_len);
+            let count = cmp::min(rem.len(), b.size());
 
             // Safe because we have already verified that `b` points to valid memory.
             unsafe {
-                copy_nonoverlapping(b.iov_base as *const u8, rem.as_mut_ptr(), count);
+                copy_nonoverlapping(b.as_ptr(), rem.as_mut_ptr(), count);
             }
             rem = &mut rem[count..];
             total += count;
@@ -460,21 +427,18 @@ impl<'a> Writer<'a> {
                     .checked_add(desc.len as usize)
                     .ok_or(Error::DescriptorChainOverflow)?;
 
-                let vs = mem
-                    .get_slice(desc.addr.offset(), desc.len.into())
-                    .map_err(Error::VolatileMemoryError)?;
-                Ok(libc::iovec {
-                    iov_base: vs.as_ptr() as *mut c_void,
-                    iov_len: vs.size() as usize,
-                })
+                mem.get_slice_at_addr(
+                    desc.addr,
+                    desc.len.try_into().expect("u32 doesn't fit in usize"),
+                )
+                .map_err(Error::GuestMemoryError)
             })
-            .collect::<Result<Vec<libc::iovec>>>()?;
+            .collect::<Result<Vec<VolatileSlice>>>()?;
         Ok(Writer {
             buffer: DescriptorChainConsumer {
                 buffers,
                 current: 0,
                 bytes_consumed: 0,
-                mem: PhantomData,
             },
         })
     }
@@ -512,7 +476,7 @@ impl<'a> Writer<'a> {
         mut src: F,
         count: usize,
     ) -> io::Result<usize> {
-        let iovs = self.buffer.get_volatile_slices(count);
+        let iovs = self.buffer.get_remaining_with_count(count);
         let read = src.read_vectored_volatile(&iovs[..])?;
         self.buffer.consume(read);
         Ok(read)
@@ -528,7 +492,7 @@ impl<'a> Writer<'a> {
         count: usize,
         off: u64,
     ) -> io::Result<usize> {
-        let iovs = self.buffer.get_volatile_slices(count);
+        let iovs = self.buffer.get_remaining_with_count(count);
         let read = src.read_vectored_at_volatile(&iovs[..], off)?;
         self.buffer.consume(read);
         Ok(read)
@@ -595,12 +559,6 @@ impl<'a> Writer<'a> {
             buffer: self.buffer.split_at(offset),
         }
     }
-
-    /// Returns a DescriptorIovec for the next `len` bytes of the descriptor chain
-    /// buffer, which can be used as an IntoIovec.
-    pub fn get_iovec(&mut self, len: usize) -> io::Result<DescriptorIovec<'a>> {
-        self.buffer.get_iovec(len)
-    }
 }
 
 impl<'a> io::Write for Writer<'a> {
@@ -612,10 +570,10 @@ impl<'a> io::Write for Writer<'a> {
                 break;
             }
 
-            let count = cmp::min(rem.len(), b.iov_len);
+            let count = cmp::min(rem.len(), b.size());
             // Safe because we have already verified that `vs` points to valid memory.
             unsafe {
-                copy_nonoverlapping(rem.as_ptr(), b.iov_base as *mut u8, count);
+                copy_nonoverlapping(rem.as_ptr(), b.as_mut_ptr(), count);
             }
             rem = &rem[count..];
             total += count;
@@ -631,18 +589,6 @@ impl<'a> io::Write for Writer<'a> {
     }
 }
 
-pub struct DescriptorIovec<'a> {
-    iovec: Vec<libc::iovec>,
-    mem: PhantomData<&'a GuestMemory>,
-}
-
-// Safe because the lifetime of DescriptorIovec is tied to the underlying GuestMemory.
-unsafe impl<'a> IntoIovec for DescriptorIovec<'a> {
-    fn into_iovec(&self) -> Vec<libc::iovec> {
-        self.iovec.clone()
-    }
-}
-
 const VIRTQ_DESC_F_NEXT: u16 = 0x1;
 const VIRTQ_DESC_F_WRITE: u16 = 0x2;
 
@@ -1266,4 +1212,59 @@ mod tests {
             .expect("failed to collect() values");
         assert_eq!(vs, vs_read);
     }
+
+    #[test]
+    fn get_remaining_with_count() {
+        use DescriptorType::*;
+
+        let memory_start_addr = GuestAddress(0x0);
+        let memory = GuestMemory::new(&vec![(memory_start_addr, 0x10000)]).unwrap();
+
+        let chain = create_descriptor_chain(
+            &memory,
+            GuestAddress(0x0),
+            GuestAddress(0x100),
+            vec![
+                (Readable, 16),
+                (Readable, 16),
+                (Readable, 96),
+                (Writable, 64),
+                (Writable, 1),
+                (Writable, 3),
+            ],
+            0,
+        )
+        .expect("create_descriptor_chain failed");
+
+        let Reader { mut buffer } = Reader::new(&memory, chain).expect("failed to create Reader");
+
+        let drain = buffer
+            .get_remaining_with_count(::std::usize::MAX)
+            .iter()
+            .fold(0usize, |total, iov| total + iov.size());
+        assert_eq!(drain, 128);
+
+        let exact = buffer
+            .get_remaining_with_count(32)
+            .iter()
+            .fold(0usize, |total, iov| total + iov.size());
+        assert!(exact > 0);
+        assert!(exact <= 32);
+
+        let split = buffer
+            .get_remaining_with_count(24)
+            .iter()
+            .fold(0usize, |total, iov| total + iov.size());
+        assert!(split > 0);
+        assert!(split <= 24);
+
+        buffer.consume(64);
+
+        let first = buffer
+            .get_remaining_with_count(8)
+            .iter()
+            .fold(0usize, |total, iov| total + iov.size());
+        assert!(first > 0);
+        assert!(first <= 8);
+    }
 }