diff options
author | Alyssa Ross <hi@alyssa.is> | 2020-06-02 03:03:26 +0000 |
---|---|---|
committer | Alyssa Ross <hi@alyssa.is> | 2020-06-14 11:23:24 +0000 |
commit | 28d9682698d287d14cbe67a0ed7acc1427add320 (patch) | |
tree | 669ed98d9b1388b553c8e0f0189678cc68dd4162 /devices/src/virtio/descriptor_utils.rs | |
parent | 460406d10bbfaa890d56d616b4610813da63a312 (diff) | |
parent | 4264464153a7a788ef73c5015ac8bbde5f8ebe1c (diff) | |
download | crosvm-28d9682698d287d14cbe67a0ed7acc1427add320.tar crosvm-28d9682698d287d14cbe67a0ed7acc1427add320.tar.gz crosvm-28d9682698d287d14cbe67a0ed7acc1427add320.tar.bz2 crosvm-28d9682698d287d14cbe67a0ed7acc1427add320.tar.lz crosvm-28d9682698d287d14cbe67a0ed7acc1427add320.tar.xz crosvm-28d9682698d287d14cbe67a0ed7acc1427add320.tar.zst crosvm-28d9682698d287d14cbe67a0ed7acc1427add320.zip |
Merge remote-tracking branch 'origin/master'
Diffstat (limited to 'devices/src/virtio/descriptor_utils.rs')
-rw-r--r-- | devices/src/virtio/descriptor_utils.rs | 261 |
1 files changed, 131 insertions, 130 deletions
diff --git a/devices/src/virtio/descriptor_utils.rs b/devices/src/virtio/descriptor_utils.rs index d65341b..902e3c3 100644 --- a/devices/src/virtio/descriptor_utils.rs +++ b/devices/src/virtio/descriptor_utils.rs @@ -2,9 +2,9 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. +use std::borrow::Cow; use std::cmp; use std::convert::TryInto; -use std::ffi::c_void; use std::fmt::{self, Display}; use std::io::{self, Read, Write}; use std::iter::FromIterator; @@ -13,10 +13,8 @@ use std::mem::{size_of, MaybeUninit}; use std::ptr::copy_nonoverlapping; use std::result; -use data_model::{DataInit, Le16, Le32, Le64, VolatileMemory, VolatileMemoryError, VolatileSlice}; -use sys_util::{ - FileReadWriteAtVolatile, FileReadWriteVolatile, GuestAddress, GuestMemory, IntoIovec, -}; +use data_model::{DataInit, Le16, Le32, Le64, VolatileMemoryError, VolatileSlice}; +use sys_util::{FileReadWriteAtVolatile, FileReadWriteVolatile, GuestAddress, GuestMemory}; use super::DescriptorChain; @@ -54,10 +52,9 @@ impl std::error::Error for Error {} #[derive(Clone)] struct DescriptorChainConsumer<'a> { - buffers: Vec<libc::iovec>, + buffers: Vec<VolatileSlice<'a>>, current: usize, bytes_consumed: usize, - mem: PhantomData<&'a GuestMemory>, } impl<'a> DescriptorChainConsumer<'a> { @@ -67,7 +64,7 @@ impl<'a> DescriptorChainConsumer<'a> { // `Reader::new()` and `Writer::new()`). self.get_remaining() .iter() - .fold(0usize, |count, buf| count + buf.iov_len) + .fold(0usize, |count, buf| count + buf.size()) } fn bytes_consumed(&self) -> usize { @@ -78,10 +75,38 @@ impl<'a> DescriptorChainConsumer<'a> { /// consume any bytes from the `DescriptorChain`. Instead callers should use the `consume` /// method to advance the `DescriptorChain`. Multiple calls to `get` with no intervening calls /// to `consume` will return the same data. - fn get_remaining(&self) -> &[libc::iovec] { + fn get_remaining(&self) -> &[VolatileSlice] { &self.buffers[self.current..] } + /// Like `get_remaining` but guarantees that the combined length of all the returned iovecs is + /// not greater than `count`. The combined length of the returned iovecs may be less than + /// `count` but will always be greater than 0 as long as there is still space left in the + /// `DescriptorChain`. + fn get_remaining_with_count(&self, count: usize) -> Cow<[VolatileSlice]> { + let iovs = self.get_remaining(); + let mut iov_count = 0; + let mut rem = count; + for iov in iovs { + if rem < iov.size() { + break; + } + + iov_count += 1; + rem -= iov.size(); + } + + // Special case where the number of bytes to be copied is smaller than the `size()` of the + // first iovec. + if iov_count == 0 && iovs.len() > 0 && count > 0 { + debug_assert!(count < iovs[0].size()); + // Safe because we know that count is smaller than the length of the first slice. + Cow::Owned(vec![iovs[0].sub_slice(0, count).unwrap()]) + } else { + Cow::Borrowed(&iovs[..iov_count]) + } + } + /// Consumes `count` bytes from the `DescriptorChain`. If `count` is larger than /// `self.available_bytes()` then all remaining bytes in the `DescriptorChain` will be consumed. /// @@ -99,19 +124,18 @@ impl<'a> DescriptorChainConsumer<'a> { break; } - let consumed = if count < buf.iov_len { + let consumed = if count < buf.size() { // Safe because we know that the iovec pointed to valid memory and we are adding a // value that is smaller than the length of the memory. - buf.iov_base = unsafe { (buf.iov_base as *mut u8).add(count) as *mut c_void }; - buf.iov_len -= count; + *buf = buf.offset(count).unwrap(); count } else { self.current += 1; - buf.iov_len + buf.size() }; - // This shouldn't overflow because `consumed <= buf.iov_len` and we already verified - // that adding all `buf.iov_len` values will not overflow when the Reader/Writer was + // This shouldn't overflow because `consumed <= buf.size()` and we already verified + // that adding all `buf.size()` values will not overflow when the Reader/Writer was // constructed. self.bytes_consumed += consumed; count -= consumed; @@ -126,81 +150,20 @@ impl<'a> DescriptorChainConsumer<'a> { let mut rem = offset; let mut end = self.current; for buf in &mut self.buffers[self.current..] { - if rem < buf.iov_len { - buf.iov_len = rem; + if rem < buf.size() { + // Safe because we are creating a smaller sub-slice. + *buf = buf.sub_slice(0, rem).unwrap(); break; } end += 1; - rem -= buf.iov_len; + rem -= buf.size(); } self.buffers.truncate(end + 1); other } - - // Temporary method for converting iovecs into VolatileSlices until we can change the - // ReadWriteVolatile traits. The irony here is that the standard implementation of the - // ReadWriteVolatile traits will convert the VolatileSlices back into iovecs. - fn get_volatile_slices(&mut self, mut count: usize) -> Vec<VolatileSlice> { - let bufs = self.get_remaining(); - let mut iovs = Vec::with_capacity(bufs.len()); - for b in bufs { - // Safe because we verified during construction that the memory at `b.iov_base` is - // `b.iov_len` bytes long. The lifetime of the `VolatileSlice` is tied to the lifetime - // of this `DescriptorChainConsumer`, which is in turn tied to the lifetime of the - // `GuestMemory` used to create it and so the memory will be available for the duration - // of the `VolatileSlice`. - let iov = unsafe { - if count < b.iov_len { - VolatileSlice::new( - b.iov_base as *mut u8, - count.try_into().expect("usize doesn't fit in u64"), - ) - } else { - VolatileSlice::new( - b.iov_base as *mut u8, - b.iov_len.try_into().expect("usize doesn't fit in u64"), - ) - } - }; - - count -= iov.size() as usize; - iovs.push(iov); - } - - iovs - } - - fn get_iovec(&mut self, len: usize) -> io::Result<DescriptorIovec<'a>> { - let mut iovec = Vec::with_capacity(self.get_remaining().len()); - - let mut rem = len; - for buf in self.get_remaining() { - let iov = if rem < buf.iov_len { - libc::iovec { - iov_base: buf.iov_base, - iov_len: rem, - } - } else { - buf.clone() - }; - - rem -= iov.iov_len; - iovec.push(iov); - - if rem == 0 { - break; - } - } - self.consume(len); - - Ok(DescriptorIovec { - iovec, - mem: PhantomData, - }) - } } /// Provides high-level interface over the sequence of memory regions @@ -249,21 +212,18 @@ impl<'a> Reader<'a> { .checked_add(desc.len as usize) .ok_or(Error::DescriptorChainOverflow)?; - let vs = mem - .get_slice(desc.addr.offset(), desc.len.into()) - .map_err(Error::VolatileMemoryError)?; - Ok(libc::iovec { - iov_base: vs.as_ptr() as *mut c_void, - iov_len: vs.size() as usize, - }) + mem.get_slice_at_addr( + desc.addr, + desc.len.try_into().expect("u32 doesn't fit in usize"), + ) + .map_err(Error::GuestMemoryError) }) - .collect::<Result<Vec<libc::iovec>>>()?; + .collect::<Result<Vec<VolatileSlice>>>()?; Ok(Reader { buffer: DescriptorChainConsumer { buffers, current: 0, bytes_consumed: 0, - mem: PhantomData, }, }) } @@ -311,7 +271,7 @@ impl<'a> Reader<'a> { mut dst: F, count: usize, ) -> io::Result<usize> { - let iovs = self.buffer.get_volatile_slices(count); + let iovs = self.buffer.get_remaining_with_count(count); let written = dst.write_vectored_volatile(&iovs[..])?; self.buffer.consume(written); Ok(written) @@ -327,7 +287,7 @@ impl<'a> Reader<'a> { count: usize, off: u64, ) -> io::Result<usize> { - let iovs = self.buffer.get_volatile_slices(count); + let iovs = self.buffer.get_remaining_with_count(count); let written = dst.write_vectored_at_volatile(&iovs[..], off)?; self.buffer.consume(written); Ok(written) @@ -392,6 +352,19 @@ impl<'a> Reader<'a> { self.buffer.bytes_consumed() } + /// Returns a `&[VolatileSlice]` that represents all the remaining data in this `Reader`. + /// Calling this method does not actually consume any data from the `Reader` and callers should + /// call `consume` to advance the `Reader`. + pub fn get_remaining(&self) -> &[VolatileSlice] { + self.buffer.get_remaining() + } + + /// Consumes `amt` bytes from the underlying descriptor chain. If `amt` is larger than the + /// remaining data left in this `Reader`, then all remaining data will be consumed. + pub fn consume(&mut self, amt: usize) { + self.buffer.consume(amt) + } + /// Splits this `Reader` into two at the given offset in the `DescriptorChain` buffer. After the /// split, `self` will be able to read up to `offset` bytes while the returned `Reader` can read /// up to `available_bytes() - offset` bytes. If `offset > self.available_bytes()`, then the @@ -401,12 +374,6 @@ impl<'a> Reader<'a> { buffer: self.buffer.split_at(offset), } } - - /// Returns a DescriptorIovec for the next `len` bytes of the descriptor chain - /// buffer, which can be used as an IntoIovec. - pub fn get_iovec(&mut self, len: usize) -> io::Result<DescriptorIovec<'a>> { - self.buffer.get_iovec(len) - } } impl<'a> io::Read for Reader<'a> { @@ -418,11 +385,11 @@ impl<'a> io::Read for Reader<'a> { break; } - let count = cmp::min(rem.len(), b.iov_len); + let count = cmp::min(rem.len(), b.size()); // Safe because we have already verified that `b` points to valid memory. unsafe { - copy_nonoverlapping(b.iov_base as *const u8, rem.as_mut_ptr(), count); + copy_nonoverlapping(b.as_ptr(), rem.as_mut_ptr(), count); } rem = &mut rem[count..]; total += count; @@ -460,21 +427,18 @@ impl<'a> Writer<'a> { .checked_add(desc.len as usize) .ok_or(Error::DescriptorChainOverflow)?; - let vs = mem - .get_slice(desc.addr.offset(), desc.len.into()) - .map_err(Error::VolatileMemoryError)?; - Ok(libc::iovec { - iov_base: vs.as_ptr() as *mut c_void, - iov_len: vs.size() as usize, - }) + mem.get_slice_at_addr( + desc.addr, + desc.len.try_into().expect("u32 doesn't fit in usize"), + ) + .map_err(Error::GuestMemoryError) }) - .collect::<Result<Vec<libc::iovec>>>()?; + .collect::<Result<Vec<VolatileSlice>>>()?; Ok(Writer { buffer: DescriptorChainConsumer { buffers, current: 0, bytes_consumed: 0, - mem: PhantomData, }, }) } @@ -512,7 +476,7 @@ impl<'a> Writer<'a> { mut src: F, count: usize, ) -> io::Result<usize> { - let iovs = self.buffer.get_volatile_slices(count); + let iovs = self.buffer.get_remaining_with_count(count); let read = src.read_vectored_volatile(&iovs[..])?; self.buffer.consume(read); Ok(read) @@ -528,7 +492,7 @@ impl<'a> Writer<'a> { count: usize, off: u64, ) -> io::Result<usize> { - let iovs = self.buffer.get_volatile_slices(count); + let iovs = self.buffer.get_remaining_with_count(count); let read = src.read_vectored_at_volatile(&iovs[..], off)?; self.buffer.consume(read); Ok(read) @@ -595,12 +559,6 @@ impl<'a> Writer<'a> { buffer: self.buffer.split_at(offset), } } - - /// Returns a DescriptorIovec for the next `len` bytes of the descriptor chain - /// buffer, which can be used as an IntoIovec. - pub fn get_iovec(&mut self, len: usize) -> io::Result<DescriptorIovec<'a>> { - self.buffer.get_iovec(len) - } } impl<'a> io::Write for Writer<'a> { @@ -612,10 +570,10 @@ impl<'a> io::Write for Writer<'a> { break; } - let count = cmp::min(rem.len(), b.iov_len); + let count = cmp::min(rem.len(), b.size()); // Safe because we have already verified that `vs` points to valid memory. unsafe { - copy_nonoverlapping(rem.as_ptr(), b.iov_base as *mut u8, count); + copy_nonoverlapping(rem.as_ptr(), b.as_mut_ptr(), count); } rem = &rem[count..]; total += count; @@ -631,18 +589,6 @@ impl<'a> io::Write for Writer<'a> { } } -pub struct DescriptorIovec<'a> { - iovec: Vec<libc::iovec>, - mem: PhantomData<&'a GuestMemory>, -} - -// Safe because the lifetime of DescriptorIovec is tied to the underlying GuestMemory. -unsafe impl<'a> IntoIovec for DescriptorIovec<'a> { - fn into_iovec(&self) -> Vec<libc::iovec> { - self.iovec.clone() - } -} - const VIRTQ_DESC_F_NEXT: u16 = 0x1; const VIRTQ_DESC_F_WRITE: u16 = 0x2; @@ -1266,4 +1212,59 @@ mod tests { .expect("failed to collect() values"); assert_eq!(vs, vs_read); } + + #[test] + fn get_remaining_with_count() { + use DescriptorType::*; + + let memory_start_addr = GuestAddress(0x0); + let memory = GuestMemory::new(&vec![(memory_start_addr, 0x10000)]).unwrap(); + + let chain = create_descriptor_chain( + &memory, + GuestAddress(0x0), + GuestAddress(0x100), + vec![ + (Readable, 16), + (Readable, 16), + (Readable, 96), + (Writable, 64), + (Writable, 1), + (Writable, 3), + ], + 0, + ) + .expect("create_descriptor_chain failed"); + + let Reader { mut buffer } = Reader::new(&memory, chain).expect("failed to create Reader"); + + let drain = buffer + .get_remaining_with_count(::std::usize::MAX) + .iter() + .fold(0usize, |total, iov| total + iov.size()); + assert_eq!(drain, 128); + + let exact = buffer + .get_remaining_with_count(32) + .iter() + .fold(0usize, |total, iov| total + iov.size()); + assert!(exact > 0); + assert!(exact <= 32); + + let split = buffer + .get_remaining_with_count(24) + .iter() + .fold(0usize, |total, iov| total + iov.size()); + assert!(split > 0); + assert!(split <= 24); + + buffer.consume(64); + + let first = buffer + .get_remaining_with_count(8) + .iter() + .fold(0usize, |total, iov| total + iov.size()); + assert!(first > 0); + assert!(first <= 8); + } } |