diff options
Diffstat (limited to 'devices/src/virtio/descriptor_utils.rs')
-rw-r--r-- | devices/src/virtio/descriptor_utils.rs | 201 |
1 files changed, 121 insertions, 80 deletions
diff --git a/devices/src/virtio/descriptor_utils.rs b/devices/src/virtio/descriptor_utils.rs index b767d42..fcf5793 100644 --- a/devices/src/virtio/descriptor_utils.rs +++ b/devices/src/virtio/descriptor_utils.rs @@ -2,9 +2,9 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. +use std::borrow::Cow; use std::cmp; use std::convert::TryInto; -use std::ffi::c_void; use std::fmt::{self, Display}; use std::io::{self, Read, Write}; use std::iter::FromIterator; @@ -13,7 +13,7 @@ use std::mem::{size_of, MaybeUninit}; use std::ptr::copy_nonoverlapping; use std::result; -use data_model::{DataInit, Le16, Le32, Le64, VolatileMemory, VolatileMemoryError, VolatileSlice}; +use data_model::{DataInit, Le16, Le32, Le64, VolatileMemoryError, VolatileSlice}; use sys_util::{ FileReadWriteAtVolatile, FileReadWriteVolatile, GuestAddress, GuestMemory, IntoIovec, }; @@ -54,10 +54,9 @@ impl std::error::Error for Error {} #[derive(Clone)] struct DescriptorChainConsumer<'a> { - buffers: Vec<libc::iovec>, + buffers: Vec<VolatileSlice<'a>>, current: usize, bytes_consumed: usize, - mem: PhantomData<&'a GuestMemory>, } impl<'a> DescriptorChainConsumer<'a> { @@ -67,7 +66,7 @@ impl<'a> DescriptorChainConsumer<'a> { // `Reader::new()` and `Writer::new()`). self.get_remaining() .iter() - .fold(0usize, |count, buf| count + buf.iov_len) + .fold(0usize, |count, buf| count + buf.size()) } fn bytes_consumed(&self) -> usize { @@ -78,10 +77,38 @@ impl<'a> DescriptorChainConsumer<'a> { /// consume any bytes from the `DescriptorChain`. Instead callers should use the `consume` /// method to advance the `DescriptorChain`. Multiple calls to `get` with no intervening calls /// to `consume` will return the same data. - fn get_remaining(&self) -> &[libc::iovec] { + fn get_remaining(&self) -> &[VolatileSlice] { &self.buffers[self.current..] } + /// Like `get_remaining` but guarantees that the combined length of all the returned iovecs is + /// not greater than `count`. The combined length of the returned iovecs may be less than + /// `count` but will always be greater than 0 as long as there is still space left in the + /// `DescriptorChain`. + fn get_remaining_with_count(&self, count: usize) -> Cow<[VolatileSlice]> { + let iovs = self.get_remaining(); + let mut iov_count = 0; + let mut rem = count; + for iov in iovs { + if rem < iov.size() { + break; + } + + iov_count += 1; + rem -= iov.size(); + } + + // Special case where the number of bytes to be copied is smaller than the `size()` of the + // first iovec. + if iov_count == 0 && iovs.len() > 0 && count > 0 { + debug_assert!(count < iovs[0].size()); + // Safe because we know that count is smaller than the length of the first slice. + Cow::Owned(vec![iovs[0].sub_slice(0, count).unwrap()]) + } else { + Cow::Borrowed(&iovs[..iov_count]) + } + } + /// Consumes `count` bytes from the `DescriptorChain`. If `count` is larger than /// `self.available_bytes()` then all remaining bytes in the `DescriptorChain` will be consumed. /// @@ -99,19 +126,18 @@ impl<'a> DescriptorChainConsumer<'a> { break; } - let consumed = if count < buf.iov_len { + let consumed = if count < buf.size() { // Safe because we know that the iovec pointed to valid memory and we are adding a // value that is smaller than the length of the memory. - buf.iov_base = unsafe { (buf.iov_base as *mut u8).add(count) as *mut c_void }; - buf.iov_len -= count; + *buf = buf.offset(count).unwrap(); count } else { self.current += 1; - buf.iov_len + buf.size() }; - // This shouldn't overflow because `consumed <= buf.iov_len` and we already verified - // that adding all `buf.iov_len` values will not overflow when the Reader/Writer was + // This shouldn't overflow because `consumed <= buf.size()` and we already verified + // that adding all `buf.size()` values will not overflow when the Reader/Writer was // constructed. self.bytes_consumed += consumed; count -= consumed; @@ -126,13 +152,14 @@ impl<'a> DescriptorChainConsumer<'a> { let mut rem = offset; let mut end = self.current; for buf in &mut self.buffers[self.current..] { - if rem < buf.iov_len { - buf.iov_len = rem; + if rem < buf.size() { + // Safe because we are creating a smaller sub-slice. + *buf = buf.sub_slice(0, rem).unwrap(); break; } end += 1; - rem -= buf.iov_len; + rem -= buf.size(); } self.buffers.truncate(end + 1); @@ -140,51 +167,16 @@ impl<'a> DescriptorChainConsumer<'a> { other } - // Temporary method for converting iovecs into VolatileSlices until we can change the - // ReadWriteVolatile traits. The irony here is that the standard implementation of the - // ReadWriteVolatile traits will convert the VolatileSlices back into iovecs. - fn get_volatile_slices(&mut self, mut count: usize) -> Vec<VolatileSlice> { - let bufs = self.get_remaining(); - let mut iovs = Vec::with_capacity(bufs.len()); - for b in bufs { - // Safe because we verified during construction that the memory at `b.iov_base` is - // `b.iov_len` bytes long. The lifetime of the `VolatileSlice` is tied to the lifetime - // of this `DescriptorChainConsumer`, which is in turn tied to the lifetime of the - // `GuestMemory` used to create it and so the memory will be available for the duration - // of the `VolatileSlice`. - let iov = unsafe { - if count < b.iov_len { - VolatileSlice::new( - b.iov_base as *mut u8, - count.try_into().expect("usize doesn't fit in u64"), - ) - } else { - VolatileSlice::new( - b.iov_base as *mut u8, - b.iov_len.try_into().expect("usize doesn't fit in u64"), - ) - } - }; - - count -= iov.size() as usize; - iovs.push(iov); - } - - iovs - } - fn get_iovec(&mut self, len: usize) -> io::Result<DescriptorIovec<'a>> { let mut iovec = Vec::with_capacity(self.get_remaining().len()); let mut rem = len; for buf in self.get_remaining() { - let iov = if rem < buf.iov_len { - libc::iovec { - iov_base: buf.iov_base, - iov_len: rem, - } + let iov = if rem < buf.size() { + // Safe because we know that `rem` is in-bounds. + buf.sub_slice(0, rem).unwrap().as_iovec() } else { - buf.clone() + buf.as_iovec() }; rem -= iov.iov_len; @@ -249,21 +241,18 @@ impl<'a> Reader<'a> { .checked_add(desc.len as usize) .ok_or(Error::DescriptorChainOverflow)?; - let vs = mem - .get_slice(desc.addr.offset(), desc.len.into()) - .map_err(Error::VolatileMemoryError)?; - Ok(libc::iovec { - iov_base: vs.as_ptr() as *mut c_void, - iov_len: vs.size() as usize, - }) + mem.get_slice_at_addr( + desc.addr, + desc.len.try_into().expect("u32 doesn't fit in usize"), + ) + .map_err(Error::GuestMemoryError) }) - .collect::<Result<Vec<libc::iovec>>>()?; + .collect::<Result<Vec<VolatileSlice>>>()?; Ok(Reader { buffer: DescriptorChainConsumer { buffers, current: 0, bytes_consumed: 0, - mem: PhantomData, }, }) } @@ -311,7 +300,7 @@ impl<'a> Reader<'a> { mut dst: F, count: usize, ) -> io::Result<usize> { - let iovs = self.buffer.get_volatile_slices(count); + let iovs = self.buffer.get_remaining_with_count(count); let written = dst.write_vectored_volatile(&iovs[..])?; self.buffer.consume(written); Ok(written) @@ -327,7 +316,7 @@ impl<'a> Reader<'a> { count: usize, off: u64, ) -> io::Result<usize> { - let iovs = self.buffer.get_volatile_slices(count); + let iovs = self.buffer.get_remaining_with_count(count); let written = dst.write_vectored_at_volatile(&iovs[..], off)?; self.buffer.consume(written); Ok(written) @@ -418,11 +407,11 @@ impl<'a> io::Read for Reader<'a> { break; } - let count = cmp::min(rem.len(), b.iov_len); + let count = cmp::min(rem.len(), b.size()); // Safe because we have already verified that `b` points to valid memory. unsafe { - copy_nonoverlapping(b.iov_base as *const u8, rem.as_mut_ptr(), count); + copy_nonoverlapping(b.as_ptr(), rem.as_mut_ptr(), count); } rem = &mut rem[count..]; total += count; @@ -460,21 +449,18 @@ impl<'a> Writer<'a> { .checked_add(desc.len as usize) .ok_or(Error::DescriptorChainOverflow)?; - let vs = mem - .get_slice(desc.addr.offset(), desc.len.into()) - .map_err(Error::VolatileMemoryError)?; - Ok(libc::iovec { - iov_base: vs.as_ptr() as *mut c_void, - iov_len: vs.size() as usize, - }) + mem.get_slice_at_addr( + desc.addr, + desc.len.try_into().expect("u32 doesn't fit in usize"), + ) + .map_err(Error::GuestMemoryError) }) - .collect::<Result<Vec<libc::iovec>>>()?; + .collect::<Result<Vec<VolatileSlice>>>()?; Ok(Writer { buffer: DescriptorChainConsumer { buffers, current: 0, bytes_consumed: 0, - mem: PhantomData, }, }) } @@ -512,7 +498,7 @@ impl<'a> Writer<'a> { mut src: F, count: usize, ) -> io::Result<usize> { - let iovs = self.buffer.get_volatile_slices(count); + let iovs = self.buffer.get_remaining_with_count(count); let read = src.read_vectored_volatile(&iovs[..])?; self.buffer.consume(read); Ok(read) @@ -528,7 +514,7 @@ impl<'a> Writer<'a> { count: usize, off: u64, ) -> io::Result<usize> { - let iovs = self.buffer.get_volatile_slices(count); + let iovs = self.buffer.get_remaining_with_count(count); let read = src.read_vectored_at_volatile(&iovs[..], off)?; self.buffer.consume(read); Ok(read) @@ -612,10 +598,10 @@ impl<'a> io::Write for Writer<'a> { break; } - let count = cmp::min(rem.len(), b.iov_len); + let count = cmp::min(rem.len(), b.size()); // Safe because we have already verified that `vs` points to valid memory. unsafe { - copy_nonoverlapping(rem.as_ptr(), b.iov_base as *mut u8, count); + copy_nonoverlapping(rem.as_ptr(), b.as_mut_ptr(), count); } rem = &rem[count..]; total += count; @@ -1266,4 +1252,59 @@ mod tests { .expect("failed to collect() values"); assert_eq!(vs, vs_read); } + + #[test] + fn get_remaining_with_count() { + use DescriptorType::*; + + let memory_start_addr = GuestAddress(0x0); + let memory = GuestMemory::new(&vec![(memory_start_addr, 0x10000)]).unwrap(); + + let chain = create_descriptor_chain( + &memory, + GuestAddress(0x0), + GuestAddress(0x100), + vec![ + (Readable, 16), + (Readable, 16), + (Readable, 96), + (Writable, 64), + (Writable, 1), + (Writable, 3), + ], + 0, + ) + .expect("create_descriptor_chain failed"); + + let Reader { mut buffer } = Reader::new(&memory, chain).expect("failed to create Reader"); + + let drain = buffer + .get_remaining_with_count(::std::usize::MAX) + .iter() + .fold(0usize, |total, iov| total + iov.size()); + assert_eq!(drain, 128); + + let exact = buffer + .get_remaining_with_count(32) + .iter() + .fold(0usize, |total, iov| total + iov.size()); + assert!(exact > 0); + assert!(exact <= 32); + + let split = buffer + .get_remaining_with_count(24) + .iter() + .fold(0usize, |total, iov| total + iov.size()); + assert!(split > 0); + assert!(split <= 24); + + buffer.consume(64); + + let first = buffer + .get_remaining_with_count(8) + .iter() + .fold(0usize, |total, iov| total + iov.size()); + assert!(first > 0); + assert!(first <= 8); + } } |