diff options
Diffstat (limited to 'devices')
-rw-r--r-- | devices/src/virtio/block.rs | 4 | ||||
-rw-r--r-- | devices/src/virtio/descriptor_utils.rs | 373 | ||||
-rw-r--r-- | devices/src/virtio/fs/server.rs | 9 |
3 files changed, 203 insertions, 183 deletions
diff --git a/devices/src/virtio/block.rs b/devices/src/virtio/block.rs index c9dda55..80d5103 100644 --- a/devices/src/virtio/block.rs +++ b/devices/src/virtio/block.rs @@ -267,9 +267,7 @@ impl Worker { let status_offset = available_bytes .checked_sub(1) .ok_or(ExecuteError::MissingStatus)?; - let mut status_writer = writer - .split_at(status_offset) - .map_err(ExecuteError::Descriptor)?; + let mut status_writer = writer.split_at(status_offset); let status = match Block::execute_request( &mut reader, diff --git a/devices/src/virtio/descriptor_utils.rs b/devices/src/virtio/descriptor_utils.rs index 27d4b1c..b767d42 100644 --- a/devices/src/virtio/descriptor_utils.rs +++ b/devices/src/virtio/descriptor_utils.rs @@ -3,7 +3,8 @@ // found in the LICENSE file. use std::cmp; -use std::collections::VecDeque; +use std::convert::TryInto; +use std::ffi::c_void; use std::fmt::{self, Display}; use std::io::{self, Read, Write}; use std::iter::FromIterator; @@ -53,8 +54,10 @@ impl std::error::Error for Error {} #[derive(Clone)] struct DescriptorChainConsumer<'a> { - buffers: VecDeque<VolatileSlice<'a>>, + buffers: Vec<libc::iovec>, + current: usize, bytes_consumed: usize, + mem: PhantomData<&'a GuestMemory>, } impl<'a> DescriptorChainConsumer<'a> { @@ -62,140 +65,136 @@ impl<'a> DescriptorChainConsumer<'a> { // This is guaranteed not to overflow because the total length of the chain // is checked during all creations of `DescriptorChainConsumer` (see // `Reader::new()` and `Writer::new()`). - self.buffers + self.get_remaining() .iter() - .fold(0usize, |count, vs| count + vs.size() as usize) + .fold(0usize, |count, buf| count + buf.iov_len) } fn bytes_consumed(&self) -> usize { self.bytes_consumed } - /// Consumes at most `count` bytes from the `DescriptorChain`. Callers must provide a function - /// that takes a `&[VolatileSlice]` and returns the total number of bytes consumed. This - /// function guarantees that the combined length of all the slices in the `&[VolatileSlice]` is - /// less than or equal to `count`. + /// Returns all the remaining buffers in the `DescriptorChain`. Calling this function does not + /// consume any bytes from the `DescriptorChain`. Instead callers should use the `consume` + /// method to advance the `DescriptorChain`. Multiple calls to `get` with no intervening calls + /// to `consume` will return the same data. + fn get_remaining(&self) -> &[libc::iovec] { + &self.buffers[self.current..] + } + + /// Consumes `count` bytes from the `DescriptorChain`. If `count` is larger than + /// `self.available_bytes()` then all remaining bytes in the `DescriptorChain` will be consumed. /// /// # Errors /// - /// If the provided function returns any error then no bytes are consumed from the buffer and - /// the error is returned to the caller. - fn consume<F>(&mut self, count: usize, f: F) -> io::Result<usize> - where - F: FnOnce(&[VolatileSlice]) -> io::Result<usize>, - { - let mut buflen = 0; - let mut bufs = Vec::with_capacity(self.buffers.len()); - for &vs in &self.buffers { - if buflen >= count { + /// Returns an error if the total bytes consumed by this `DescriptorChainConsumer` overflows a + /// usize. + fn consume(&mut self, mut count: usize) { + // The implementation is adapted from `IoSlice::advance` in libstd. We can't use + // `get_remaining` here because then the compiler complains that `self.current` is already + // borrowed and doesn't allow us to modify it. We also need to borrow the iovecs mutably. + let current = self.current; + for buf in &mut self.buffers[current..] { + if count == 0 { break; } - let rem = count - buflen; - if (rem as u64) < vs.size() { - let buf = vs.sub_slice(0, rem as u64).map_err(|e| { - io::Error::new(io::ErrorKind::InvalidData, Error::VolatileMemoryError(e)) - })?; - bufs.push(buf); - buflen += rem; + let consumed = if count < buf.iov_len { + // Safe because we know that the iovec pointed to valid memory and we are adding a + // value that is smaller than the length of the memory. + buf.iov_base = unsafe { (buf.iov_base as *mut u8).add(count) as *mut c_void }; + buf.iov_len -= count; + count } else { - bufs.push(vs); - buflen += vs.size() as usize; - } + self.current += 1; + buf.iov_len + }; + + // This shouldn't overflow because `consumed <= buf.iov_len` and we already verified + // that adding all `buf.iov_len` values will not overflow when the Reader/Writer was + // constructed. + self.bytes_consumed += consumed; + count -= consumed; } + } - if bufs.is_empty() { - return Ok(0); - } + fn split_at(&mut self, offset: usize) -> DescriptorChainConsumer<'a> { + let mut other = self.clone(); + other.consume(offset); + other.bytes_consumed = 0; - let bytes_consumed = f(&*bufs)?; - - // This can happen if a driver tricks a device into reading/writing more data than - // fits in a `usize`. - let total_bytes_consumed = - self.bytes_consumed - .checked_add(bytes_consumed) - .ok_or_else(|| { - io::Error::new(io::ErrorKind::InvalidData, Error::DescriptorChainOverflow) - })?; - - let mut rem = bytes_consumed; - while let Some(vs) = self.buffers.pop_front() { - if (rem as u64) < vs.size() { - // Split the slice and push the remainder back into the buffer list. Safe because we - // know that `rem` is not out of bounds due to the check and we checked the bounds - // on `vs` when we added it to the buffer list. - self.buffers.push_front(vs.offset(rem as u64).unwrap()); + let mut rem = offset; + let mut end = self.current; + for buf in &mut self.buffers[self.current..] { + if rem < buf.iov_len { + buf.iov_len = rem; break; } - // No need for checked math because we know that `vs.size() <= rem`. - rem -= vs.size() as usize; + end += 1; + rem -= buf.iov_len; } - self.bytes_consumed = total_bytes_consumed; + self.buffers.truncate(end + 1); - Ok(bytes_consumed) + other } - fn split_at(&mut self, offset: usize) -> Result<DescriptorChainConsumer<'a>> { - let mut rem = offset; - let pos = self.buffers.iter().position(|vs| { - if (rem as u64) < vs.size() { - true - } else { - rem -= vs.size() as usize; - false - } - }); - - if let Some(at) = pos { - let mut other = self.buffers.split_off(at); - - if rem > 0 { - // There must be at least one element in `other` because we checked - // its `size` value in the call to `position` above. - let front = other.pop_front().expect("empty VecDeque after split"); - self.buffers.push_back( - front - .sub_slice(0, rem as u64) - .map_err(Error::VolatileMemoryError)?, - ); - other.push_front( - front - .offset(rem as u64) - .map_err(Error::VolatileMemoryError)?, - ); - } + // Temporary method for converting iovecs into VolatileSlices until we can change the + // ReadWriteVolatile traits. The irony here is that the standard implementation of the + // ReadWriteVolatile traits will convert the VolatileSlices back into iovecs. + fn get_volatile_slices(&mut self, mut count: usize) -> Vec<VolatileSlice> { + let bufs = self.get_remaining(); + let mut iovs = Vec::with_capacity(bufs.len()); + for b in bufs { + // Safe because we verified during construction that the memory at `b.iov_base` is + // `b.iov_len` bytes long. The lifetime of the `VolatileSlice` is tied to the lifetime + // of this `DescriptorChainConsumer`, which is in turn tied to the lifetime of the + // `GuestMemory` used to create it and so the memory will be available for the duration + // of the `VolatileSlice`. + let iov = unsafe { + if count < b.iov_len { + VolatileSlice::new( + b.iov_base as *mut u8, + count.try_into().expect("usize doesn't fit in u64"), + ) + } else { + VolatileSlice::new( + b.iov_base as *mut u8, + b.iov_len.try_into().expect("usize doesn't fit in u64"), + ) + } + }; - Ok(DescriptorChainConsumer { - buffers: other, - bytes_consumed: 0, - }) - } else if rem == 0 { - Ok(DescriptorChainConsumer { - buffers: VecDeque::new(), - bytes_consumed: 0, - }) - } else { - Err(Error::SplitOutOfBounds(offset)) + count -= iov.size() as usize; + iovs.push(iov); } + + iovs } fn get_iovec(&mut self, len: usize) -> io::Result<DescriptorIovec<'a>> { - let mut iovec = Vec::new(); + let mut iovec = Vec::with_capacity(self.get_remaining().len()); + + let mut rem = len; + for buf in self.get_remaining() { + let iov = if rem < buf.iov_len { + libc::iovec { + iov_base: buf.iov_base, + iov_len: rem, + } + } else { + buf.clone() + }; - self.consume(len, |bufs| { - let mut total = 0; - for vs in bufs { - iovec.push(libc::iovec { - iov_base: vs.as_ptr() as *mut libc::c_void, - iov_len: vs.size() as usize, - }); - total += vs.size() as usize; + rem -= iov.iov_len; + iovec.push(iov); + + if rem == 0 { + break; } - Ok(total) - })?; + } + self.consume(len); Ok(DescriptorIovec { iovec, @@ -250,14 +249,21 @@ impl<'a> Reader<'a> { .checked_add(desc.len as usize) .ok_or(Error::DescriptorChainOverflow)?; - mem.get_slice(desc.addr.offset(), desc.len.into()) - .map_err(Error::VolatileMemoryError) + let vs = mem + .get_slice(desc.addr.offset(), desc.len.into()) + .map_err(Error::VolatileMemoryError)?; + Ok(libc::iovec { + iov_base: vs.as_ptr() as *mut c_void, + iov_len: vs.size() as usize, + }) }) - .collect::<Result<VecDeque<VolatileSlice<'a>>>>()?; + .collect::<Result<Vec<libc::iovec>>>()?; Ok(Reader { buffer: DescriptorChainConsumer { buffers, + current: 0, bytes_consumed: 0, + mem: PhantomData, }, }) } @@ -305,8 +311,10 @@ impl<'a> Reader<'a> { mut dst: F, count: usize, ) -> io::Result<usize> { - self.buffer - .consume(count, |bufs| dst.write_vectored_volatile(bufs)) + let iovs = self.buffer.get_volatile_slices(count); + let written = dst.write_vectored_volatile(&iovs[..])?; + self.buffer.consume(written); + Ok(written) } /// Reads data from the descriptor chain buffer into a File at offset `off`. @@ -319,8 +327,10 @@ impl<'a> Reader<'a> { count: usize, off: u64, ) -> io::Result<usize> { - self.buffer - .consume(count, |bufs| dst.write_vectored_at_volatile(bufs, off)) + let iovs = self.buffer.get_volatile_slices(count); + let written = dst.write_vectored_at_volatile(&iovs[..], off)?; + self.buffer.consume(written); + Ok(written) } pub fn read_exact_to<F: FileReadWriteVolatile>( @@ -382,12 +392,14 @@ impl<'a> Reader<'a> { self.buffer.bytes_consumed() } - /// Splits this `Reader` into two at the given offset in the `DescriptorChain` buffer. - /// After the split, `self` will be able to read up to `offset` bytes while the returned - /// `Reader` can read up to `available_bytes() - offset` bytes. Returns an error if - /// `offset > self.available_bytes()`. - pub fn split_at(&mut self, offset: usize) -> Result<Reader<'a>> { - self.buffer.split_at(offset).map(|buffer| Reader { buffer }) + /// Splits this `Reader` into two at the given offset in the `DescriptorChain` buffer. After the + /// split, `self` will be able to read up to `offset` bytes while the returned `Reader` can read + /// up to `available_bytes() - offset` bytes. If `offset > self.available_bytes()`, then the + /// returned `Reader` will not be able to read any bytes. + pub fn split_at(&mut self, offset: usize) -> Reader<'a> { + Reader { + buffer: self.buffer.split_at(offset), + } } /// Returns a DescriptorIovec for the next `len` bytes of the descriptor chain @@ -399,27 +411,25 @@ impl<'a> Reader<'a> { impl<'a> io::Read for Reader<'a> { fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { - self.buffer.consume(buf.len(), |bufs| { - let mut rem = buf; - let mut total = 0; - for vs in bufs { - // This is guaranteed by the implementation of `consume`. - debug_assert_eq!(vs.size(), cmp::min(rem.len() as u64, vs.size())); - - // Safe because we have already verified that `vs` points to valid memory. - unsafe { - copy_nonoverlapping( - vs.as_ptr() as *const u8, - rem.as_mut_ptr(), - vs.size() as usize, - ); - } - let copied = vs.size() as usize; - rem = &mut rem[copied..]; - total += copied; + let mut rem = buf; + let mut total = 0; + for b in self.buffer.get_remaining() { + if rem.len() == 0 { + break; } - Ok(total) - }) + + let count = cmp::min(rem.len(), b.iov_len); + + // Safe because we have already verified that `b` points to valid memory. + unsafe { + copy_nonoverlapping(b.iov_base as *const u8, rem.as_mut_ptr(), count); + } + rem = &mut rem[count..]; + total += count; + } + + self.buffer.consume(total); + Ok(total) } } @@ -450,14 +460,21 @@ impl<'a> Writer<'a> { .checked_add(desc.len as usize) .ok_or(Error::DescriptorChainOverflow)?; - mem.get_slice(desc.addr.offset(), desc.len.into()) - .map_err(Error::VolatileMemoryError) + let vs = mem + .get_slice(desc.addr.offset(), desc.len.into()) + .map_err(Error::VolatileMemoryError)?; + Ok(libc::iovec { + iov_base: vs.as_ptr() as *mut c_void, + iov_len: vs.size() as usize, + }) }) - .collect::<Result<VecDeque<VolatileSlice<'a>>>>()?; + .collect::<Result<Vec<libc::iovec>>>()?; Ok(Writer { buffer: DescriptorChainConsumer { buffers, + current: 0, bytes_consumed: 0, + mem: PhantomData, }, }) } @@ -495,8 +512,10 @@ impl<'a> Writer<'a> { mut src: F, count: usize, ) -> io::Result<usize> { - self.buffer - .consume(count, |bufs| src.read_vectored_volatile(bufs)) + let iovs = self.buffer.get_volatile_slices(count); + let read = src.read_vectored_volatile(&iovs[..])?; + self.buffer.consume(read); + Ok(read) } /// Writes data to the descriptor chain buffer from a File at offset `off`. @@ -509,8 +528,10 @@ impl<'a> Writer<'a> { count: usize, off: u64, ) -> io::Result<usize> { - self.buffer - .consume(count, |bufs| src.read_vectored_at_volatile(bufs, off)) + let iovs = self.buffer.get_volatile_slices(count); + let read = src.read_vectored_at_volatile(&iovs[..], off)?; + self.buffer.consume(read); + Ok(read) } pub fn write_all_from<F: FileReadWriteVolatile>( @@ -565,12 +586,14 @@ impl<'a> Writer<'a> { self.buffer.bytes_consumed() } - /// Splits this `Writer` into two at the given offset in the `DescriptorChain` buffer. - /// After the split, `self` will be able to write up to `offset` bytes while the returned - /// `Writer` can write up to `available_bytes() - offset` bytes. Returns an error if - /// `offset > self.available_bytes()`. - pub fn split_at(&mut self, offset: usize) -> Result<Writer<'a>> { - self.buffer.split_at(offset).map(|buffer| Writer { buffer }) + /// Splits this `Writer` into two at the given offset in the `DescriptorChain` buffer. After the + /// split, `self` will be able to write up to `offset` bytes while the returned `Writer` can + /// write up to `available_bytes() - offset` bytes. If `offset > self.available_bytes()`, then + /// the returned `Writer` will not be able to write any bytes. + pub fn split_at(&mut self, offset: usize) -> Writer<'a> { + Writer { + buffer: self.buffer.split_at(offset), + } } /// Returns a DescriptorIovec for the next `len` bytes of the descriptor chain @@ -582,23 +605,24 @@ impl<'a> Writer<'a> { impl<'a> io::Write for Writer<'a> { fn write(&mut self, buf: &[u8]) -> io::Result<usize> { - self.buffer.consume(buf.len(), |bufs| { - let mut rem = buf; - let mut total = 0; - for vs in bufs { - // This is guaranteed by the implementation of `consume`. - debug_assert_eq!(vs.size(), cmp::min(rem.len() as u64, vs.size())); - - // Safe because we have already verified that `vs` points to valid memory. - unsafe { - copy_nonoverlapping(rem.as_ptr(), vs.as_ptr(), vs.size() as usize); - } - let copied = vs.size() as usize; - rem = &rem[copied..]; - total += copied; + let mut rem = buf; + let mut total = 0; + for b in self.buffer.get_remaining() { + if rem.len() == 0 { + break; } - Ok(total) - }) + + let count = cmp::min(rem.len(), b.iov_len); + // Safe because we have already verified that `vs` points to valid memory. + unsafe { + copy_nonoverlapping(rem.as_ptr(), b.iov_base as *mut u8, count); + } + rem = &rem[count..]; + total += count; + } + + self.buffer.consume(total); + Ok(total) } fn flush(&mut self) -> io::Result<()> { @@ -1031,7 +1055,7 @@ mod tests { .expect("create_descriptor_chain failed"); let mut reader = Reader::new(&memory, chain).expect("failed to create Reader"); - let other = reader.split_at(32).expect("failed to split Reader"); + let other = reader.split_at(32); assert_eq!(reader.available_bytes(), 32); assert_eq!(other.available_bytes(), 96); } @@ -1060,7 +1084,7 @@ mod tests { .expect("create_descriptor_chain failed"); let mut reader = Reader::new(&memory, chain).expect("failed to create Reader"); - let other = reader.split_at(24).expect("failed to split Reader"); + let other = reader.split_at(24); assert_eq!(reader.available_bytes(), 24); assert_eq!(other.available_bytes(), 104); } @@ -1089,7 +1113,7 @@ mod tests { .expect("create_descriptor_chain failed"); let mut reader = Reader::new(&memory, chain).expect("failed to create Reader"); - let other = reader.split_at(128).expect("failed to split Reader"); + let other = reader.split_at(128); assert_eq!(reader.available_bytes(), 128); assert_eq!(other.available_bytes(), 0); } @@ -1118,7 +1142,7 @@ mod tests { .expect("create_descriptor_chain failed"); let mut reader = Reader::new(&memory, chain).expect("failed to create Reader"); - let other = reader.split_at(0).expect("failed to split Reader"); + let other = reader.split_at(0); assert_eq!(reader.available_bytes(), 0); assert_eq!(other.available_bytes(), 128); } @@ -1147,9 +1171,12 @@ mod tests { .expect("create_descriptor_chain failed"); let mut reader = Reader::new(&memory, chain).expect("failed to create Reader"); - if let Ok(_) = reader.split_at(256) { - panic!("successfully split Reader with out of bounds offset"); - } + let other = reader.split_at(256); + assert_eq!( + other.available_bytes(), + 0, + "Reader returned from out-of-bounds split still has available bytes" + ); } #[test] diff --git a/devices/src/virtio/fs/server.rs b/devices/src/virtio/fs/server.rs index 33b7c98..c1af80c 100644 --- a/devices/src/virtio/fs/server.rs +++ b/devices/src/virtio/fs/server.rs @@ -496,10 +496,7 @@ impl<F: FileSystem + Sync> Server<F> { }; // Split the writer into 2 pieces: one for the `OutHeader` and the rest for the data. - let data_writer = ZCWriter( - w.split_at(size_of::<OutHeader>()) - .map_err(Error::InvalidDescriptorChain)?, - ); + let data_writer = ZCWriter(w.split_at(size_of::<OutHeader>())); match self.fs.read( Context::from(in_header), @@ -910,9 +907,7 @@ impl<F: FileSystem + Sync> Server<F> { } // Skip over enough bytes for the header. - let mut cursor = w - .split_at(size_of::<OutHeader>()) - .map_err(Error::InvalidDescriptorChain)?; + let mut cursor = w.split_at(size_of::<OutHeader>()); let res = if plus { self.fs.readdirplus( |