From 977f008bc3be607600a98d77ef6984b52dba7eb6 Mon Sep 17 00:00:00 2001 From: Daniel Verkamp Date: Wed, 24 Jul 2019 14:10:19 -0700 Subject: devices: virtio: add seek() for descriptor chains This allows moving the read/write cursor around within a chain of descriptors through the standard io::Seek interface. BUG=chromium:990546 TEST=./build_test Change-Id: I26ed368d3c7592188241a343dfeb922f3423d935 Signed-off-by: Daniel Verkamp Reviewed-on: https://chromium-review.googlesource.com/c/chromiumos/platform/crosvm/+/1721369 Tested-by: kokoro Reviewed-by: Zach Reizner --- devices/src/virtio/descriptor_utils.rs | 133 ++++++++++++++++++++++++++++++++- 1 file changed, 132 insertions(+), 1 deletion(-) (limited to 'devices/src/virtio/descriptor_utils.rs') diff --git a/devices/src/virtio/descriptor_utils.rs b/devices/src/virtio/descriptor_utils.rs index 9d18a63..a206002 100644 --- a/devices/src/virtio/descriptor_utils.rs +++ b/devices/src/virtio/descriptor_utils.rs @@ -3,6 +3,7 @@ // found in the LICENSE file. use std::cmp; +use std::convert::TryFrom; use std::fmt::{self, Display}; use std::io; use std::os::unix::io::AsRawFd; @@ -44,6 +45,7 @@ enum DescriptorFilter { struct DescriptorChainConsumer<'a> { offset: usize, desc_chain: Option>, + desc_chain_start: Option>, bytes_consumed: usize, avail_bytes: Option, filter: DescriptorFilter, @@ -56,7 +58,8 @@ impl<'a> DescriptorChainConsumer<'a> { ) -> DescriptorChainConsumer<'a> { DescriptorChainConsumer { offset: 0, - desc_chain, + desc_chain: desc_chain.clone(), + desc_chain_start: desc_chain, bytes_consumed: 0, avail_bytes: None, filter, @@ -129,6 +132,61 @@ impl<'a> DescriptorChainConsumer<'a> { } desc_chain } + + fn seek_from_start(&mut self, offset: usize) -> Result<()> { + if offset < self.bytes_consumed { + // Restart from the beginning of the descriptor chain. + self.bytes_consumed = 0; + self.avail_bytes = None; + self.desc_chain = self.desc_chain_start.clone(); + } + + let mut count = offset - self.bytes_consumed; + while count > 0 { + let bytes_consumed = self.consume(|_, _| Ok(()), count)?; + if bytes_consumed == 0 { + break; + } + count -= bytes_consumed; + } + + Ok(()) + } + + fn seek(&mut self, pos: io::SeekFrom) -> io::Result { + fn apply_signed_offset(base: usize, offset: i64) -> io::Result { + let base = i64::try_from(base).map_err(|_| { + io::Error::new(io::ErrorKind::InvalidData, "seek position out of i64 range") + })?; + let result = base.checked_add(offset).ok_or(io::Error::new( + io::ErrorKind::InvalidData, + "seek offset overflowed", + ))?; + if result < 0 { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "seek offset < 0", + )); + } + Ok(result as u64) + } + + let offset = match pos { + io::SeekFrom::Start(o) => o, + io::SeekFrom::Current(o) => apply_signed_offset(self.bytes_consumed(), o)?, + io::SeekFrom::End(o) => { + apply_signed_offset(self.bytes_consumed() + self.available_bytes(), o)? + } + }; + + let offset = usize::try_from(offset).map_err(|_| { + io::Error::new(io::ErrorKind::InvalidData, "seek offset overflowed usize") + })?; + self.seek_from_start(offset) + .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?; + + Ok(self.bytes_consumed() as u64) + } } /// Provides high-level interface over the sequence of memory regions @@ -335,10 +393,23 @@ impl<'a> io::Write for Writer<'a> { } } +impl<'a> io::Seek for Reader<'a> { + fn seek(&mut self, pos: io::SeekFrom) -> io::Result { + self.buffer.seek(pos) + } +} + +impl<'a> io::Seek for Writer<'a> { + fn seek(&mut self, pos: io::SeekFrom) -> io::Result { + self.buffer.seek(pos) + } +} + #[cfg(test)] mod tests { use super::*; use data_model::{Le16, Le32, Le64}; + use std::io::{Seek, SeekFrom}; use sys_util::{MemfdSeals, SharedMemory}; const VIRTQ_DESC_F_NEXT: u16 = 0x1; @@ -677,4 +748,64 @@ mod tests { Ok(read_secret) => assert_eq!(read_secret, secret), } } + + #[test] + fn reader_seek_simple_chain() { + use DescriptorType::*; + + let memory_start_addr = GuestAddress(0x0); + let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap(); + + let chain = create_descriptor_chain( + &memory, + GuestAddress(0x0), + GuestAddress(0x100), + vec![ + (Readable, 8), + (Readable, 16), + (Readable, 18), + (Readable, 64), + ], + 0, + ); + let mut reader = Reader::new(&memory, chain); + assert_eq!(reader.available_bytes(), 106); + assert_eq!(reader.bytes_read(), 0); + + // Skip some bytes. available_bytes() and bytes_read() should update accordingly. + reader + .seek(SeekFrom::Current(64)) + .expect("seek should not fail here"); + assert_eq!(reader.available_bytes(), 42); + assert_eq!(reader.bytes_read(), 64); + + // Seek past end of chain - position should point just past the last byte. + reader + .seek(SeekFrom::Current(64)) + .expect("seek should not fail here"); + assert_eq!(reader.available_bytes(), 0); + assert_eq!(reader.bytes_read(), 106); + + // Seek back to the beginning. + reader + .seek(SeekFrom::Start(0)) + .expect("seek should not fail here"); + assert_eq!(reader.available_bytes(), 106); + assert_eq!(reader.bytes_read(), 0); + + // Seek to one byte before the end. + reader + .seek(SeekFrom::End(-1)) + .expect("seek should not fail here"); + assert_eq!(reader.available_bytes(), 1); + assert_eq!(reader.bytes_read(), 105); + + // Read the last byte. + let mut buffer = [0 as u8; 1]; + reader + .read_exact(&mut buffer) + .expect("read_exact should not fail here"); + assert_eq!(reader.available_bytes(), 0); + assert_eq!(reader.bytes_read(), 106); + } } -- cgit 1.4.1