summary refs log tree commit diff
path: root/devices/src/virtio/descriptor_utils.rs
diff options
context:
space:
mode:
authorDaniel Verkamp <dverkamp@chromium.org>2019-07-24 14:10:19 -0700
committerCommit Bot <commit-bot@chromium.org>2019-08-13 16:48:42 +0000
commit977f008bc3be607600a98d77ef6984b52dba7eb6 (patch)
tree7df0b60f6912fa8b9e1d9460f6619c348d1f9b81 /devices/src/virtio/descriptor_utils.rs
parent36713056968fb9106ec0da6c0d964293f0425e99 (diff)
downloadcrosvm-977f008bc3be607600a98d77ef6984b52dba7eb6.tar
crosvm-977f008bc3be607600a98d77ef6984b52dba7eb6.tar.gz
crosvm-977f008bc3be607600a98d77ef6984b52dba7eb6.tar.bz2
crosvm-977f008bc3be607600a98d77ef6984b52dba7eb6.tar.lz
crosvm-977f008bc3be607600a98d77ef6984b52dba7eb6.tar.xz
crosvm-977f008bc3be607600a98d77ef6984b52dba7eb6.tar.zst
crosvm-977f008bc3be607600a98d77ef6984b52dba7eb6.zip
devices: virtio: add seek() for descriptor chains
This allows moving the read/write cursor around within a chain of
descriptors through the standard io::Seek interface.

BUG=chromium:990546
TEST=./build_test

Change-Id: I26ed368d3c7592188241a343dfeb922f3423d935
Signed-off-by: Daniel Verkamp <dverkamp@chromium.org>
Reviewed-on: https://chromium-review.googlesource.com/c/chromiumos/platform/crosvm/+/1721369
Tested-by: kokoro <noreply+kokoro@google.com>
Reviewed-by: Zach Reizner <zachr@chromium.org>
Diffstat (limited to 'devices/src/virtio/descriptor_utils.rs')
-rw-r--r--devices/src/virtio/descriptor_utils.rs133
1 files changed, 132 insertions, 1 deletions
diff --git a/devices/src/virtio/descriptor_utils.rs b/devices/src/virtio/descriptor_utils.rs
index 9d18a63..a206002 100644
--- a/devices/src/virtio/descriptor_utils.rs
+++ b/devices/src/virtio/descriptor_utils.rs
@@ -3,6 +3,7 @@
 // found in the LICENSE file.
 
 use std::cmp;
+use std::convert::TryFrom;
 use std::fmt::{self, Display};
 use std::io;
 use std::os::unix::io::AsRawFd;
@@ -44,6 +45,7 @@ enum DescriptorFilter {
 struct DescriptorChainConsumer<'a> {
     offset: usize,
     desc_chain: Option<DescriptorChain<'a>>,
+    desc_chain_start: Option<DescriptorChain<'a>>,
     bytes_consumed: usize,
     avail_bytes: Option<usize>,
     filter: DescriptorFilter,
@@ -56,7 +58,8 @@ impl<'a> DescriptorChainConsumer<'a> {
     ) -> DescriptorChainConsumer<'a> {
         DescriptorChainConsumer {
             offset: 0,
-            desc_chain,
+            desc_chain: desc_chain.clone(),
+            desc_chain_start: desc_chain,
             bytes_consumed: 0,
             avail_bytes: None,
             filter,
@@ -129,6 +132,61 @@ impl<'a> DescriptorChainConsumer<'a> {
         }
         desc_chain
     }
+
+    fn seek_from_start(&mut self, offset: usize) -> Result<()> {
+        if offset < self.bytes_consumed {
+            // Restart from the beginning of the descriptor chain.
+            self.bytes_consumed = 0;
+            self.avail_bytes = None;
+            self.desc_chain = self.desc_chain_start.clone();
+        }
+
+        let mut count = offset - self.bytes_consumed;
+        while count > 0 {
+            let bytes_consumed = self.consume(|_, _| Ok(()), count)?;
+            if bytes_consumed == 0 {
+                break;
+            }
+            count -= bytes_consumed;
+        }
+
+        Ok(())
+    }
+
+    fn seek(&mut self, pos: io::SeekFrom) -> io::Result<u64> {
+        fn apply_signed_offset(base: usize, offset: i64) -> io::Result<u64> {
+            let base = i64::try_from(base).map_err(|_| {
+                io::Error::new(io::ErrorKind::InvalidData, "seek position out of i64 range")
+            })?;
+            let result = base.checked_add(offset).ok_or(io::Error::new(
+                io::ErrorKind::InvalidData,
+                "seek offset overflowed",
+            ))?;
+            if result < 0 {
+                return Err(io::Error::new(
+                    io::ErrorKind::InvalidData,
+                    "seek offset < 0",
+                ));
+            }
+            Ok(result as u64)
+        }
+
+        let offset = match pos {
+            io::SeekFrom::Start(o) => o,
+            io::SeekFrom::Current(o) => apply_signed_offset(self.bytes_consumed(), o)?,
+            io::SeekFrom::End(o) => {
+                apply_signed_offset(self.bytes_consumed() + self.available_bytes(), o)?
+            }
+        };
+
+        let offset = usize::try_from(offset).map_err(|_| {
+            io::Error::new(io::ErrorKind::InvalidData, "seek offset overflowed usize")
+        })?;
+        self.seek_from_start(offset)
+            .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?;
+
+        Ok(self.bytes_consumed() as u64)
+    }
 }
 
 /// Provides high-level interface over the sequence of memory regions
@@ -335,10 +393,23 @@ impl<'a> io::Write for Writer<'a> {
     }
 }
 
+impl<'a> io::Seek for Reader<'a> {
+    fn seek(&mut self, pos: io::SeekFrom) -> io::Result<u64> {
+        self.buffer.seek(pos)
+    }
+}
+
+impl<'a> io::Seek for Writer<'a> {
+    fn seek(&mut self, pos: io::SeekFrom) -> io::Result<u64> {
+        self.buffer.seek(pos)
+    }
+}
+
 #[cfg(test)]
 mod tests {
     use super::*;
     use data_model::{Le16, Le32, Le64};
+    use std::io::{Seek, SeekFrom};
     use sys_util::{MemfdSeals, SharedMemory};
 
     const VIRTQ_DESC_F_NEXT: u16 = 0x1;
@@ -677,4 +748,64 @@ mod tests {
             Ok(read_secret) => assert_eq!(read_secret, secret),
         }
     }
+
+    #[test]
+    fn reader_seek_simple_chain() {
+        use DescriptorType::*;
+
+        let memory_start_addr = GuestAddress(0x0);
+        let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
+
+        let chain = create_descriptor_chain(
+            &memory,
+            GuestAddress(0x0),
+            GuestAddress(0x100),
+            vec![
+                (Readable, 8),
+                (Readable, 16),
+                (Readable, 18),
+                (Readable, 64),
+            ],
+            0,
+        );
+        let mut reader = Reader::new(&memory, chain);
+        assert_eq!(reader.available_bytes(), 106);
+        assert_eq!(reader.bytes_read(), 0);
+
+        // Skip some bytes.  available_bytes() and bytes_read() should update accordingly.
+        reader
+            .seek(SeekFrom::Current(64))
+            .expect("seek should not fail here");
+        assert_eq!(reader.available_bytes(), 42);
+        assert_eq!(reader.bytes_read(), 64);
+
+        // Seek past end of chain - position should point just past the last byte.
+        reader
+            .seek(SeekFrom::Current(64))
+            .expect("seek should not fail here");
+        assert_eq!(reader.available_bytes(), 0);
+        assert_eq!(reader.bytes_read(), 106);
+
+        // Seek back to the beginning.
+        reader
+            .seek(SeekFrom::Start(0))
+            .expect("seek should not fail here");
+        assert_eq!(reader.available_bytes(), 106);
+        assert_eq!(reader.bytes_read(), 0);
+
+        // Seek to one byte before the end.
+        reader
+            .seek(SeekFrom::End(-1))
+            .expect("seek should not fail here");
+        assert_eq!(reader.available_bytes(), 1);
+        assert_eq!(reader.bytes_read(), 105);
+
+        // Read the last byte.
+        let mut buffer = [0 as u8; 1];
+        reader
+            .read_exact(&mut buffer)
+            .expect("read_exact should not fail here");
+        assert_eq!(reader.available_bytes(), 0);
+        assert_eq!(reader.bytes_read(), 106);
+    }
 }