diff options
Diffstat (limited to 'devices/src/virtio/descriptor_utils.rs')
-rw-r--r-- | devices/src/virtio/descriptor_utils.rs | 965 |
1 files changed, 583 insertions, 382 deletions
diff --git a/devices/src/virtio/descriptor_utils.rs b/devices/src/virtio/descriptor_utils.rs index 409134e..3c8ff87 100644 --- a/devices/src/virtio/descriptor_utils.rs +++ b/devices/src/virtio/descriptor_utils.rs @@ -3,23 +3,24 @@ // found in the LICENSE file. use std::cmp; -use std::convert::TryFrom; +use std::collections::VecDeque; use std::fmt::{self, Display}; -use std::io; -use std::os::unix::io::AsRawFd; +use std::io::{self, Read, Write}; +use std::mem::{size_of, MaybeUninit}; use std::result; -use data_model::{DataInit, Le16, Le32, Le64, VolatileMemory, VolatileMemoryError}; -use sys_util::guest_memory::Error as GuestMemoryError; -use sys_util::{FileReadWriteVolatile, GuestAddress, GuestMemory}; +use data_model::{DataInit, Le16, Le32, Le64, VolatileMemory, VolatileMemoryError, VolatileSlice}; +use sys_util::{FileReadWriteAtVolatile, FileReadWriteVolatile, GuestAddress, GuestMemory}; use super::DescriptorChain; #[derive(Debug)] pub enum Error { + DescriptorChainOverflow, GuestMemoryError(sys_util::GuestMemoryError), InvalidChain, IoError(io::Error), + SplitOutOfBounds(usize), VolatileMemoryError(VolatileMemoryError), } @@ -28,9 +29,14 @@ impl Display for Error { use self::Error::*; match self { + DescriptorChainOverflow => write!( + f, + "the combined length of all the buffers in a `DescriptorChain` would overflow" + ), GuestMemoryError(e) => write!(f, "descriptor guest memory error: {}", e), InvalidChain => write!(f, "invalid descriptor chain"), IoError(e) => write!(f, "descriptor I/O error: {}", e), + SplitOutOfBounds(off) => write!(f, "`DescriptorChain` split is out of bounds: {}", off), VolatileMemoryError(e) => write!(f, "volatile memory error: {}", e), } } @@ -40,157 +46,133 @@ pub type Result<T> = result::Result<T, Error>; impl std::error::Error for Error {} -#[derive(Clone, PartialEq, Eq)] -enum DescriptorFilter { - OnlyReadable, - OnlyWritable, -} - #[derive(Clone)] struct DescriptorChainConsumer<'a> { - offset: usize, - desc_chain: Option<DescriptorChain<'a>>, - desc_chain_start: Option<DescriptorChain<'a>>, + buffers: VecDeque<VolatileSlice<'a>>, bytes_consumed: usize, - avail_bytes: Option<usize>, - filter: DescriptorFilter, } impl<'a> DescriptorChainConsumer<'a> { - fn new( - desc_chain: Option<DescriptorChain<'a>>, - filter: DescriptorFilter, - ) -> DescriptorChainConsumer<'a> { - DescriptorChainConsumer { - offset: 0, - desc_chain: desc_chain.clone(), - desc_chain_start: desc_chain, - bytes_consumed: 0, - avail_bytes: None, - filter, - } - } - - fn available_bytes(&mut self) -> usize { - if let Some(bytes) = self.avail_bytes { - bytes - } else { - let mut chain = self.desc_chain.clone(); - let mut count = 0; - while let Some(desc) = chain { - count += desc.len as usize; - chain = self.advance(desc); - } - let bytes = count - self.offset; - self.avail_bytes = Some(bytes); - bytes - } + fn available_bytes(&self) -> Result<usize> { + self.buffers + .iter() + .try_fold(0usize, |count, vs| count.checked_add(vs.size() as usize)) + .ok_or(Error::DescriptorChainOverflow) } fn bytes_consumed(&self) -> usize { self.bytes_consumed } - fn consume<F>(&mut self, mut fnc: F, mut count: usize) -> Result<usize> + /// Consumes at most `count` bytes from the `DescriptorChain`. Callers must provide a function + /// that takes a `&[VolatileSlice]` and returns the total number of bytes consumed. This + /// function guarantees that the combined length of all the slices in the `&[VolatileSlice]` is + /// less than or equal to `count`. + /// + /// # Errors + /// + /// If the provided function returns any error then no bytes are consumed from the buffer and + /// the error is returned to the caller. + fn consume<F>(&mut self, count: usize, f: F) -> io::Result<usize> where - F: FnMut(GuestAddress, usize) -> Result<()>, + F: FnOnce(&[VolatileSlice]) -> io::Result<usize>, { - let mut bytes_consumed = 0; - while count > 0 { - if let Some(current) = &self.desc_chain { - let addr = current - .addr - .checked_add(self.offset as u64) - .ok_or_else(|| { - Error::GuestMemoryError(GuestMemoryError::InvalidGuestAddress(current.addr)) - })?; - let len = cmp::min(count, current.len as usize - self.offset); - fnc(addr, len)?; - - self.offset += len; - self.avail_bytes = self.avail_bytes.map(|av| av - len); - self.bytes_consumed += len; - bytes_consumed += len; - count -= len; - - if self.offset == current.len as usize { - self.offset = 0; - if let Some(desc_chain) = self.desc_chain.take() { - self.desc_chain = self.advance(desc_chain); - } - } - } else { - // Nothing left to read. + let mut buflen = 0; + let mut bufs = Vec::with_capacity(self.buffers.len()); + for &vs in &self.buffers { + if buflen >= count { break; } - } - Ok(bytes_consumed) - } - fn advance(&self, desc_chain: DescriptorChain<'a>) -> Option<DescriptorChain<'a>> { - let mut desc_chain = desc_chain.next_descriptor(); - // TODO(jstaron): Update this code to take the indirect descriptors into account. - if self.filter == DescriptorFilter::OnlyReadable { - // When encounter first write-only descriptor set `desc_chain` to None to stop - // further processing. - desc_chain = desc_chain.filter(DescriptorChain::is_read_only); + let rem = count - buflen; + if (rem as u64) < vs.size() { + let buf = vs.sub_slice(0, rem as u64).map_err(|e| { + io::Error::new(io::ErrorKind::InvalidData, Error::VolatileMemoryError(e)) + })?; + bufs.push(buf); + buflen += rem; + } else { + bufs.push(vs); + buflen += vs.size() as usize; + } } - desc_chain - } - fn seek_from_start(&mut self, offset: usize) -> Result<()> { - if offset < self.bytes_consumed { - // Restart from the beginning of the descriptor chain. - self.bytes_consumed = 0; - self.avail_bytes = None; - self.desc_chain = self.desc_chain_start.clone(); + if bufs.is_empty() { + return Ok(0); } - let mut count = offset - self.bytes_consumed; - while count > 0 { - let bytes_consumed = self.consume(|_, _| Ok(()), count)?; - if bytes_consumed == 0 { + let bytes_consumed = f(&*bufs)?; + + // This can happen if a driver tricks a device into reading/writing more data than + // fits in a `usize`. + let total_bytes_consumed = + self.bytes_consumed + .checked_add(bytes_consumed) + .ok_or_else(|| { + io::Error::new(io::ErrorKind::InvalidData, Error::DescriptorChainOverflow) + })?; + + let mut rem = bytes_consumed; + while let Some(vs) = self.buffers.pop_front() { + if (rem as u64) < vs.size() { + // Split the slice and push the remainder back into the buffer list. Safe because we + // know that `rem` is not out of bounds due to the check and we checked the bounds + // on `vs` when we added it to the buffer list. + self.buffers.push_front(vs.offset(rem as u64).unwrap()); break; } - count -= bytes_consumed; + + // No need for checked math because we know that `vs.size() <= rem`. + rem -= vs.size() as usize; } - Ok(()) + self.bytes_consumed = total_bytes_consumed; + + Ok(bytes_consumed) } - fn seek(&mut self, pos: io::SeekFrom) -> io::Result<u64> { - fn apply_signed_offset(base: usize, offset: i64) -> io::Result<u64> { - let base = i64::try_from(base).map_err(|_| { - io::Error::new(io::ErrorKind::InvalidData, "seek position out of i64 range") - })?; - let result = base.checked_add(offset).ok_or(io::Error::new( - io::ErrorKind::InvalidData, - "seek offset overflowed", - ))?; - if result < 0 { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - "seek offset < 0", - )); + fn split_at(&mut self, offset: usize) -> Result<DescriptorChainConsumer<'a>> { + let mut rem = offset; + let pos = self.buffers.iter().position(|vs| { + if (rem as u64) < vs.size() { + true + } else { + rem -= vs.size() as usize; + false } - Ok(result as u64) - } - - let offset = match pos { - io::SeekFrom::Start(o) => o, - io::SeekFrom::Current(o) => apply_signed_offset(self.bytes_consumed(), o)?, - io::SeekFrom::End(o) => { - apply_signed_offset(self.bytes_consumed() + self.available_bytes(), o)? + }); + + if let Some(at) = pos { + let mut other = self.buffers.split_off(at); + + if rem > 0 { + // There must be at least one element in `other` because we checked + // its `size` value in the call to `position` above. + let front = other.pop_front().expect("empty VecDeque after split"); + self.buffers.push_back( + front + .sub_slice(0, rem as u64) + .map_err(Error::VolatileMemoryError)?, + ); + other.push_front( + front + .offset(rem as u64) + .map_err(Error::VolatileMemoryError)?, + ); } - }; - let offset = usize::try_from(offset).map_err(|_| { - io::Error::new(io::ErrorKind::InvalidData, "seek offset overflowed usize") - })?; - self.seek_from_start(offset) - .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?; - - Ok(self.bytes_consumed() as u64) + Ok(DescriptorChainConsumer { + buffers: other, + bytes_consumed: 0, + }) + } else if rem == 0 { + Ok(DescriptorChainConsumer { + buffers: VecDeque::new(), + bytes_consumed: 0, + }) + } else { + Err(Error::SplitOutOfBounds(offset)) + } } } @@ -203,107 +185,98 @@ impl<'a> DescriptorChainConsumer<'a> { /// descriptor is encountered. #[derive(Clone)] pub struct Reader<'a> { - mem: &'a GuestMemory, buffer: DescriptorChainConsumer<'a>, } impl<'a> Reader<'a> { /// Construct a new Reader wrapper over `desc_chain`. - pub fn new(mem: &'a GuestMemory, desc_chain: DescriptorChain<'a>) -> Reader<'a> { + pub fn new(mem: &'a GuestMemory, desc_chain: DescriptorChain<'a>) -> Result<Reader<'a>> { // TODO(jstaron): Update this code to take the indirect descriptors into account. - let desc_chain = if desc_chain.is_read_only() { - Some(desc_chain) - } else { - None - }; - Reader { - mem, - buffer: DescriptorChainConsumer::new(desc_chain, DescriptorFilter::OnlyReadable), - } - } - - /// Reads to a slice from the descriptor chain buffer. - /// Reads as many bytes as necessary to completely fill - /// the specified slice or to consume all bytes from the - /// descriptor chain buffer. Returns number of copied bytes. - pub fn read(&mut self, buf: &mut [u8]) -> Result<usize> { - let mem = self.mem; - let len = buf.len(); - let mut read_count = 0; - self.buffer.consume( - move |addr, count| { - let result = mem.read_exact_at_addr(&mut buf[read_count..read_count + count], addr); - if result.is_ok() { - read_count += count; - } - result.map_err(Error::GuestMemoryError) + let buffers = desc_chain + .into_iter() + .readable() + .map(|desc| { + mem.get_slice(desc.addr.offset(), desc.len.into()) + .map_err(Error::VolatileMemoryError) + }) + .collect::<Result<VecDeque<VolatileSlice<'a>>>>()?; + Ok(Reader { + buffer: DescriptorChainConsumer { + buffers, + bytes_consumed: 0, }, - len, - ) - } - - /// Reads to a slice from the descriptor chain. - /// Returns an error if there isn't enough data in the - /// descriptor chain buffer to fill the entire slice. Part of - /// the slice may have been filled nevertheless. - pub fn read_exact(&mut self, buf: &mut [u8]) -> Result<()> { - let count = self.read(buf)?; - if count == buf.len() { - Ok(()) - } else { - Err(Error::GuestMemoryError(GuestMemoryError::ShortRead { - expected: buf.len(), - completed: count, - })) - } + }) } /// Reads an object from the descriptor chain buffer. - pub fn read_obj<T: DataInit + Default>(&mut self) -> Result<T> { - let mut object: T = Default::default(); - self.read_exact(object.as_mut_slice()).map(|_| object) + pub fn read_obj<T: DataInit>(&mut self) -> io::Result<T> { + let mut obj = MaybeUninit::<T>::uninit(); + + // Safe because `MaybeUninit` guarantees that the pointer is valid for + // `size_of::<T>()` bytes. + let buf = unsafe { + ::std::slice::from_raw_parts_mut(obj.as_mut_ptr() as *mut u8, size_of::<T>()) + }; + + self.read_exact(buf)?; + + // Safe because any type that implements `DataInit` can be considered initialized + // even if it is filled with random data. + Ok(unsafe { obj.assume_init() }) } /// Reads data from the descriptor chain buffer into a file descriptor. /// Returns the number of bytes read from the descriptor chain buffer. /// The number of bytes read can be less than `count` if there isn't /// enough data in the descriptor chain buffer. - pub fn read_to(&mut self, dst: &dyn AsRawFd, count: usize) -> Result<usize> { - let mem = self.mem; - self.buffer.consume( - |addr, count| { - mem.write_from_memory(addr, dst, count) - .map_err(Error::GuestMemoryError) - }, - count, - ) + pub fn read_to<F: FileReadWriteVolatile>( + &mut self, + mut dst: F, + count: usize, + ) -> io::Result<usize> { + self.buffer + .consume(count, |bufs| dst.write_vectored_volatile(bufs)) } - /// Reads data from the descriptor chain buffer into a FileReadWriteVolatile. + /// Reads data from the descriptor chain buffer into a File at offset `off`. /// Returns the number of bytes read from the descriptor chain buffer. /// The number of bytes read can be less than `count` if there isn't /// enough data in the descriptor chain buffer. - pub fn read_to_volatile<T: FileReadWriteVolatile + ?Sized>( + pub fn read_to_at<F: FileReadWriteAtVolatile>( &mut self, - dst: &mut T, + mut dst: F, count: usize, - ) -> Result<usize> { - let mem = self.mem; - self.buffer.consume( - |addr, count| { - let mem_volatile_slice = mem - .get_slice(addr.offset(), count as u64) - .map_err(Error::VolatileMemoryError)?; - dst.write_all_volatile(mem_volatile_slice) - .map_err(Error::IoError)?; - Ok(()) - }, - count, - ) + off: u64, + ) -> io::Result<usize> { + self.buffer + .consume(count, |bufs| dst.write_vectored_at_volatile(bufs, off)) + } + + pub fn read_exact_to<F: FileReadWriteVolatile>( + &mut self, + mut dst: F, + mut count: usize, + ) -> io::Result<()> { + while count > 0 { + match self.read_to(&mut dst, count) { + Ok(0) => { + return Err(io::Error::new( + io::ErrorKind::UnexpectedEof, + "failed to fill whole buffer", + )) + } + Ok(n) => count -= n, + Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {} + Err(e) => return Err(e), + } + } + + Ok(()) } - /// Returns number of bytes available for reading. - pub fn available_bytes(&mut self) -> usize { + /// Returns number of bytes available for reading. May return an error if the combined + /// lengths of all the buffers in the DescriptorChain would cause an integer overflow. + pub fn available_bytes(&self) -> Result<usize> { self.buffer.available_bytes() } @@ -311,6 +284,29 @@ impl<'a> Reader<'a> { pub fn bytes_read(&self) -> usize { self.buffer.bytes_consumed() } + + /// Splits this `Reader` into two at the given offset in the `DescriptorChain` buffer. + /// After the split, `self` will be able to read up to `offset` bytes while the returned + /// `Reader` can read up to `available_bytes() - offset` bytes. Returns an error if + /// `offset > self.available_bytes()`. + pub fn split_at(&mut self, offset: usize) -> Result<Reader<'a>> { + self.buffer.split_at(offset).map(|buffer| Reader { buffer }) + } +} + +impl<'a> io::Read for Reader<'a> { + fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { + self.buffer.consume(buf.len(), |bufs| { + if let Some(vs) = bufs.first() { + // This is guaranteed by the implementation of `consume`. + debug_assert_eq!(vs.size(), cmp::min(buf.len() as u64, vs.size())); + vs.copy_to(buf); + Ok(vs.size() as usize) + } else { + Ok(0) + } + }) + } } /// Provides high-level interface over the sequence of memory regions @@ -320,65 +316,38 @@ impl<'a> Reader<'a> { /// descriptors after any device-readable descriptors (2.6.4.2 in Virtio Spec v1.1). /// Writer will start iterating the descriptors from the first writable one and will /// assume that all following descriptors are writable. +#[derive(Clone)] pub struct Writer<'a> { - mem: &'a GuestMemory, buffer: DescriptorChainConsumer<'a>, } impl<'a> Writer<'a> { /// Construct a new Writer wrapper over `desc_chain`. - pub fn new(mem: &'a GuestMemory, desc_chain: DescriptorChain<'a>) -> Writer<'a> { - // Skip all readable descriptors and get first writable one. - let desc_chain = desc_chain.into_iter().writable().next(); - Writer { - mem, - buffer: DescriptorChainConsumer::new(desc_chain, DescriptorFilter::OnlyWritable), - } - } - - /// Writes a slice to the descriptor chain buffer. - /// Returns the number of bytes written. The number of bytes written - /// can be less than the length of the slice if there isn't enough - /// space in the descriptor chain buffer. - pub fn write(&mut self, buf: &[u8]) -> Result<usize> { - let mem = self.mem; - let len = buf.len(); - let mut write_count = 0; - self.buffer.consume( - move |addr, count| { - let result = mem.write_all_at_addr(&buf[write_count..write_count + count], addr); - if result.is_ok() { - write_count += count; - } - result.map_err(Error::GuestMemoryError) + pub fn new(mem: &'a GuestMemory, desc_chain: DescriptorChain<'a>) -> Result<Writer<'a>> { + let buffers = desc_chain + .into_iter() + .writable() + .map(|desc| { + mem.get_slice(desc.addr.offset(), desc.len.into()) + .map_err(Error::VolatileMemoryError) + }) + .collect::<Result<VecDeque<VolatileSlice<'a>>>>()?; + Ok(Writer { + buffer: DescriptorChainConsumer { + buffers, + bytes_consumed: 0, }, - len, - ) - } - - /// Writes the entire contents of a slice to descriptor chain buffer. - /// Returns an error if there isn't enough room in the descriptor chain buffer - /// to complete the entire write. Part of the data may have been written - /// nevertheless. - pub fn write_all(&mut self, buf: &[u8]) -> Result<()> { - let count = self.write(buf)?; - if count == buf.len() { - Ok(()) - } else { - Err(Error::GuestMemoryError(GuestMemoryError::ShortRead { - expected: buf.len(), - completed: count, - })) - } + }) } /// Writes an object to the descriptor chain buffer. - pub fn write_obj<T: DataInit>(&mut self, val: T) -> Result<()> { + pub fn write_obj<T: DataInit>(&mut self, val: T) -> io::Result<()> { self.write_all(val.as_slice()) } - /// Returns number of bytes available for writing. - pub fn available_bytes(&mut self) -> usize { + /// Returns number of bytes available for writing. May return an error if the combined + /// lengths of all the buffers in the DescriptorChain would cause an overflow. + pub fn available_bytes(&self) -> Result<usize> { self.buffer.available_bytes() } @@ -386,57 +355,77 @@ impl<'a> Writer<'a> { /// Returns the number of bytes written to the descriptor chain buffer. /// The number of bytes written can be less than `count` if /// there isn't enough data in the descriptor chain buffer. - pub fn write_from(&mut self, src: &dyn AsRawFd, count: usize) -> Result<usize> { - let mem = self.mem; - self.buffer.consume( - |addr, count| { - mem.read_to_memory(addr, src, count) - .map_err(Error::GuestMemoryError) - }, - count, - ) + pub fn write_from<F: FileReadWriteVolatile>( + &mut self, + mut src: F, + count: usize, + ) -> io::Result<usize> { + self.buffer + .consume(count, |bufs| src.read_vectored_volatile(bufs)) } - /// Writes data to the descriptor chain buffer from a FileReadWriteVolatile. + /// Writes data to the descriptor chain buffer from a File at offset `off`. /// Returns the number of bytes written to the descriptor chain buffer. /// The number of bytes written can be less than `count` if /// there isn't enough data in the descriptor chain buffer. - pub fn write_from_volatile<T: FileReadWriteVolatile + ?Sized>( + pub fn write_from_at<F: FileReadWriteAtVolatile>( &mut self, - src: &mut T, + mut src: F, count: usize, - ) -> Result<usize> { - let mem = self.mem; - self.buffer.consume( - |addr, count| { - let mem_volatile_slice = mem - .get_slice(addr.offset(), count as u64) - .map_err(Error::VolatileMemoryError)?; - src.read_exact_volatile(mem_volatile_slice) - .map_err(Error::IoError)?; - Ok(()) - }, - count, - ) + off: u64, + ) -> io::Result<usize> { + self.buffer + .consume(count, |bufs| src.read_vectored_at_volatile(bufs, off)) + } + + pub fn write_all_from<F: FileReadWriteVolatile>( + &mut self, + mut src: F, + mut count: usize, + ) -> io::Result<()> { + while count > 0 { + match self.write_from(&mut src, count) { + Ok(0) => { + return Err(io::Error::new( + io::ErrorKind::WriteZero, + "failed to write whole buffer", + )) + } + Ok(n) => count -= n, + Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {} + Err(e) => return Err(e), + } + } + + Ok(()) } /// Returns number of bytes already written to the descriptor chain buffer. pub fn bytes_written(&self) -> usize { self.buffer.bytes_consumed() } -} -impl<'a> io::Read for Reader<'a> { - fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { - self.read(buf) - .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err)) + /// Splits this `Writer` into two at the given offset in the `DescriptorChain` buffer. + /// After the split, `self` will be able to write up to `offset` bytes while the returned + /// `Writer` can write up to `available_bytes() - offset` bytes. Returns an error if + /// `offset > self.available_bytes()`. + pub fn split_at(&mut self, offset: usize) -> Result<Writer<'a>> { + self.buffer.split_at(offset).map(|buffer| Writer { buffer }) } } impl<'a> io::Write for Writer<'a> { fn write(&mut self, buf: &[u8]) -> io::Result<usize> { - self.write(buf) - .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err)) + self.buffer.consume(buf.len(), |bufs| { + if let Some(vs) = bufs.first() { + // This is guaranteed by the implementation of `consume`. + debug_assert_eq!(vs.size(), cmp::min(buf.len() as u64, vs.size())); + vs.copy_from(buf); + Ok(vs.size() as usize) + } else { + Ok(0) + } + }) } fn flush(&mut self) -> io::Result<()> { @@ -445,18 +434,6 @@ impl<'a> io::Write for Writer<'a> { } } -impl<'a> io::Seek for Reader<'a> { - fn seek(&mut self, pos: io::SeekFrom) -> io::Result<u64> { - self.buffer.seek(pos) - } -} - -impl<'a> io::Seek for Writer<'a> { - fn seek(&mut self, pos: io::SeekFrom) -> io::Result<u64> { - self.buffer.seek(pos) - } -} - const VIRTQ_DESC_F_NEXT: u16 = 0x1; const VIRTQ_DESC_F_WRITE: u16 = 0x2; @@ -524,7 +501,6 @@ pub fn create_descriptor_chain( #[cfg(test)] mod tests { use super::*; - use std::io::{Seek, SeekFrom}; use sys_util::{MemfdSeals, SharedMemory}; #[test] @@ -547,8 +523,13 @@ mod tests { 0, ) .expect("create_descriptor_chain failed"); - let mut reader = Reader::new(&memory, chain); - assert_eq!(reader.available_bytes(), 106); + let mut reader = Reader::new(&memory, chain).expect("failed to create Reader"); + assert_eq!( + reader + .available_bytes() + .expect("failed to get available bytes"), + 106 + ); assert_eq!(reader.bytes_read(), 0); let mut buffer = [0 as u8; 64]; @@ -556,7 +537,12 @@ mod tests { panic!("read_exact should not fail here"); } - assert_eq!(reader.available_bytes(), 42); + assert_eq!( + reader + .available_bytes() + .expect("failed to get available bytes"), + 42 + ); assert_eq!(reader.bytes_read(), 64); match reader.read(&mut buffer) { @@ -564,7 +550,12 @@ mod tests { Ok(length) => assert_eq!(length, 42), } - assert_eq!(reader.available_bytes(), 0); + assert_eq!( + reader + .available_bytes() + .expect("failed to get available bytes"), + 0 + ); assert_eq!(reader.bytes_read(), 106); } @@ -588,8 +579,13 @@ mod tests { 0, ) .expect("create_descriptor_chain failed");; - let mut writer = Writer::new(&memory, chain); - assert_eq!(writer.available_bytes(), 106); + let mut writer = Writer::new(&memory, chain).expect("failed to create Writer"); + assert_eq!( + writer + .available_bytes() + .expect("failed to get available bytes"), + 106 + ); assert_eq!(writer.bytes_written(), 0); let mut buffer = [0 as u8; 64]; @@ -597,7 +593,12 @@ mod tests { panic!("write_all should not fail here"); } - assert_eq!(writer.available_bytes(), 42); + assert_eq!( + writer + .available_bytes() + .expect("failed to get available bytes"), + 42 + ); assert_eq!(writer.bytes_written(), 64); match writer.write(&mut buffer) { @@ -605,7 +606,12 @@ mod tests { Ok(length) => assert_eq!(length, 42), } - assert_eq!(writer.available_bytes(), 0); + assert_eq!( + writer + .available_bytes() + .expect("failed to get available bytes"), + 0 + ); assert_eq!(writer.bytes_written(), 106); } @@ -624,13 +630,23 @@ mod tests { 0, ) .expect("create_descriptor_chain failed");; - let mut reader = Reader::new(&memory, chain); - assert_eq!(reader.available_bytes(), 0); + let mut reader = Reader::new(&memory, chain).expect("failed to create Reader"); + assert_eq!( + reader + .available_bytes() + .expect("failed to get available bytes"), + 0 + ); assert_eq!(reader.bytes_read(), 0); assert!(reader.read_obj::<u8>().is_err()); - assert_eq!(reader.available_bytes(), 0); + assert_eq!( + reader + .available_bytes() + .expect("failed to get available bytes"), + 0 + ); assert_eq!(reader.bytes_read(), 0); } @@ -649,13 +665,23 @@ mod tests { 0, ) .expect("create_descriptor_chain failed");; - let mut writer = Writer::new(&memory, chain); - assert_eq!(writer.available_bytes(), 0); + let mut writer = Writer::new(&memory, chain).expect("failed to create Writer"); + assert_eq!( + writer + .available_bytes() + .expect("failed to get available bytes"), + 0 + ); assert_eq!(writer.bytes_written(), 0); assert!(writer.write_obj(0u8).is_err()); - assert_eq!(writer.available_bytes(), 0); + assert_eq!( + writer + .available_bytes() + .expect("failed to get available bytes"), + 0 + ); assert_eq!(writer.bytes_written(), 0); } @@ -675,7 +701,7 @@ mod tests { ) .expect("create_descriptor_chain failed");; - let mut reader = Reader::new(&memory, chain); + let mut reader = Reader::new(&memory, chain).expect("failed to create Reader"); // GuestMemory's write_from_memory requires raw file descriptor. let mut shm = SharedMemory::anon().unwrap(); @@ -686,14 +712,20 @@ mod tests { fd_seals.set_grow_seal(); shm.add_seals(fd_seals).unwrap(); - if let Ok(_) = reader.read_to(&shm, 512) { - panic!("read_to should fail here, got Ok(_) instead"); - } - - assert!(reader.available_bytes() < 512); - assert!(reader.available_bytes() > 0); - assert!(reader.bytes_read() < 512); - assert!(reader.bytes_read() > 0); + reader + .read_exact_to(&mut shm, 512) + .expect_err("successfully read more bytes than SharedMemory size"); + + // Linux doesn't do partial writes if you give a buffer larger than the remaining length of + // the shared memory. And since we passed an iovec with the full contents of the + // DescriptorChain we ended up not writing any bytes at all. + assert_eq!( + reader + .available_bytes() + .expect("failed to get available bytes"), + 512 + ); + assert_eq!(reader.bytes_read(), 0); } #[test] @@ -710,22 +742,25 @@ mod tests { vec![(Writable, 256), (Writable, 256)], 0, ) - .expect("create_descriptor_chain failed");; + .expect("create_descriptor_chain failed"); - let mut writer = Writer::new(&memory, chain); + let mut writer = Writer::new(&memory, chain).expect("failed to create Writer"); // GuestMemory's read_to_memory requires raw file descriptor. let mut shm = SharedMemory::anon().unwrap(); shm.set_size(384).unwrap(); - if let Ok(_) = writer.write_from(&shm, 512) { - panic!("write_from should fail here, got Ok(_) instead"); - } + writer + .write_all_from(&mut shm, 512) + .expect_err("successfully wrote more bytes than in SharedMemory"); - assert!(writer.available_bytes() < 512); - assert!(writer.available_bytes() > 0); - assert!(writer.bytes_written() < 512); - assert!(writer.bytes_written() > 0); + assert_eq!( + writer + .available_bytes() + .expect("failed to get available bytes"), + 128 + ); + assert_eq!(writer.bytes_written(), 384); } #[test] @@ -749,28 +784,40 @@ mod tests { ], 0, ) - .expect("create_descriptor_chain failed");; - let mut reader = Reader::new(&memory, chain.clone()); - let mut writer = Writer::new(&memory, chain); + .expect("create_descriptor_chain failed"); + let mut reader = Reader::new(&memory, chain.clone()).expect("failed to create Reader"); + let mut writer = Writer::new(&memory, chain).expect("failed to create Writer"); assert_eq!(reader.bytes_read(), 0); assert_eq!(writer.bytes_written(), 0); - let mut buffer = [0 as u8; 200]; + let mut buffer = Vec::with_capacity(200); - match reader.read(&mut buffer) { - Err(_) => panic!("read should not fail here"), - Ok(length) => assert_eq!(length, 128), - } + assert_eq!( + reader + .read_to_end(&mut buffer) + .expect("read should not fail here"), + 128 + ); - match writer.write(&mut buffer) { - Err(_) => panic!("write should not fail here"), - Ok(length) => assert_eq!(length, 68), - } + // The writable descriptors are only 68 bytes long. + writer + .write_all(&buffer[..68]) + .expect("write should not fail here"); - assert_eq!(reader.available_bytes(), 0); + assert_eq!( + reader + .available_bytes() + .expect("failed to get available bytes"), + 0 + ); assert_eq!(reader.bytes_read(), 128); - assert_eq!(writer.available_bytes(), 0); + assert_eq!( + writer + .available_bytes() + .expect("failed to get available bytes"), + 0 + ); assert_eq!(writer.bytes_written(), 68); } @@ -791,8 +838,8 @@ mod tests { vec![(Writable, 1), (Writable, 1), (Writable, 1), (Writable, 1)], 123, ) - .expect("create_descriptor_chain failed");; - let mut writer = Writer::new(&memory, chain_writer); + .expect("create_descriptor_chain failed"); + let mut writer = Writer::new(&memory, chain_writer).expect("failed to create Writer"); if let Err(_) = writer.write_obj(secret) { panic!("write_obj should not fail here"); } @@ -806,7 +853,7 @@ mod tests { 123, ) .expect("create_descriptor_chain failed"); - let mut reader = Reader::new(&memory, chain_reader); + let mut reader = Reader::new(&memory, chain_reader).expect("failed to create Reader"); match reader.read_obj::<Le32>() { Err(_) => panic!("read_obj should not fail here"), Ok(read_secret) => assert_eq!(read_secret, secret), @@ -814,63 +861,217 @@ mod tests { } #[test] - fn reader_seek_simple_chain() { + fn reader_unexpected_eof() { use DescriptorType::*; let memory_start_addr = GuestAddress(0x0); - let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap(); + let memory = GuestMemory::new(&vec![(memory_start_addr, 0x10000)]).unwrap(); + + let chain = create_descriptor_chain( + &memory, + GuestAddress(0x0), + GuestAddress(0x100), + vec![(Readable, 256), (Readable, 256)], + 0, + ) + .expect("create_descriptor_chain failed"); + + let mut reader = Reader::new(&memory, chain).expect("failed to create Reader"); + + let mut buf = Vec::with_capacity(1024); + buf.resize(1024, 0); + + assert_eq!( + reader + .read_exact(&mut buf[..]) + .expect_err("read more bytes than available") + .kind(), + io::ErrorKind::UnexpectedEof + ); + } + + #[test] + fn split_border() { + use DescriptorType::*; + + let memory_start_addr = GuestAddress(0x0); + let memory = GuestMemory::new(&vec![(memory_start_addr, 0x10000)]).unwrap(); let chain = create_descriptor_chain( &memory, GuestAddress(0x0), GuestAddress(0x100), vec![ - (Readable, 8), (Readable, 16), - (Readable, 18), - (Readable, 64), + (Readable, 16), + (Readable, 96), + (Writable, 64), + (Writable, 1), + (Writable, 3), ], 0, ) .expect("create_descriptor_chain failed");; - let mut reader = Reader::new(&memory, chain); - assert_eq!(reader.available_bytes(), 106); - assert_eq!(reader.bytes_read(), 0); + let mut reader = Reader::new(&memory, chain).expect("failed to create Reader"); + + let other = reader.split_at(32).expect("failed to split Reader"); + assert_eq!( + reader + .available_bytes() + .expect("failed to get available bytes"), + 32 + ); + assert_eq!( + other + .available_bytes() + .expect("failed to get available bytes"), + 96 + ); + } - // Skip some bytes. available_bytes() and bytes_read() should update accordingly. - reader - .seek(SeekFrom::Current(64)) - .expect("seek should not fail here"); - assert_eq!(reader.available_bytes(), 42); - assert_eq!(reader.bytes_read(), 64); + #[test] + fn split_middle() { + use DescriptorType::*; - // Seek past end of chain - position should point just past the last byte. - reader - .seek(SeekFrom::Current(64)) - .expect("seek should not fail here"); - assert_eq!(reader.available_bytes(), 0); - assert_eq!(reader.bytes_read(), 106); + let memory_start_addr = GuestAddress(0x0); + let memory = GuestMemory::new(&vec![(memory_start_addr, 0x10000)]).unwrap(); - // Seek back to the beginning. - reader - .seek(SeekFrom::Start(0)) - .expect("seek should not fail here"); - assert_eq!(reader.available_bytes(), 106); - assert_eq!(reader.bytes_read(), 0); + let chain = create_descriptor_chain( + &memory, + GuestAddress(0x0), + GuestAddress(0x100), + vec![ + (Readable, 16), + (Readable, 16), + (Readable, 96), + (Writable, 64), + (Writable, 1), + (Writable, 3), + ], + 0, + ) + .expect("create_descriptor_chain failed"); + let mut reader = Reader::new(&memory, chain).expect("failed to create Reader"); + + let other = reader.split_at(24).expect("failed to split Reader"); + assert_eq!( + reader + .available_bytes() + .expect("failed to get available bytes"), + 24 + ); + assert_eq!( + other + .available_bytes() + .expect("failed to get available bytes"), + 104 + ); + } - // Seek to one byte before the end. - reader - .seek(SeekFrom::End(-1)) - .expect("seek should not fail here"); - assert_eq!(reader.available_bytes(), 1); - assert_eq!(reader.bytes_read(), 105); + #[test] + fn split_end() { + use DescriptorType::*; - // Read the last byte. - let mut buffer = [0 as u8; 1]; - reader - .read_exact(&mut buffer) - .expect("read_exact should not fail here"); - assert_eq!(reader.available_bytes(), 0); - assert_eq!(reader.bytes_read(), 106); + let memory_start_addr = GuestAddress(0x0); + let memory = GuestMemory::new(&vec![(memory_start_addr, 0x10000)]).unwrap(); + + let chain = create_descriptor_chain( + &memory, + GuestAddress(0x0), + GuestAddress(0x100), + vec![ + (Readable, 16), + (Readable, 16), + (Readable, 96), + (Writable, 64), + (Writable, 1), + (Writable, 3), + ], + 0, + ) + .expect("create_descriptor_chain failed"); + let mut reader = Reader::new(&memory, chain).expect("failed to create Reader"); + + let other = reader.split_at(128).expect("failed to split Reader"); + assert_eq!( + reader + .available_bytes() + .expect("failed to get available bytes"), + 128 + ); + assert_eq!( + other + .available_bytes() + .expect("failed to get available bytes"), + 0 + ); + } + + #[test] + fn split_beginning() { + use DescriptorType::*; + + let memory_start_addr = GuestAddress(0x0); + let memory = GuestMemory::new(&vec![(memory_start_addr, 0x10000)]).unwrap(); + + let chain = create_descriptor_chain( + &memory, + GuestAddress(0x0), + GuestAddress(0x100), + vec![ + (Readable, 16), + (Readable, 16), + (Readable, 96), + (Writable, 64), + (Writable, 1), + (Writable, 3), + ], + 0, + ) + .expect("create_descriptor_chain failed"); + let mut reader = Reader::new(&memory, chain).expect("failed to create Reader"); + + let other = reader.split_at(0).expect("failed to split Reader"); + assert_eq!( + reader + .available_bytes() + .expect("failed to get available bytes"), + 0 + ); + assert_eq!( + other + .available_bytes() + .expect("failed to get available bytes"), + 128 + ); + } + + #[test] + fn split_outofbounds() { + use DescriptorType::*; + + let memory_start_addr = GuestAddress(0x0); + let memory = GuestMemory::new(&vec![(memory_start_addr, 0x10000)]).unwrap(); + + let chain = create_descriptor_chain( + &memory, + GuestAddress(0x0), + GuestAddress(0x100), + vec![ + (Readable, 16), + (Readable, 16), + (Readable, 96), + (Writable, 64), + (Writable, 1), + (Writable, 3), + ], + 0, + ) + .expect("create_descriptor_chain failed"); + let mut reader = Reader::new(&memory, chain).expect("failed to create Reader"); + + if let Ok(_) = reader.split_at(256) { + panic!("successfully split Reader with out of bounds offset"); + } } } |