diff options
-rw-r--r-- | devices/src/virtio/balloon.rs | 22 | ||||
-rw-r--r-- | devices/src/virtio/block.rs | 111 | ||||
-rw-r--r-- | devices/src/virtio/descriptor_utils.rs | 965 | ||||
-rw-r--r-- | devices/src/virtio/gpu/mod.rs | 78 | ||||
-rw-r--r-- | devices/src/virtio/gpu/protocol.rs | 25 | ||||
-rw-r--r-- | devices/src/virtio/input/mod.rs | 117 | ||||
-rw-r--r-- | devices/src/virtio/net.rs | 78 | ||||
-rw-r--r-- | devices/src/virtio/p9.rs | 14 | ||||
-rw-r--r-- | devices/src/virtio/pmem.rs | 76 | ||||
-rw-r--r-- | devices/src/virtio/rng.rs | 9 | ||||
-rw-r--r-- | devices/src/virtio/tpm.rs | 23 | ||||
-rw-r--r-- | sys_util/src/file_traits.rs | 4 | ||||
-rw-r--r-- | sys_util/src/guest_memory.rs | 9 |
13 files changed, 984 insertions, 547 deletions
diff --git a/devices/src/virtio/balloon.rs b/devices/src/virtio/balloon.rs index 633b3fc..ec16f88 100644 --- a/devices/src/virtio/balloon.rs +++ b/devices/src/virtio/balloon.rs @@ -94,11 +94,29 @@ impl Worker { let index = avail_desc.index; if inflate { - let mut reader = Reader::new(&self.mem, avail_desc); - let data_length = reader.available_bytes(); + let mut reader = match Reader::new(&self.mem, avail_desc) { + Ok(r) => r, + Err(e) => { + error!("balloon: failed to create reader: {}", e); + queue.add_used(&self.mem, index, 0); + needs_interrupt = true; + continue; + } + }; + let data_length = match reader.available_bytes() { + Ok(l) => l, + Err(e) => { + error!("balloon: failed to get available bytes: {}", e); + queue.add_used(&self.mem, index, 0); + needs_interrupt = true; + continue; + } + }; if data_length % 4 != 0 { error!("invalid inflate buffer size: {}", data_length); + queue.add_used(&self.mem, index, 0); + needs_interrupt = true; continue; } diff --git a/devices/src/virtio/block.rs b/devices/src/virtio/block.rs index 73fd416..b171864 100644 --- a/devices/src/virtio/block.rs +++ b/devices/src/virtio/block.rs @@ -3,7 +3,7 @@ // found in the LICENSE file. use std::fmt::{self, Display}; -use std::io::{self, Seek, SeekFrom}; +use std::io::{self, Seek, SeekFrom, Write}; use std::mem::size_of; use std::os::unix::io::{AsRawFd, RawFd}; use std::result; @@ -129,12 +129,14 @@ unsafe impl DataInit for virtio_blk_discard_write_zeroes {} #[derive(Debug)] enum ExecuteError { Descriptor(DescriptorError), + Read(io::Error), + WriteStatus(io::Error), /// Error arming the flush timer. Flush(io::Error), ReadIo { length: usize, sector: u64, - desc_error: DescriptorError, + desc_error: io::Error, }, ShortRead { sector: u64, @@ -149,7 +151,7 @@ enum ExecuteError { WriteIo { length: usize, sector: u64, - desc_error: DescriptorError, + desc_error: io::Error, }, ShortWrite { sector: u64, @@ -176,6 +178,8 @@ impl Display for ExecuteError { match self { Descriptor(e) => write!(f, "virtio descriptor error: {}", e), + Read(e) => write!(f, "failed to read message: {}", e), + WriteStatus(e) => write!(f, "failed to write request status: {}", e), Flush(e) => write!(f, "failed to flush: {}", e), ReadIo { length, @@ -247,6 +251,8 @@ impl ExecuteError { fn status(&self) -> u8 { match self { ExecuteError::Descriptor(_) => VIRTIO_BLK_S_IOERR, + ExecuteError::Read(_) => VIRTIO_BLK_S_IOERR, + ExecuteError::WriteStatus(_) => VIRTIO_BLK_S_IOERR, ExecuteError::Flush(_) => VIRTIO_BLK_S_IOERR, ExecuteError::ReadIo { .. } => VIRTIO_BLK_S_IOERR, ExecuteError::ShortRead { .. } => VIRTIO_BLK_S_IOERR, @@ -275,6 +281,50 @@ struct Worker { } impl Worker { + fn process_one_request( + avail_desc: DescriptorChain, + read_only: bool, + disk: &mut DiskFile, + disk_size: u64, + flush_timer: &mut TimerFd, + flush_timer_armed: &mut bool, + mem: &GuestMemory, + ) -> result::Result<usize, ExecuteError> { + let mut status_writer = + Writer::new(mem, avail_desc.clone()).map_err(ExecuteError::Descriptor)?; + let available_bytes = status_writer + .available_bytes() + .map_err(ExecuteError::Descriptor)?; + let status_offset = available_bytes + .checked_sub(1) + .ok_or(ExecuteError::MissingStatus)?; + + status_writer = status_writer + .split_at(status_offset) + .map_err(ExecuteError::Descriptor)?; + + let status = match Block::execute_request( + avail_desc, + read_only, + disk, + disk_size, + flush_timer, + flush_timer_armed, + mem, + ) { + Ok(()) => VIRTIO_BLK_S_OK, + Err(e) => { + error!("failed executing disk request: {}", e); + e.status() + } + }; + + status_writer + .write_all(&[status]) + .map_err(ExecuteError::WriteStatus)?; + Ok(available_bytes) + } + fn process_queue( &mut self, queue_index: usize, @@ -288,9 +338,8 @@ impl Worker { let mut needs_interrupt = false; while let Some(avail_desc) = queue.pop(&self.mem) { let desc_index = avail_desc.index; - let mut status_writer = Writer::new(&self.mem, avail_desc.clone()); - let status = match Block::execute_request( + let len = match Worker::process_one_request( avail_desc, self.read_only, &mut *self.disk_image, @@ -299,24 +348,11 @@ impl Worker { flush_timer_armed, &self.mem, ) { - Ok(()) => VIRTIO_BLK_S_OK, + Ok(len) => len, Err(e) => { - error!("failed executing disk request: {}", e); - e.status() - } - }; - - let len = if let Ok(status_offset) = status_writer.seek(SeekFrom::End(-1)) { - match status_writer.write_all(&[status]) { - Ok(_) => status_offset + 1, - Err(e) => { - error!("failed to write status: {}", e); - 0 - } + error!("block: failed to handle request: {}", e); + 0 } - } else { - error!("failed to seek to status location"); - 0 }; queue.add_used(&self.mem, desc_index, len as u32); @@ -541,11 +577,10 @@ impl Block { flush_timer_armed: &mut bool, mem: &GuestMemory, ) -> result::Result<(), ExecuteError> { - let mut reader = Reader::new(mem, avail_desc.clone()); - let mut writer = Writer::new(mem, avail_desc); + let mut reader = Reader::new(mem, avail_desc.clone()).map_err(ExecuteError::Descriptor)?; + let mut writer = Writer::new(mem, avail_desc).map_err(ExecuteError::Descriptor)?; - let req_header: virtio_blk_req_header = - reader.read_obj().map_err(ExecuteError::Descriptor)?; + let req_header: virtio_blk_req_header = reader.read_obj().map_err(ExecuteError::Read)?; let req_type = req_header.req_type.to_native(); let sector = req_header.sector.to_native(); @@ -580,6 +615,7 @@ impl Block { // The last byte of writer is virtio_blk_req::status, so subtract it from data_len. let data_len = writer .available_bytes() + .map_err(ExecuteError::Descriptor)? .checked_sub(1) .ok_or(ExecuteError::MissingStatus)?; let offset = sector @@ -588,14 +624,13 @@ impl Block { check_range(offset, data_len as u64, disk_size)?; disk.seek(SeekFrom::Start(offset)) .map_err(|e| ExecuteError::Seek { ioerr: e, sector })?; - let actual_length = - writer - .write_from_volatile(disk, data_len) - .map_err(|desc_error| ExecuteError::ReadIo { - length: data_len, - sector, - desc_error, - })?; + let actual_length = writer.write_from(disk, data_len).map_err(|desc_error| { + ExecuteError::ReadIo { + length: data_len, + sector, + desc_error, + } + })?; if actual_length < data_len { return Err(ExecuteError::ShortRead { sector, @@ -605,7 +640,7 @@ impl Block { } } VIRTIO_BLK_T_OUT => { - let data_len = reader.available_bytes(); + let data_len = reader.available_bytes().map_err(ExecuteError::Descriptor)?; let offset = sector .checked_shl(u32::from(SECTOR_SHIFT)) .ok_or(ExecuteError::OutOfRange)?; @@ -614,7 +649,7 @@ impl Block { .map_err(|e| ExecuteError::Seek { ioerr: e, sector })?; let actual_length = reader - .read_to_volatile(disk, data_len) + .read_to(disk, data_len) .map_err(|desc_error| ExecuteError::WriteIo { length: data_len, sector, @@ -635,9 +670,11 @@ impl Block { } } VIRTIO_BLK_T_DISCARD | VIRTIO_BLK_T_WRITE_ZEROES => { - while reader.available_bytes() >= size_of::<virtio_blk_discard_write_zeroes>() { + while reader.available_bytes().map_err(ExecuteError::Descriptor)? + >= size_of::<virtio_blk_discard_write_zeroes>() + { let seg: virtio_blk_discard_write_zeroes = - reader.read_obj().map_err(ExecuteError::Descriptor)?; + reader.read_obj().map_err(ExecuteError::Read)?; let sector = seg.sector.to_native(); let num_sectors = seg.num_sectors.to_native(); diff --git a/devices/src/virtio/descriptor_utils.rs b/devices/src/virtio/descriptor_utils.rs index 409134e..3c8ff87 100644 --- a/devices/src/virtio/descriptor_utils.rs +++ b/devices/src/virtio/descriptor_utils.rs @@ -3,23 +3,24 @@ // found in the LICENSE file. use std::cmp; -use std::convert::TryFrom; +use std::collections::VecDeque; use std::fmt::{self, Display}; -use std::io; -use std::os::unix::io::AsRawFd; +use std::io::{self, Read, Write}; +use std::mem::{size_of, MaybeUninit}; use std::result; -use data_model::{DataInit, Le16, Le32, Le64, VolatileMemory, VolatileMemoryError}; -use sys_util::guest_memory::Error as GuestMemoryError; -use sys_util::{FileReadWriteVolatile, GuestAddress, GuestMemory}; +use data_model::{DataInit, Le16, Le32, Le64, VolatileMemory, VolatileMemoryError, VolatileSlice}; +use sys_util::{FileReadWriteAtVolatile, FileReadWriteVolatile, GuestAddress, GuestMemory}; use super::DescriptorChain; #[derive(Debug)] pub enum Error { + DescriptorChainOverflow, GuestMemoryError(sys_util::GuestMemoryError), InvalidChain, IoError(io::Error), + SplitOutOfBounds(usize), VolatileMemoryError(VolatileMemoryError), } @@ -28,9 +29,14 @@ impl Display for Error { use self::Error::*; match self { + DescriptorChainOverflow => write!( + f, + "the combined length of all the buffers in a `DescriptorChain` would overflow" + ), GuestMemoryError(e) => write!(f, "descriptor guest memory error: {}", e), InvalidChain => write!(f, "invalid descriptor chain"), IoError(e) => write!(f, "descriptor I/O error: {}", e), + SplitOutOfBounds(off) => write!(f, "`DescriptorChain` split is out of bounds: {}", off), VolatileMemoryError(e) => write!(f, "volatile memory error: {}", e), } } @@ -40,157 +46,133 @@ pub type Result<T> = result::Result<T, Error>; impl std::error::Error for Error {} -#[derive(Clone, PartialEq, Eq)] -enum DescriptorFilter { - OnlyReadable, - OnlyWritable, -} - #[derive(Clone)] struct DescriptorChainConsumer<'a> { - offset: usize, - desc_chain: Option<DescriptorChain<'a>>, - desc_chain_start: Option<DescriptorChain<'a>>, + buffers: VecDeque<VolatileSlice<'a>>, bytes_consumed: usize, - avail_bytes: Option<usize>, - filter: DescriptorFilter, } impl<'a> DescriptorChainConsumer<'a> { - fn new( - desc_chain: Option<DescriptorChain<'a>>, - filter: DescriptorFilter, - ) -> DescriptorChainConsumer<'a> { - DescriptorChainConsumer { - offset: 0, - desc_chain: desc_chain.clone(), - desc_chain_start: desc_chain, - bytes_consumed: 0, - avail_bytes: None, - filter, - } - } - - fn available_bytes(&mut self) -> usize { - if let Some(bytes) = self.avail_bytes { - bytes - } else { - let mut chain = self.desc_chain.clone(); - let mut count = 0; - while let Some(desc) = chain { - count += desc.len as usize; - chain = self.advance(desc); - } - let bytes = count - self.offset; - self.avail_bytes = Some(bytes); - bytes - } + fn available_bytes(&self) -> Result<usize> { + self.buffers + .iter() + .try_fold(0usize, |count, vs| count.checked_add(vs.size() as usize)) + .ok_or(Error::DescriptorChainOverflow) } fn bytes_consumed(&self) -> usize { self.bytes_consumed } - fn consume<F>(&mut self, mut fnc: F, mut count: usize) -> Result<usize> + /// Consumes at most `count` bytes from the `DescriptorChain`. Callers must provide a function + /// that takes a `&[VolatileSlice]` and returns the total number of bytes consumed. This + /// function guarantees that the combined length of all the slices in the `&[VolatileSlice]` is + /// less than or equal to `count`. + /// + /// # Errors + /// + /// If the provided function returns any error then no bytes are consumed from the buffer and + /// the error is returned to the caller. + fn consume<F>(&mut self, count: usize, f: F) -> io::Result<usize> where - F: FnMut(GuestAddress, usize) -> Result<()>, + F: FnOnce(&[VolatileSlice]) -> io::Result<usize>, { - let mut bytes_consumed = 0; - while count > 0 { - if let Some(current) = &self.desc_chain { - let addr = current - .addr - .checked_add(self.offset as u64) - .ok_or_else(|| { - Error::GuestMemoryError(GuestMemoryError::InvalidGuestAddress(current.addr)) - })?; - let len = cmp::min(count, current.len as usize - self.offset); - fnc(addr, len)?; - - self.offset += len; - self.avail_bytes = self.avail_bytes.map(|av| av - len); - self.bytes_consumed += len; - bytes_consumed += len; - count -= len; - - if self.offset == current.len as usize { - self.offset = 0; - if let Some(desc_chain) = self.desc_chain.take() { - self.desc_chain = self.advance(desc_chain); - } - } - } else { - // Nothing left to read. + let mut buflen = 0; + let mut bufs = Vec::with_capacity(self.buffers.len()); + for &vs in &self.buffers { + if buflen >= count { break; } - } - Ok(bytes_consumed) - } - fn advance(&self, desc_chain: DescriptorChain<'a>) -> Option<DescriptorChain<'a>> { - let mut desc_chain = desc_chain.next_descriptor(); - // TODO(jstaron): Update this code to take the indirect descriptors into account. - if self.filter == DescriptorFilter::OnlyReadable { - // When encounter first write-only descriptor set `desc_chain` to None to stop - // further processing. - desc_chain = desc_chain.filter(DescriptorChain::is_read_only); + let rem = count - buflen; + if (rem as u64) < vs.size() { + let buf = vs.sub_slice(0, rem as u64).map_err(|e| { + io::Error::new(io::ErrorKind::InvalidData, Error::VolatileMemoryError(e)) + })?; + bufs.push(buf); + buflen += rem; + } else { + bufs.push(vs); + buflen += vs.size() as usize; + } } - desc_chain - } - fn seek_from_start(&mut self, offset: usize) -> Result<()> { - if offset < self.bytes_consumed { - // Restart from the beginning of the descriptor chain. - self.bytes_consumed = 0; - self.avail_bytes = None; - self.desc_chain = self.desc_chain_start.clone(); + if bufs.is_empty() { + return Ok(0); } - let mut count = offset - self.bytes_consumed; - while count > 0 { - let bytes_consumed = self.consume(|_, _| Ok(()), count)?; - if bytes_consumed == 0 { + let bytes_consumed = f(&*bufs)?; + + // This can happen if a driver tricks a device into reading/writing more data than + // fits in a `usize`. + let total_bytes_consumed = + self.bytes_consumed + .checked_add(bytes_consumed) + .ok_or_else(|| { + io::Error::new(io::ErrorKind::InvalidData, Error::DescriptorChainOverflow) + })?; + + let mut rem = bytes_consumed; + while let Some(vs) = self.buffers.pop_front() { + if (rem as u64) < vs.size() { + // Split the slice and push the remainder back into the buffer list. Safe because we + // know that `rem` is not out of bounds due to the check and we checked the bounds + // on `vs` when we added it to the buffer list. + self.buffers.push_front(vs.offset(rem as u64).unwrap()); break; } - count -= bytes_consumed; + + // No need for checked math because we know that `vs.size() <= rem`. + rem -= vs.size() as usize; } - Ok(()) + self.bytes_consumed = total_bytes_consumed; + + Ok(bytes_consumed) } - fn seek(&mut self, pos: io::SeekFrom) -> io::Result<u64> { - fn apply_signed_offset(base: usize, offset: i64) -> io::Result<u64> { - let base = i64::try_from(base).map_err(|_| { - io::Error::new(io::ErrorKind::InvalidData, "seek position out of i64 range") - })?; - let result = base.checked_add(offset).ok_or(io::Error::new( - io::ErrorKind::InvalidData, - "seek offset overflowed", - ))?; - if result < 0 { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - "seek offset < 0", - )); + fn split_at(&mut self, offset: usize) -> Result<DescriptorChainConsumer<'a>> { + let mut rem = offset; + let pos = self.buffers.iter().position(|vs| { + if (rem as u64) < vs.size() { + true + } else { + rem -= vs.size() as usize; + false } - Ok(result as u64) - } - - let offset = match pos { - io::SeekFrom::Start(o) => o, - io::SeekFrom::Current(o) => apply_signed_offset(self.bytes_consumed(), o)?, - io::SeekFrom::End(o) => { - apply_signed_offset(self.bytes_consumed() + self.available_bytes(), o)? + }); + + if let Some(at) = pos { + let mut other = self.buffers.split_off(at); + + if rem > 0 { + // There must be at least one element in `other` because we checked + // its `size` value in the call to `position` above. + let front = other.pop_front().expect("empty VecDeque after split"); + self.buffers.push_back( + front + .sub_slice(0, rem as u64) + .map_err(Error::VolatileMemoryError)?, + ); + other.push_front( + front + .offset(rem as u64) + .map_err(Error::VolatileMemoryError)?, + ); } - }; - let offset = usize::try_from(offset).map_err(|_| { - io::Error::new(io::ErrorKind::InvalidData, "seek offset overflowed usize") - })?; - self.seek_from_start(offset) - .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?; - - Ok(self.bytes_consumed() as u64) + Ok(DescriptorChainConsumer { + buffers: other, + bytes_consumed: 0, + }) + } else if rem == 0 { + Ok(DescriptorChainConsumer { + buffers: VecDeque::new(), + bytes_consumed: 0, + }) + } else { + Err(Error::SplitOutOfBounds(offset)) + } } } @@ -203,107 +185,98 @@ impl<'a> DescriptorChainConsumer<'a> { /// descriptor is encountered. #[derive(Clone)] pub struct Reader<'a> { - mem: &'a GuestMemory, buffer: DescriptorChainConsumer<'a>, } impl<'a> Reader<'a> { /// Construct a new Reader wrapper over `desc_chain`. - pub fn new(mem: &'a GuestMemory, desc_chain: DescriptorChain<'a>) -> Reader<'a> { + pub fn new(mem: &'a GuestMemory, desc_chain: DescriptorChain<'a>) -> Result<Reader<'a>> { // TODO(jstaron): Update this code to take the indirect descriptors into account. - let desc_chain = if desc_chain.is_read_only() { - Some(desc_chain) - } else { - None - }; - Reader { - mem, - buffer: DescriptorChainConsumer::new(desc_chain, DescriptorFilter::OnlyReadable), - } - } - - /// Reads to a slice from the descriptor chain buffer. - /// Reads as many bytes as necessary to completely fill - /// the specified slice or to consume all bytes from the - /// descriptor chain buffer. Returns number of copied bytes. - pub fn read(&mut self, buf: &mut [u8]) -> Result<usize> { - let mem = self.mem; - let len = buf.len(); - let mut read_count = 0; - self.buffer.consume( - move |addr, count| { - let result = mem.read_exact_at_addr(&mut buf[read_count..read_count + count], addr); - if result.is_ok() { - read_count += count; - } - result.map_err(Error::GuestMemoryError) + let buffers = desc_chain + .into_iter() + .readable() + .map(|desc| { + mem.get_slice(desc.addr.offset(), desc.len.into()) + .map_err(Error::VolatileMemoryError) + }) + .collect::<Result<VecDeque<VolatileSlice<'a>>>>()?; + Ok(Reader { + buffer: DescriptorChainConsumer { + buffers, + bytes_consumed: 0, }, - len, - ) - } - - /// Reads to a slice from the descriptor chain. - /// Returns an error if there isn't enough data in the - /// descriptor chain buffer to fill the entire slice. Part of - /// the slice may have been filled nevertheless. - pub fn read_exact(&mut self, buf: &mut [u8]) -> Result<()> { - let count = self.read(buf)?; - if count == buf.len() { - Ok(()) - } else { - Err(Error::GuestMemoryError(GuestMemoryError::ShortRead { - expected: buf.len(), - completed: count, - })) - } + }) } /// Reads an object from the descriptor chain buffer. - pub fn read_obj<T: DataInit + Default>(&mut self) -> Result<T> { - let mut object: T = Default::default(); - self.read_exact(object.as_mut_slice()).map(|_| object) + pub fn read_obj<T: DataInit>(&mut self) -> io::Result<T> { + let mut obj = MaybeUninit::<T>::uninit(); + + // Safe because `MaybeUninit` guarantees that the pointer is valid for + // `size_of::<T>()` bytes. + let buf = unsafe { + ::std::slice::from_raw_parts_mut(obj.as_mut_ptr() as *mut u8, size_of::<T>()) + }; + + self.read_exact(buf)?; + + // Safe because any type that implements `DataInit` can be considered initialized + // even if it is filled with random data. + Ok(unsafe { obj.assume_init() }) } /// Reads data from the descriptor chain buffer into a file descriptor. /// Returns the number of bytes read from the descriptor chain buffer. /// The number of bytes read can be less than `count` if there isn't /// enough data in the descriptor chain buffer. - pub fn read_to(&mut self, dst: &dyn AsRawFd, count: usize) -> Result<usize> { - let mem = self.mem; - self.buffer.consume( - |addr, count| { - mem.write_from_memory(addr, dst, count) - .map_err(Error::GuestMemoryError) - }, - count, - ) + pub fn read_to<F: FileReadWriteVolatile>( + &mut self, + mut dst: F, + count: usize, + ) -> io::Result<usize> { + self.buffer + .consume(count, |bufs| dst.write_vectored_volatile(bufs)) } - /// Reads data from the descriptor chain buffer into a FileReadWriteVolatile. + /// Reads data from the descriptor chain buffer into a File at offset `off`. /// Returns the number of bytes read from the descriptor chain buffer. /// The number of bytes read can be less than `count` if there isn't /// enough data in the descriptor chain buffer. - pub fn read_to_volatile<T: FileReadWriteVolatile + ?Sized>( + pub fn read_to_at<F: FileReadWriteAtVolatile>( &mut self, - dst: &mut T, + mut dst: F, count: usize, - ) -> Result<usize> { - let mem = self.mem; - self.buffer.consume( - |addr, count| { - let mem_volatile_slice = mem - .get_slice(addr.offset(), count as u64) - .map_err(Error::VolatileMemoryError)?; - dst.write_all_volatile(mem_volatile_slice) - .map_err(Error::IoError)?; - Ok(()) - }, - count, - ) + off: u64, + ) -> io::Result<usize> { + self.buffer + .consume(count, |bufs| dst.write_vectored_at_volatile(bufs, off)) + } + + pub fn read_exact_to<F: FileReadWriteVolatile>( + &mut self, + mut dst: F, + mut count: usize, + ) -> io::Result<()> { + while count > 0 { + match self.read_to(&mut dst, count) { + Ok(0) => { + return Err(io::Error::new( + io::ErrorKind::UnexpectedEof, + "failed to fill whole buffer", + )) + } + Ok(n) => count -= n, + Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {} + Err(e) => return Err(e), + } + } + + Ok(()) } - /// Returns number of bytes available for reading. - pub fn available_bytes(&mut self) -> usize { + /// Returns number of bytes available for reading. May return an error if the combined + /// lengths of all the buffers in the DescriptorChain would cause an integer overflow. + pub fn available_bytes(&self) -> Result<usize> { self.buffer.available_bytes() } @@ -311,6 +284,29 @@ impl<'a> Reader<'a> { pub fn bytes_read(&self) -> usize { self.buffer.bytes_consumed() } + + /// Splits this `Reader` into two at the given offset in the `DescriptorChain` buffer. + /// After the split, `self` will be able to read up to `offset` bytes while the returned + /// `Reader` can read up to `available_bytes() - offset` bytes. Returns an error if + /// `offset > self.available_bytes()`. + pub fn split_at(&mut self, offset: usize) -> Result<Reader<'a>> { + self.buffer.split_at(offset).map(|buffer| Reader { buffer }) + } +} + +impl<'a> io::Read for Reader<'a> { + fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { + self.buffer.consume(buf.len(), |bufs| { + if let Some(vs) = bufs.first() { + // This is guaranteed by the implementation of `consume`. + debug_assert_eq!(vs.size(), cmp::min(buf.len() as u64, vs.size())); + vs.copy_to(buf); + Ok(vs.size() as usize) + } else { + Ok(0) + } + }) + } } /// Provides high-level interface over the sequence of memory regions @@ -320,65 +316,38 @@ impl<'a> Reader<'a> { /// descriptors after any device-readable descriptors (2.6.4.2 in Virtio Spec v1.1). /// Writer will start iterating the descriptors from the first writable one and will /// assume that all following descriptors are writable. +#[derive(Clone)] pub struct Writer<'a> { - mem: &'a GuestMemory, buffer: DescriptorChainConsumer<'a>, } impl<'a> Writer<'a> { /// Construct a new Writer wrapper over `desc_chain`. - pub fn new(mem: &'a GuestMemory, desc_chain: DescriptorChain<'a>) -> Writer<'a> { - // Skip all readable descriptors and get first writable one. - let desc_chain = desc_chain.into_iter().writable().next(); - Writer { - mem, - buffer: DescriptorChainConsumer::new(desc_chain, DescriptorFilter::OnlyWritable), - } - } - - /// Writes a slice to the descriptor chain buffer. - /// Returns the number of bytes written. The number of bytes written - /// can be less than the length of the slice if there isn't enough - /// space in the descriptor chain buffer. - pub fn write(&mut self, buf: &[u8]) -> Result<usize> { - let mem = self.mem; - let len = buf.len(); - let mut write_count = 0; - self.buffer.consume( - move |addr, count| { - let result = mem.write_all_at_addr(&buf[write_count..write_count + count], addr); - if result.is_ok() { - write_count += count; - } - result.map_err(Error::GuestMemoryError) + pub fn new(mem: &'a GuestMemory, desc_chain: DescriptorChain<'a>) -> Result<Writer<'a>> { + let buffers = desc_chain + .into_iter() + .writable() + .map(|desc| { + mem.get_slice(desc.addr.offset(), desc.len.into()) + .map_err(Error::VolatileMemoryError) + }) + .collect::<Result<VecDeque<VolatileSlice<'a>>>>()?; + Ok(Writer { + buffer: DescriptorChainConsumer { + buffers, + bytes_consumed: 0, }, - len, - ) - } - - /// Writes the entire contents of a slice to descriptor chain buffer. - /// Returns an error if there isn't enough room in the descriptor chain buffer - /// to complete the entire write. Part of the data may have been written - /// nevertheless. - pub fn write_all(&mut self, buf: &[u8]) -> Result<()> { - let count = self.write(buf)?; - if count == buf.len() { - Ok(()) - } else { - Err(Error::GuestMemoryError(GuestMemoryError::ShortRead { - expected: buf.len(), - completed: count, - })) - } + }) } /// Writes an object to the descriptor chain buffer. - pub fn write_obj<T: DataInit>(&mut self, val: T) -> Result<()> { + pub fn write_obj<T: DataInit>(&mut self, val: T) -> io::Result<()> { self.write_all(val.as_slice()) } - /// Returns number of bytes available for writing. - pub fn available_bytes(&mut self) -> usize { + /// Returns number of bytes available for writing. May return an error if the combined + /// lengths of all the buffers in the DescriptorChain would cause an overflow. + pub fn available_bytes(&self) -> Result<usize> { self.buffer.available_bytes() } @@ -386,57 +355,77 @@ impl<'a> Writer<'a> { /// Returns the number of bytes written to the descriptor chain buffer. /// The number of bytes written can be less than `count` if /// there isn't enough data in the descriptor chain buffer. - pub fn write_from(&mut self, src: &dyn AsRawFd, count: usize) -> Result<usize> { - let mem = self.mem; - self.buffer.consume( - |addr, count| { - mem.read_to_memory(addr, src, count) - .map_err(Error::GuestMemoryError) - }, - count, - ) + pub fn write_from<F: FileReadWriteVolatile>( + &mut self, + mut src: F, + count: usize, + ) -> io::Result<usize> { + self.buffer + .consume(count, |bufs| src.read_vectored_volatile(bufs)) } - /// Writes data to the descriptor chain buffer from a FileReadWriteVolatile. + /// Writes data to the descriptor chain buffer from a File at offset `off`. /// Returns the number of bytes written to the descriptor chain buffer. /// The number of bytes written can be less than `count` if /// there isn't enough data in the descriptor chain buffer. - pub fn write_from_volatile<T: FileReadWriteVolatile + ?Sized>( + pub fn write_from_at<F: FileReadWriteAtVolatile>( &mut self, - src: &mut T, + mut src: F, count: usize, - ) -> Result<usize> { - let mem = self.mem; - self.buffer.consume( - |addr, count| { - let mem_volatile_slice = mem - .get_slice(addr.offset(), count as u64) - .map_err(Error::VolatileMemoryError)?; - src.read_exact_volatile(mem_volatile_slice) - .map_err(Error::IoError)?; - Ok(()) - }, - count, - ) + off: u64, + ) -> io::Result<usize> { + self.buffer + .consume(count, |bufs| src.read_vectored_at_volatile(bufs, off)) + } + + pub fn write_all_from<F: FileReadWriteVolatile>( + &mut self, + mut src: F, + mut count: usize, + ) -> io::Result<()> { + while count > 0 { + match self.write_from(&mut src, count) { + Ok(0) => { + return Err(io::Error::new( + io::ErrorKind::WriteZero, + "failed to write whole buffer", + )) + } + Ok(n) => count -= n, + Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {} + Err(e) => return Err(e), + } + } + + Ok(()) } /// Returns number of bytes already written to the descriptor chain buffer. pub fn bytes_written(&self) -> usize { self.buffer.bytes_consumed() } -} -impl<'a> io::Read for Reader<'a> { - fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { - self.read(buf) - .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err)) + /// Splits this `Writer` into two at the given offset in the `DescriptorChain` buffer. + /// After the split, `self` will be able to write up to `offset` bytes while the returned + /// `Writer` can write up to `available_bytes() - offset` bytes. Returns an error if + /// `offset > self.available_bytes()`. + pub fn split_at(&mut self, offset: usize) -> Result<Writer<'a>> { + self.buffer.split_at(offset).map(|buffer| Writer { buffer }) } } impl<'a> io::Write for Writer<'a> { fn write(&mut self, buf: &[u8]) -> io::Result<usize> { - self.write(buf) - .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err)) + self.buffer.consume(buf.len(), |bufs| { + if let Some(vs) = bufs.first() { + // This is guaranteed by the implementation of `consume`. + debug_assert_eq!(vs.size(), cmp::min(buf.len() as u64, vs.size())); + vs.copy_from(buf); + Ok(vs.size() as usize) + } else { + Ok(0) + } + }) } fn flush(&mut self) -> io::Result<()> { @@ -445,18 +434,6 @@ impl<'a> io::Write for Writer<'a> { } } -impl<'a> io::Seek for Reader<'a> { - fn seek(&mut self, pos: io::SeekFrom) -> io::Result<u64> { - self.buffer.seek(pos) - } -} - -impl<'a> io::Seek for Writer<'a> { - fn seek(&mut self, pos: io::SeekFrom) -> io::Result<u64> { - self.buffer.seek(pos) - } -} - const VIRTQ_DESC_F_NEXT: u16 = 0x1; const VIRTQ_DESC_F_WRITE: u16 = 0x2; @@ -524,7 +501,6 @@ pub fn create_descriptor_chain( #[cfg(test)] mod tests { use super::*; - use std::io::{Seek, SeekFrom}; use sys_util::{MemfdSeals, SharedMemory}; #[test] @@ -547,8 +523,13 @@ mod tests { 0, ) .expect("create_descriptor_chain failed"); - let mut reader = Reader::new(&memory, chain); - assert_eq!(reader.available_bytes(), 106); + let mut reader = Reader::new(&memory, chain).expect("failed to create Reader"); + assert_eq!( + reader + .available_bytes() + .expect("failed to get available bytes"), + 106 + ); assert_eq!(reader.bytes_read(), 0); let mut buffer = [0 as u8; 64]; @@ -556,7 +537,12 @@ mod tests { panic!("read_exact should not fail here"); } - assert_eq!(reader.available_bytes(), 42); + assert_eq!( + reader + .available_bytes() + .expect("failed to get available bytes"), + 42 + ); assert_eq!(reader.bytes_read(), 64); match reader.read(&mut buffer) { @@ -564,7 +550,12 @@ mod tests { Ok(length) => assert_eq!(length, 42), } - assert_eq!(reader.available_bytes(), 0); + assert_eq!( + reader + .available_bytes() + .expect("failed to get available bytes"), + 0 + ); assert_eq!(reader.bytes_read(), 106); } @@ -588,8 +579,13 @@ mod tests { 0, ) .expect("create_descriptor_chain failed");; - let mut writer = Writer::new(&memory, chain); - assert_eq!(writer.available_bytes(), 106); + let mut writer = Writer::new(&memory, chain).expect("failed to create Writer"); + assert_eq!( + writer + .available_bytes() + .expect("failed to get available bytes"), + 106 + ); assert_eq!(writer.bytes_written(), 0); let mut buffer = [0 as u8; 64]; @@ -597,7 +593,12 @@ mod tests { panic!("write_all should not fail here"); } - assert_eq!(writer.available_bytes(), 42); + assert_eq!( + writer + .available_bytes() + .expect("failed to get available bytes"), + 42 + ); assert_eq!(writer.bytes_written(), 64); match writer.write(&mut buffer) { @@ -605,7 +606,12 @@ mod tests { Ok(length) => assert_eq!(length, 42), } - assert_eq!(writer.available_bytes(), 0); + assert_eq!( + writer + .available_bytes() + .expect("failed to get available bytes"), + 0 + ); assert_eq!(writer.bytes_written(), 106); } @@ -624,13 +630,23 @@ mod tests { 0, ) .expect("create_descriptor_chain failed");; - let mut reader = Reader::new(&memory, chain); - assert_eq!(reader.available_bytes(), 0); + let mut reader = Reader::new(&memory, chain).expect("failed to create Reader"); + assert_eq!( + reader + .available_bytes() + .expect("failed to get available bytes"), + 0 + ); assert_eq!(reader.bytes_read(), 0); assert!(reader.read_obj::<u8>().is_err()); - assert_eq!(reader.available_bytes(), 0); + assert_eq!( + reader + .available_bytes() + .expect("failed to get available bytes"), + 0 + ); assert_eq!(reader.bytes_read(), 0); } @@ -649,13 +665,23 @@ mod tests { 0, ) .expect("create_descriptor_chain failed");; - let mut writer = Writer::new(&memory, chain); - assert_eq!(writer.available_bytes(), 0); + let mut writer = Writer::new(&memory, chain).expect("failed to create Writer"); + assert_eq!( + writer + .available_bytes() + .expect("failed to get available bytes"), + 0 + ); assert_eq!(writer.bytes_written(), 0); assert!(writer.write_obj(0u8).is_err()); - assert_eq!(writer.available_bytes(), 0); + assert_eq!( + writer + .available_bytes() + .expect("failed to get available bytes"), + 0 + ); assert_eq!(writer.bytes_written(), 0); } @@ -675,7 +701,7 @@ mod tests { ) .expect("create_descriptor_chain failed");; - let mut reader = Reader::new(&memory, chain); + let mut reader = Reader::new(&memory, chain).expect("failed to create Reader"); // GuestMemory's write_from_memory requires raw file descriptor. let mut shm = SharedMemory::anon().unwrap(); @@ -686,14 +712,20 @@ mod tests { fd_seals.set_grow_seal(); shm.add_seals(fd_seals).unwrap(); - if let Ok(_) = reader.read_to(&shm, 512) { - panic!("read_to should fail here, got Ok(_) instead"); - } - - assert!(reader.available_bytes() < 512); - assert!(reader.available_bytes() > 0); - assert!(reader.bytes_read() < 512); - assert!(reader.bytes_read() > 0); + reader + .read_exact_to(&mut shm, 512) + .expect_err("successfully read more bytes than SharedMemory size"); + + // Linux doesn't do partial writes if you give a buffer larger than the remaining length of + // the shared memory. And since we passed an iovec with the full contents of the + // DescriptorChain we ended up not writing any bytes at all. + assert_eq!( + reader + .available_bytes() + .expect("failed to get available bytes"), + 512 + ); + assert_eq!(reader.bytes_read(), 0); } #[test] @@ -710,22 +742,25 @@ mod tests { vec![(Writable, 256), (Writable, 256)], 0, ) - .expect("create_descriptor_chain failed");; + .expect("create_descriptor_chain failed"); - let mut writer = Writer::new(&memory, chain); + let mut writer = Writer::new(&memory, chain).expect("failed to create Writer"); // GuestMemory's read_to_memory requires raw file descriptor. let mut shm = SharedMemory::anon().unwrap(); shm.set_size(384).unwrap(); - if let Ok(_) = writer.write_from(&shm, 512) { - panic!("write_from should fail here, got Ok(_) instead"); - } + writer + .write_all_from(&mut shm, 512) + .expect_err("successfully wrote more bytes than in SharedMemory"); - assert!(writer.available_bytes() < 512); - assert!(writer.available_bytes() > 0); - assert!(writer.bytes_written() < 512); - assert!(writer.bytes_written() > 0); + assert_eq!( + writer + .available_bytes() + .expect("failed to get available bytes"), + 128 + ); + assert_eq!(writer.bytes_written(), 384); } #[test] @@ -749,28 +784,40 @@ mod tests { ], 0, ) - .expect("create_descriptor_chain failed");; - let mut reader = Reader::new(&memory, chain.clone()); - let mut writer = Writer::new(&memory, chain); + .expect("create_descriptor_chain failed"); + let mut reader = Reader::new(&memory, chain.clone()).expect("failed to create Reader"); + let mut writer = Writer::new(&memory, chain).expect("failed to create Writer"); assert_eq!(reader.bytes_read(), 0); assert_eq!(writer.bytes_written(), 0); - let mut buffer = [0 as u8; 200]; + let mut buffer = Vec::with_capacity(200); - match reader.read(&mut buffer) { - Err(_) => panic!("read should not fail here"), - Ok(length) => assert_eq!(length, 128), - } + assert_eq!( + reader + .read_to_end(&mut buffer) + .expect("read should not fail here"), + 128 + ); - match writer.write(&mut buffer) { - Err(_) => panic!("write should not fail here"), - Ok(length) => assert_eq!(length, 68), - } + // The writable descriptors are only 68 bytes long. + writer + .write_all(&buffer[..68]) + .expect("write should not fail here"); - assert_eq!(reader.available_bytes(), 0); + assert_eq!( + reader + .available_bytes() + .expect("failed to get available bytes"), + 0 + ); assert_eq!(reader.bytes_read(), 128); - assert_eq!(writer.available_bytes(), 0); + assert_eq!( + writer + .available_bytes() + .expect("failed to get available bytes"), + 0 + ); assert_eq!(writer.bytes_written(), 68); } @@ -791,8 +838,8 @@ mod tests { vec![(Writable, 1), (Writable, 1), (Writable, 1), (Writable, 1)], 123, ) - .expect("create_descriptor_chain failed");; - let mut writer = Writer::new(&memory, chain_writer); + .expect("create_descriptor_chain failed"); + let mut writer = Writer::new(&memory, chain_writer).expect("failed to create Writer"); if let Err(_) = writer.write_obj(secret) { panic!("write_obj should not fail here"); } @@ -806,7 +853,7 @@ mod tests { 123, ) .expect("create_descriptor_chain failed"); - let mut reader = Reader::new(&memory, chain_reader); + let mut reader = Reader::new(&memory, chain_reader).expect("failed to create Reader"); match reader.read_obj::<Le32>() { Err(_) => panic!("read_obj should not fail here"), Ok(read_secret) => assert_eq!(read_secret, secret), @@ -814,63 +861,217 @@ mod tests { } #[test] - fn reader_seek_simple_chain() { + fn reader_unexpected_eof() { use DescriptorType::*; let memory_start_addr = GuestAddress(0x0); - let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap(); + let memory = GuestMemory::new(&vec![(memory_start_addr, 0x10000)]).unwrap(); + + let chain = create_descriptor_chain( + &memory, + GuestAddress(0x0), + GuestAddress(0x100), + vec![(Readable, 256), (Readable, 256)], + 0, + ) + .expect("create_descriptor_chain failed"); + + let mut reader = Reader::new(&memory, chain).expect("failed to create Reader"); + + let mut buf = Vec::with_capacity(1024); + buf.resize(1024, 0); + + assert_eq!( + reader + .read_exact(&mut buf[..]) + .expect_err("read more bytes than available") + .kind(), + io::ErrorKind::UnexpectedEof + ); + } + + #[test] + fn split_border() { + use DescriptorType::*; + + let memory_start_addr = GuestAddress(0x0); + let memory = GuestMemory::new(&vec![(memory_start_addr, 0x10000)]).unwrap(); let chain = create_descriptor_chain( &memory, GuestAddress(0x0), GuestAddress(0x100), vec![ - (Readable, 8), (Readable, 16), - (Readable, 18), - (Readable, 64), + (Readable, 16), + (Readable, 96), + (Writable, 64), + (Writable, 1), + (Writable, 3), ], 0, ) .expect("create_descriptor_chain failed");; - let mut reader = Reader::new(&memory, chain); - assert_eq!(reader.available_bytes(), 106); - assert_eq!(reader.bytes_read(), 0); + let mut reader = Reader::new(&memory, chain).expect("failed to create Reader"); + + let other = reader.split_at(32).expect("failed to split Reader"); + assert_eq!( + reader + .available_bytes() + .expect("failed to get available bytes"), + 32 + ); + assert_eq!( + other + .available_bytes() + .expect("failed to get available bytes"), + 96 + ); + } - // Skip some bytes. available_bytes() and bytes_read() should update accordingly. - reader - .seek(SeekFrom::Current(64)) - .expect("seek should not fail here"); - assert_eq!(reader.available_bytes(), 42); - assert_eq!(reader.bytes_read(), 64); + #[test] + fn split_middle() { + use DescriptorType::*; - // Seek past end of chain - position should point just past the last byte. - reader - .seek(SeekFrom::Current(64)) - .expect("seek should not fail here"); - assert_eq!(reader.available_bytes(), 0); - assert_eq!(reader.bytes_read(), 106); + let memory_start_addr = GuestAddress(0x0); + let memory = GuestMemory::new(&vec![(memory_start_addr, 0x10000)]).unwrap(); - // Seek back to the beginning. - reader - .seek(SeekFrom::Start(0)) - .expect("seek should not fail here"); - assert_eq!(reader.available_bytes(), 106); - assert_eq!(reader.bytes_read(), 0); + let chain = create_descriptor_chain( + &memory, + GuestAddress(0x0), + GuestAddress(0x100), + vec![ + (Readable, 16), + (Readable, 16), + (Readable, 96), + (Writable, 64), + (Writable, 1), + (Writable, 3), + ], + 0, + ) + .expect("create_descriptor_chain failed"); + let mut reader = Reader::new(&memory, chain).expect("failed to create Reader"); + + let other = reader.split_at(24).expect("failed to split Reader"); + assert_eq!( + reader + .available_bytes() + .expect("failed to get available bytes"), + 24 + ); + assert_eq!( + other + .available_bytes() + .expect("failed to get available bytes"), + 104 + ); + } - // Seek to one byte before the end. - reader - .seek(SeekFrom::End(-1)) - .expect("seek should not fail here"); - assert_eq!(reader.available_bytes(), 1); - assert_eq!(reader.bytes_read(), 105); + #[test] + fn split_end() { + use DescriptorType::*; - // Read the last byte. - let mut buffer = [0 as u8; 1]; - reader - .read_exact(&mut buffer) - .expect("read_exact should not fail here"); - assert_eq!(reader.available_bytes(), 0); - assert_eq!(reader.bytes_read(), 106); + let memory_start_addr = GuestAddress(0x0); + let memory = GuestMemory::new(&vec![(memory_start_addr, 0x10000)]).unwrap(); + + let chain = create_descriptor_chain( + &memory, + GuestAddress(0x0), + GuestAddress(0x100), + vec![ + (Readable, 16), + (Readable, 16), + (Readable, 96), + (Writable, 64), + (Writable, 1), + (Writable, 3), + ], + 0, + ) + .expect("create_descriptor_chain failed"); + let mut reader = Reader::new(&memory, chain).expect("failed to create Reader"); + + let other = reader.split_at(128).expect("failed to split Reader"); + assert_eq!( + reader + .available_bytes() + .expect("failed to get available bytes"), + 128 + ); + assert_eq!( + other + .available_bytes() + .expect("failed to get available bytes"), + 0 + ); + } + + #[test] + fn split_beginning() { + use DescriptorType::*; + + let memory_start_addr = GuestAddress(0x0); + let memory = GuestMemory::new(&vec![(memory_start_addr, 0x10000)]).unwrap(); + + let chain = create_descriptor_chain( + &memory, + GuestAddress(0x0), + GuestAddress(0x100), + vec![ + (Readable, 16), + (Readable, 16), + (Readable, 96), + (Writable, 64), + (Writable, 1), + (Writable, 3), + ], + 0, + ) + .expect("create_descriptor_chain failed"); + let mut reader = Reader::new(&memory, chain).expect("failed to create Reader"); + + let other = reader.split_at(0).expect("failed to split Reader"); + assert_eq!( + reader + .available_bytes() + .expect("failed to get available bytes"), + 0 + ); + assert_eq!( + other + .available_bytes() + .expect("failed to get available bytes"), + 128 + ); + } + + #[test] + fn split_outofbounds() { + use DescriptorType::*; + + let memory_start_addr = GuestAddress(0x0); + let memory = GuestMemory::new(&vec![(memory_start_addr, 0x10000)]).unwrap(); + + let chain = create_descriptor_chain( + &memory, + GuestAddress(0x0), + GuestAddress(0x100), + vec![ + (Readable, 16), + (Readable, 16), + (Readable, 96), + (Writable, 64), + (Writable, 1), + (Writable, 3), + ], + 0, + ) + .expect("create_descriptor_chain failed"); + let mut reader = Reader::new(&memory, chain).expect("failed to create Reader"); + + if let Ok(_) = reader.split_at(256) { + panic!("successfully split Reader with out of bounds offset"); + } } } diff --git a/devices/src/virtio/gpu/mod.rs b/devices/src/virtio/gpu/mod.rs index 29142ea..97c84d5 100644 --- a/devices/src/virtio/gpu/mod.rs +++ b/devices/src/virtio/gpu/mod.rs @@ -8,6 +8,7 @@ mod protocol; use std::cell::RefCell; use std::collections::VecDeque; use std::i64; +use std::io::Read; use std::mem::{self, size_of}; use std::num::NonZeroU8; use std::os::unix::io::{AsRawFd, RawFd}; @@ -124,7 +125,14 @@ impl Frontend { mem, ), GpuCommand::ResourceAttachBacking(info) => { - if reader.available_bytes() != 0 { + let available_bytes = match reader.available_bytes() { + Ok(count) => count, + Err(e) => { + debug!("invalid descriptor: {}", e); + 0 + } + }; + if available_bytes != 0 { let entry_count = info.nr_entries.to_native() as usize; let mut iovecs = Vec::with_capacity(entry_count); for _ in 0..entry_count { @@ -247,10 +255,17 @@ impl Frontend { ) } GpuCommand::CmdSubmit3d(info) => { - if reader.available_bytes() != 0 { + let available_bytes = match reader.available_bytes() { + Ok(count) => count, + Err(e) => { + debug!("invalid descriptor: {}", e); + 0 + } + }; + if available_bytes != 0 { let cmd_size = info.size.to_native() as usize; let mut cmd_buf = vec![0; cmd_size]; - if reader.read(&mut cmd_buf[..]).is_ok() { + if reader.read_exact(&mut cmd_buf[..]).is_ok() { self.backend .submit_command(info.hdr.ctx_id.to_native(), &mut cmd_buf[..]) } else { @@ -263,7 +278,14 @@ impl Frontend { } } GpuCommand::AllocationMetadata(info) => { - if reader.available_bytes() != 0 { + let available_bytes = match reader.available_bytes() { + Ok(count) => count, + Err(e) => { + debug!("invalid descriptor: {}", e); + 0 + } + }; + if available_bytes != 0 { let id = info.request_id.to_native(); let request_size = info.request_size.to_native(); let response_size = info.response_size.to_native(); @@ -275,7 +297,7 @@ impl Frontend { let mut request_buf = vec![0; request_size as usize]; let response_buf = vec![0; response_size as usize]; - if reader.read(&mut request_buf[..]).is_ok() { + if reader.read_exact(&mut request_buf[..]).is_ok() { self.backend .allocation_metadata(id, request_buf, response_buf) } else { @@ -286,7 +308,14 @@ impl Frontend { } } GpuCommand::ResourceCreateV2(info) => { - if reader.available_bytes() != 0 { + let available_bytes = match reader.available_bytes() { + Ok(count) => count, + Err(e) => { + debug!("invalid descriptor: {}", e); + 0 + } + }; + if available_bytes != 0 { let resource_id = info.resource_id.to_native(); let guest_memory_type = info.guest_memory_type.to_native(); let size = info.size.to_native(); @@ -314,7 +343,7 @@ impl Frontend { } } - match reader.read(&mut args[..]) { + match reader.read_exact(&mut args[..]) { Ok(_) => self.backend.resource_create_v2( resource_id, guest_memory_type, @@ -346,13 +375,23 @@ impl Frontend { let mut signal_used = false; while let Some(desc) = queue.pop(mem) { if Frontend::validate_desc(&desc) { - let mut reader = Reader::new(mem, desc.clone()); - let mut writer = Writer::new(mem, desc.clone()); - if let Some(ret_desc) = - self.process_descriptor(mem, desc.index, &mut reader, &mut writer) - { - queue.add_used(&mem, ret_desc.index, ret_desc.len); - signal_used = true; + match ( + Reader::new(mem, desc.clone()), + Writer::new(mem, desc.clone()), + ) { + (Ok(mut reader), Ok(mut writer)) => { + if let Some(ret_desc) = + self.process_descriptor(mem, desc.index, &mut reader, &mut writer) + { + queue.add_used(&mem, ret_desc.index, ret_desc.len); + signal_used = true; + } + } + (_, Err(e)) | (Err(e), _) => { + debug!("invalid descriptor: {}", e); + queue.add_used(&mem, desc.index, 0); + signal_used = true; + } } } else { let likely_type = mem.read_obj_from_addr(desc.addr).unwrap_or(Le32::from(0)); @@ -391,7 +430,16 @@ impl Frontend { if resp.is_err() { debug!("{:?} -> {:?}", gpu_cmd, resp); } - if writer.available_bytes() != 0 { + + let available_bytes = match writer.available_bytes() { + Ok(count) => count, + Err(e) => { + debug!("invalid descriptor: {}", e); + 0 + } + }; + + if available_bytes != 0 { let mut fence_id = 0; let mut ctx_id = 0; let mut flags = 0; diff --git a/devices/src/virtio/gpu/protocol.rs b/devices/src/virtio/gpu/protocol.rs index c6773dd..dacc850 100644 --- a/devices/src/virtio/gpu/protocol.rs +++ b/devices/src/virtio/gpu/protocol.rs @@ -7,6 +7,7 @@ use std::cmp::min; use std::fmt::{self, Display}; +use std::io::{self, Write}; use std::marker::PhantomData; use std::mem::{size_of, size_of_val}; use std::str::from_utf8; @@ -589,6 +590,8 @@ pub enum GpuCommandDecodeError { Memory(DescriptorError), /// The type of the command was invalid. InvalidType(u32), + /// An I/O error occurred. + IO(io::Error), } impl Display for GpuCommandDecodeError { @@ -602,6 +605,7 @@ impl Display for GpuCommandDecodeError { e, ), InvalidType(n) => write!(f, "invalid command type ({})", n), + IO(e) => write!(f, "an I/O error occurred: {}", e), } } } @@ -612,6 +616,12 @@ impl From<DescriptorError> for GpuCommandDecodeError { } } +impl From<io::Error> for GpuCommandDecodeError { + fn from(e: io::Error) -> GpuCommandDecodeError { + GpuCommandDecodeError::IO(e) + } +} + impl fmt::Debug for GpuCommand { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { use self::GpuCommand::*; @@ -754,6 +764,8 @@ pub enum GpuResponseEncodeError { TooManyDisplays(usize), /// More planes than are valid were in a `OkResourcePlaneInfo`. TooManyPlanes(usize), + /// An I/O error occurred. + IO(io::Error), } impl Display for GpuResponseEncodeError { @@ -768,6 +780,7 @@ impl Display for GpuResponseEncodeError { ), TooManyDisplays(n) => write!(f, "{} is more displays than are valid", n), TooManyPlanes(n) => write!(f, "{} is more planes than are valid", n), + IO(e) => write!(f, "an I/O error occurred: {}", e), } } } @@ -778,6 +791,12 @@ impl From<DescriptorError> for GpuResponseEncodeError { } } +impl From<io::Error> for GpuResponseEncodeError { + fn from(e: io::Error) -> GpuResponseEncodeError { + GpuResponseEncodeError::IO(e) + } +} + impl GpuResponse { /// Encodes a this `GpuResponse` into `resp` and the given set of metadata. pub fn encode( @@ -823,7 +842,7 @@ impl GpuResponse { } GpuResponse::OkCapset(ref data) => { resp.write_obj(hdr)?; - resp.write(data)?; + resp.write_all(data)?; size_of_val(&hdr) + data.len() } GpuResponse::OkResourcePlaneInfo { @@ -847,7 +866,7 @@ impl GpuResponse { strides, offsets, }; - if resp.available_bytes() >= size_of_val(&plane_info) { + if resp.available_bytes()? >= size_of_val(&plane_info) { resp.write_obj(plane_info)?; size_of_val(&plane_info) } else { @@ -869,7 +888,7 @@ impl GpuResponse { }; resp.write_obj(resp_info)?; - resp.write(&res_info.response)?; + resp.write_all(&res_info.response)?; size_of_val(&resp_info) + res_info.response.len() } _ => { diff --git a/devices/src/virtio/input/mod.rs b/devices/src/virtio/input/mod.rs index ce99d52..2459d24 100644 --- a/devices/src/virtio/input/mod.rs +++ b/devices/src/virtio/input/mod.rs @@ -17,7 +17,8 @@ use sys_util::{error, warn, EventFd, GuestMemory, PollContext, PollToken}; use self::event_source::{input_event, EvdevEventSource, EventSource, SocketEventSource}; use super::{ - copy_config, Queue, Reader, VirtioDevice, Writer, INTERRUPT_STATUS_USED_RING, TYPE_INPUT, + copy_config, DescriptorChain, DescriptorError, Queue, Reader, VirtioDevice, Writer, + INTERRUPT_STATUS_USED_RING, TYPE_INPUT, }; use std::collections::BTreeMap; use std::fmt::{self, Display}; @@ -54,7 +55,14 @@ pub enum InputError { EvdevGrabError(sys_util::Error), // Detected error on guest side GuestError(String), + // Virtio descriptor error + Descriptor(DescriptorError), + // Error while reading from virtqueue + ReadQueue(std::io::Error), + // Error while writing to virtqueue + WriteQueue(std::io::Error), } + pub type Result<T> = std::result::Result<T, InputError>; impl Display for InputError { @@ -76,6 +84,9 @@ impl Display for InputError { } EvdevGrabError(e) => write!(f, "failed to grab event device: {}", e), GuestError(s) => write!(f, "detected error on guest side: {}", s), + Descriptor(e) => write!(f, "virtio descriptor error: {}", e), + ReadQueue(e) => write!(f, "failed to read from virtqueue: {}", e), + WriteQueue(e) => write!(f, "failed to write to virtqueue: {}", e), } } } @@ -365,6 +376,27 @@ impl<T: EventSource> Worker<T> { self.interrupt_evt.write(1).unwrap(); } + // Fills a virtqueue with events from the source. Returns the number of bytes written. + fn fill_event_virtqueue( + event_source: &mut T, + avail_desc: DescriptorChain, + mem: &GuestMemory, + ) -> Result<usize> { + let mut writer = Writer::new(mem, avail_desc).map_err(InputError::Descriptor)?; + + while writer.available_bytes().map_err(InputError::Descriptor)? + >= virtio_input_event::EVENT_SIZE + { + if let Some(evt) = event_source.pop_available_event() { + writer.write_obj(evt).map_err(InputError::WriteQueue)?; + } else { + break; + } + } + + Ok(writer.bytes_written()) + } + // Send events from the source to the guest fn send_events(&mut self) -> bool { let mut needs_interrupt = false; @@ -377,26 +409,23 @@ impl<T: EventSource> Worker<T> { } Some(avail_desc) => { let avail_desc_index = avail_desc.index; - let mut writer = Writer::new(&self.guest_memory, avail_desc); - while writer.available_bytes() >= virtio_input_event::EVENT_SIZE { - match self.event_source.pop_available_event() { - Some(evt) => { - if let Err(e) = writer.write_obj(evt) { - // An error here would mean the address and length given - // in the queue descriptor are wrong: Don't try to write - // to this buffer anymore. - error!("Could not write event to guest memory: {}", e); - break; - } - } - None => break, + + let bytes_written = match Worker::fill_event_virtqueue( + &mut self.event_source, + avail_desc, + &self.guest_memory, + ) { + Ok(count) => count, + Err(e) => { + error!("Input: failed to send events to guest: {}", e); + break; } - } + }; self.event_queue.add_used( &self.guest_memory, avail_desc_index, - writer.bytes_written() as u32, + bytes_written as u32, ); needs_interrupt = true; } @@ -406,38 +435,42 @@ impl<T: EventSource> Worker<T> { needs_interrupt } + // Sends events from the guest to the source. Returns the number of bytes read. + fn read_event_virtqueue( + avail_desc: DescriptorChain, + event_source: &mut T, + mem: &GuestMemory, + ) -> Result<usize> { + let mut reader = Reader::new(mem, avail_desc).map_err(InputError::Descriptor)?; + while reader.available_bytes().map_err(InputError::Descriptor)? + >= virtio_input_event::EVENT_SIZE + { + let evt: virtio_input_event = reader.read_obj().map_err(InputError::ReadQueue)?; + event_source.send_event(&evt)?; + } + + Ok(reader.bytes_read()) + } + fn process_status_queue(&mut self) -> Result<bool> { let mut needs_interrupt = false; while let Some(avail_desc) = self.status_queue.pop(&self.guest_memory) { let avail_desc_index = avail_desc.index; - let mut reader = Reader::new(&self.guest_memory, avail_desc); - if reader.available_bytes() % virtio_input_event::EVENT_SIZE != 0 { - warn!( - "Ignoring buffer of unexpected size on status queue: {:0}", - reader.available_bytes(), - ); - } else { - while reader.available_bytes() >= virtio_input_event::EVENT_SIZE { - match reader.read_obj::<virtio_input_event>() { - Ok(evt) => { - self.event_source.send_event(&evt)?; - } - Err(e) => { - // An error here would mean the address or length in the buffer - // descriptor was wrong: Don't try to read from this buffer - // anymore. - error!("Unable to read status event from guest memory: {}", e); - break; - } - } - } - } - self.status_queue.add_used( + let bytes_read = match Worker::read_event_virtqueue( + avail_desc, + &mut self.event_source, &self.guest_memory, - avail_desc_index, - reader.bytes_read() as u32, - ); + ) { + Ok(count) => count, + Err(e) => { + error!("Input: failed to read events from virtqueue: {}", e); + break; + } + }; + + self.status_queue + .add_used(&self.guest_memory, avail_desc_index, bytes_read as u32); needs_interrupt = true; } diff --git a/devices/src/virtio/net.rs b/devices/src/virtio/net.rs index 7680b66..6bc2e25 100644 --- a/devices/src/virtio/net.rs +++ b/devices/src/virtio/net.rs @@ -3,6 +3,7 @@ // found in the LICENSE file. use std::fmt::{self, Display}; +use std::io::{self, Read, Write}; use std::mem; use std::net::Ipv4Addr; use std::os::unix::io::{AsRawFd, RawFd}; @@ -13,15 +14,12 @@ use std::thread; use libc::{EAGAIN, EEXIST}; use net_sys; use net_util::{Error as TapError, MacAddress, TapT}; -use sys_util::guest_memory::Error as MemoryError; use sys_util::Error as SysError; use sys_util::{error, warn, EventFd, GuestMemory, PollContext, PollToken}; use virtio_sys::virtio_net::virtio_net_hdr_v1; use virtio_sys::{vhost, virtio_net}; -use super::{ - DescriptorError, Queue, Reader, VirtioDevice, Writer, INTERRUPT_STATUS_USED_RING, TYPE_NET, -}; +use super::{Queue, Reader, VirtioDevice, Writer, INTERRUPT_STATUS_USED_RING, TYPE_NET}; /// The maximum buffer size when segmentation offload is enabled. This /// includes the 12-byte virtio net header. @@ -122,23 +120,30 @@ where }; let index = desc_chain.index; - let mut writer = Writer::new(&self.mem, desc_chain); - - match writer.write_all(&self.rx_buf[0..self.rx_count]) { - Ok(()) => (), - Err(DescriptorError::GuestMemoryError(MemoryError::ShortWrite { .. })) => { - warn!( - "net: rx: buffer is too small to hold frame of size {}", - self.rx_count - ); + let bytes_written = match Writer::new(&self.mem, desc_chain) { + Ok(mut writer) => { + match writer.write_all(&self.rx_buf[0..self.rx_count]) { + Ok(()) => (), + Err(ref e) if e.kind() == io::ErrorKind::WriteZero => { + warn!( + "net: rx: buffer is too small to hold frame of size {}", + self.rx_count + ); + } + Err(e) => { + warn!("net: rx: failed to write slice: {}", e); + } + }; + + writer.bytes_written() as u32 } Err(e) => { - warn!("net: rx: failed to write slice: {}", e); + error!("net: failed to create Writer: {}", e); + 0 } - } + }; - self.rx_queue - .add_used(&self.mem, index, writer.bytes_written() as u32); + self.rx_queue.add_used(&self.mem, index, bytes_written); // Interrupt the guest immediately for received frames to // reduce latency. @@ -174,21 +179,38 @@ where fn process_tx(&mut self) { let mut frame = [0u8; MAX_BUFFER_SIZE]; + // Reads up to `buf.len()` bytes or until there is no more data in `r`, whichever + // is smaller. + fn read_to_end(mut r: Reader, buf: &mut [u8]) -> io::Result<usize> { + let mut count = 0; + while count < buf.len() { + match r.read(&mut buf[count..]) { + Ok(0) => break, + Ok(n) => count += n, + Err(e) => return Err(e), + } + } + + Ok(count) + } + while let Some(desc_chain) = self.tx_queue.pop(&self.mem) { let index = desc_chain.index; - let mut reader = Reader::new(&self.mem, desc_chain); - - match reader.read(&mut frame) { - // We need to copy frame into continuous buffer before writing it to tap - // because tap requires frame to complete in a single write. - Ok(read_count) => { - if let Err(err) = self.tap.write_all(&frame[..read_count]) { - error!("net: tx: failed to write to tap: {}", err); + + match Reader::new(&self.mem, desc_chain) { + Ok(reader) => { + match read_to_end(reader, &mut frame[..]) { + Ok(len) => { + // We need to copy frame into continuous buffer before writing it to tap + // because tap requires frame to complete in a single write. + if let Err(err) = self.tap.write_all(&frame[..len]) { + error!("net: tx: failed to write to tap: {}", err); + } + } + Err(e) => error!("net: tx: failed to read frame into buffer: {}", e), } } - Err(err) => { - error!("net: tx: failed to read frame into buffer: {}", err); - } + Err(e) => error!("net: failed to create Reader: {}", e), } self.tx_queue.add_used(&self.mem, index, 0); diff --git a/devices/src/virtio/p9.rs b/devices/src/virtio/p9.rs index 6968733..6d89a45 100644 --- a/devices/src/virtio/p9.rs +++ b/devices/src/virtio/p9.rs @@ -17,7 +17,8 @@ use sys_util::{error, warn, Error as SysError, EventFd, GuestMemory, PollContext use virtio_sys::vhost::VIRTIO_F_VERSION_1; use super::{ - copy_config, Queue, Reader, VirtioDevice, Writer, INTERRUPT_STATUS_USED_RING, TYPE_9P, + copy_config, DescriptorError, Queue, Reader, VirtioDevice, Writer, INTERRUPT_STATUS_USED_RING, + TYPE_9P, }; const QUEUE_SIZE: u16 = 128; @@ -45,6 +46,8 @@ pub enum P9Error { NoWritableDescriptors, /// Failed to signal the virio used queue. SignalUsedQueue(SysError), + /// A DescriptorChain contains invalid data. + InvalidDescriptorChain(DescriptorError), /// An internal I/O error occurred. Internal(io::Error), } @@ -73,6 +76,9 @@ impl Display for P9Error { NoReadableDescriptors => write!(f, "request does not have any readable descriptors"), NoWritableDescriptors => write!(f, "request does not have any writable descriptors"), SignalUsedQueue(err) => write!(f, "failed to signal used queue: {}", err), + InvalidDescriptorChain(err) => { + write!(f, "DescriptorChain contains invalid data: {}", err) + } Internal(err) => write!(f, "P9 internal server error: {}", err), } } @@ -98,8 +104,10 @@ impl Worker { fn process_queue(&mut self) -> P9Result<()> { while let Some(avail_desc) = self.queue.pop(&self.mem) { - let mut reader = Reader::new(&self.mem, avail_desc.clone()); - let mut writer = Writer::new(&self.mem, avail_desc.clone()); + let mut reader = Reader::new(&self.mem, avail_desc.clone()) + .map_err(P9Error::InvalidDescriptorChain)?; + let mut writer = Writer::new(&self.mem, avail_desc.clone()) + .map_err(P9Error::InvalidDescriptorChain)?; self.server .handle_message(&mut reader, &mut writer) diff --git a/devices/src/virtio/pmem.rs b/devices/src/virtio/pmem.rs index 1553ff2..4df295b 100644 --- a/devices/src/virtio/pmem.rs +++ b/devices/src/virtio/pmem.rs @@ -2,7 +2,9 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. +use std::fmt::{self, Display}; use std::fs::File; +use std::io; use std::os::unix::io::{AsRawFd, RawFd}; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; @@ -14,8 +16,8 @@ use sys_util::{error, EventFd, GuestAddress, GuestMemory, PollContext, PollToken use data_model::{DataInit, Le32, Le64}; use super::{ - copy_config, Queue, Reader, VirtioDevice, Writer, INTERRUPT_STATUS_USED_RING, TYPE_PMEM, - VIRTIO_F_VERSION_1, + copy_config, DescriptorChain, DescriptorError, Queue, Reader, VirtioDevice, Writer, + INTERRUPT_STATUS_USED_RING, TYPE_PMEM, VIRTIO_F_VERSION_1, }; const QUEUE_SIZE: u16 = 256; @@ -53,6 +55,32 @@ struct virtio_pmem_req { // Safe because it only has data and has no implicit padding. unsafe impl DataInit for virtio_pmem_req {} +#[derive(Debug)] +enum Error { + /// Invalid virtio descriptor chain. + Descriptor(DescriptorError), + /// Failed to read from virtqueue. + ReadQueue(io::Error), + /// Failed to write to virtqueue. + WriteQueue(io::Error), +} + +impl Display for Error { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + use self::Error::*; + + match self { + Descriptor(e) => write!(f, "virtio descriptor error: {}", e), + ReadQueue(e) => write!(f, "failed to read from virtqueue: {}", e), + WriteQueue(e) => write!(f, "failed to write to virtqueue: {}", e), + } + } +} + +impl ::std::error::Error for Error {} + +type Result<T> = ::std::result::Result<T, Error>; + struct Worker { queue: Queue, memory: GuestMemory, @@ -79,33 +107,39 @@ impl Worker { } } + fn handle_request(&self, avail_desc: DescriptorChain) -> Result<usize> { + let mut reader = + Reader::new(&self.memory, avail_desc.clone()).map_err(Error::Descriptor)?; + let mut writer = Writer::new(&self.memory, avail_desc).map_err(Error::Descriptor)?; + + let status_code = reader + .read_obj() + .map(|request| self.execute_request(request)) + .map_err(Error::ReadQueue)?; + + let response = virtio_pmem_resp { + status_code: status_code.into(), + }; + + writer.write_obj(response).map_err(Error::WriteQueue)?; + + Ok(writer.bytes_written()) + } + fn process_queue(&mut self) -> bool { let mut needs_interrupt = false; while let Some(avail_desc) = self.queue.pop(&self.memory) { let avail_desc_index = avail_desc.index; - let mut reader = Reader::new(&self.memory, avail_desc.clone()); - let mut writer = Writer::new(&self.memory, avail_desc); - let status_code = match reader.read_obj::<virtio_pmem_req>() { - Ok(request) => self.execute_request(request), + let bytes_written = match self.handle_request(avail_desc) { + Ok(count) => count, Err(e) => { - error!("failed to read virtio_pmem_req: {}", e); - VIRTIO_PMEM_RESP_TYPE_EIO + error!("pmem: unable to handle request: {}", e); + 0 } }; - - let response = virtio_pmem_resp { - status_code: status_code.into(), - }; - if let Err(e) = writer.write_obj(response) { - error!("failed to write virtio_pmem_resp: {}", e); - } - - self.queue.add_used( - &self.memory, - avail_desc_index, - writer.bytes_written() as u32, - ); + self.queue + .add_used(&self.memory, avail_desc_index, bytes_written as u32); needs_interrupt = true; } diff --git a/devices/src/virtio/rng.rs b/devices/src/virtio/rng.rs index 4e6448f..17dca69 100644 --- a/devices/src/virtio/rng.rs +++ b/devices/src/virtio/rng.rs @@ -51,15 +51,18 @@ impl Worker { let mut needs_interrupt = false; while let Some(avail_desc) = queue.pop(&self.mem) { let index = avail_desc.index; - let mut writer = Writer::new(&self.mem, avail_desc); - // Fill the entire descriptor chain buffer with random bytes. - let written = match writer.write_from(&self.random_file, std::usize::MAX) { + let random_file = &mut self.random_file; + let written = match Writer::new(&self.mem, avail_desc) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e)) + .and_then(|mut writer| writer.write_from(random_file, std::usize::MAX)) + { Ok(n) => n, Err(e) => { warn!("Failed to write random data to the guest: {}", e); 0 } }; + queue.add_used(&self.mem, index, written as u32); needs_interrupt = true; } diff --git a/devices/src/virtio/tpm.rs b/devices/src/virtio/tpm.rs index a3cdf48..4e557dc 100644 --- a/devices/src/virtio/tpm.rs +++ b/devices/src/virtio/tpm.rs @@ -5,6 +5,7 @@ use std::env; use std::fmt::{self, Display}; use std::fs; +use std::io::{self, Read, Write}; use std::ops::BitOrAssign; use std::os::unix::io::RawFd; use std::path::PathBuf; @@ -48,16 +49,17 @@ struct Device { impl Device { fn perform_work(&mut self, mem: &GuestMemory, desc: DescriptorChain) -> Result<u32> { - let mut reader = Reader::new(mem, desc.clone()); - let mut writer = Writer::new(mem, desc); + let mut reader = Reader::new(mem, desc.clone()).map_err(Error::Descriptor)?; + let mut writer = Writer::new(mem, desc).map_err(Error::Descriptor)?; - if reader.available_bytes() > TPM_BUFSIZE { + let available_bytes = reader.available_bytes().map_err(Error::Descriptor)?; + if available_bytes > TPM_BUFSIZE { return Err(Error::CommandTooLong { - size: reader.available_bytes(), + size: available_bytes, }); } - let mut command = vec![0u8; reader.available_bytes() as usize]; + let mut command = vec![0u8; available_bytes]; reader.read_exact(&mut command).map_err(Error::Read)?; let response = self.simulator.execute_command(&command); @@ -68,9 +70,10 @@ impl Device { }); } - if response.len() > writer.available_bytes() { + let writer_len = writer.available_bytes().map_err(Error::Descriptor)?; + if response.len() > writer_len { return Err(Error::BufferTooSmall { - size: writer.available_bytes(), + size: writer_len, required: response.len(), }); } @@ -287,10 +290,11 @@ type Result<T> = std::result::Result<T, Error>; enum Error { CommandTooLong { size: usize }, - Read(DescriptorError), + Descriptor(DescriptorError), + Read(io::Error), ResponseTooLong { size: usize }, BufferTooSmall { size: usize, required: usize }, - Write(DescriptorError), + Write(io::Error), } impl Display for Error { @@ -303,6 +307,7 @@ impl Display for Error { "vtpm command is too long: {} > {} bytes", size, TPM_BUFSIZE ), + Descriptor(e) => write!(f, "virtio descriptor error: {}", e), Read(e) => write!(f, "vtpm failed to read from guest memory: {}", e), ResponseTooLong { size } => write!( f, diff --git a/sys_util/src/file_traits.rs b/sys_util/src/file_traits.rs index d35bc4f..f296d9b 100644 --- a/sys_util/src/file_traits.rs +++ b/sys_util/src/file_traits.rs @@ -111,7 +111,7 @@ pub trait FileReadWriteVolatile { } } -impl<'a, T: FileReadWriteVolatile> FileReadWriteVolatile for &'a mut T { +impl<'a, T: FileReadWriteVolatile + ?Sized> FileReadWriteVolatile for &'a mut T { fn read_volatile(&mut self, slice: VolatileSlice) -> Result<usize> { (**self).read_volatile(slice) } @@ -208,7 +208,7 @@ pub trait FileReadWriteAtVolatile { } } -impl<'a, T: FileReadWriteAtVolatile> FileReadWriteAtVolatile for &'a mut T { +impl<'a, T: FileReadWriteAtVolatile + ?Sized> FileReadWriteAtVolatile for &'a mut T { fn read_at_volatile(&mut self, slice: VolatileSlice, offset: u64) -> Result<usize> { (**self).read_at_volatile(slice, offset) } diff --git a/sys_util/src/guest_memory.rs b/sys_util/src/guest_memory.rs index ac7722e..1246a9c 100644 --- a/sys_util/src/guest_memory.rs +++ b/sys_util/src/guest_memory.rs @@ -18,6 +18,7 @@ use data_model::DataInit; #[derive(Debug)] pub enum Error { + DescriptorChainOverflow, InvalidGuestAddress(GuestAddress), MemoryAccess(GuestAddress, mmap::Error), MemoryMappingFailed(mmap::Error), @@ -28,6 +29,8 @@ pub enum Error { MemoryAddSealsFailed(errno::Error), ShortWrite { expected: usize, completed: usize }, ShortRead { expected: usize, completed: usize }, + SplitOutOfBounds(usize), + VolatileMemoryAccess(VolatileMemoryError), } pub type Result<T> = result::Result<T, Error>; @@ -38,6 +41,10 @@ impl Display for Error { use self::Error::*; match self { + DescriptorChainOverflow => write!( + f, + "the combined length of all the buffers in a DescriptorChain is too large" + ), InvalidGuestAddress(addr) => write!(f, "invalid guest address {}", addr), MemoryAccess(addr, e) => { write!(f, "invalid guest memory access at addr={}: {}", addr, e) @@ -64,6 +71,8 @@ impl Display for Error { "incomplete read of {} instead of {} bytes", completed, expected, ), + SplitOutOfBounds(off) => write!(f, "DescriptorChain split is out of bounds: {}", off), + VolatileMemoryAccess(e) => e.fmt(f), } } } |