summary refs log tree commit diff
path: root/sys_util/src/guest_memory.rs
diff options
context:
space:
mode:
Diffstat (limited to 'sys_util/src/guest_memory.rs')
-rw-r--r--sys_util/src/guest_memory.rs140
1 files changed, 97 insertions, 43 deletions
diff --git a/sys_util/src/guest_memory.rs b/sys_util/src/guest_memory.rs
index e8f620b..60b775a 100644
--- a/sys_util/src/guest_memory.rs
+++ b/sys_util/src/guest_memory.rs
@@ -7,6 +7,7 @@
 use std::convert::AsRef;
 use std::convert::TryFrom;
 use std::fmt::{self, Display};
+use std::mem::size_of;
 use std::os::unix::io::{AsRawFd, RawFd};
 use std::result;
 use std::sync::Arc;
@@ -87,11 +88,19 @@ struct MemoryRegion {
     memfd_offset: u64,
 }
 
-fn region_end(region: &MemoryRegion) -> GuestAddress {
-    // unchecked_add is safe as the region bounds were checked when it was created.
-    region
-        .guest_base
-        .unchecked_add(region.mapping.size() as u64)
+impl MemoryRegion {
+    fn start(&self) -> GuestAddress {
+        self.guest_base
+    }
+
+    fn end(&self) -> GuestAddress {
+        // unchecked_add is safe as the region bounds were checked when it was created.
+        self.guest_base.unchecked_add(self.mapping.size() as u64)
+    }
+
+    fn contains(&self, addr: GuestAddress) -> bool {
+        addr >= self.guest_base && addr < self.end()
+    }
 }
 
 /// Tracks a memory region and where it is mapped in the guest, along with a shm
@@ -200,8 +209,8 @@ impl GuestMemory {
     pub fn end_addr(&self) -> GuestAddress {
         self.regions
             .iter()
-            .max_by_key(|region| region.guest_base)
-            .map_or(GuestAddress(0), |region| region_end(region))
+            .max_by_key(|region| region.start())
+            .map_or(GuestAddress(0), MemoryRegion::end)
     }
 
     /// Returns the total size of memory in bytes.
@@ -214,9 +223,7 @@ impl GuestMemory {
 
     /// Returns true if the given address is within the memory range available to the guest.
     pub fn address_in_range(&self, addr: GuestAddress) -> bool {
-        self.regions
-            .iter()
-            .any(|region| region.guest_base <= addr && addr < region_end(region))
+        self.regions.iter().any(|region| region.contains(addr))
     }
 
     /// Returns true if the given range (start, end) is overlap with the memory range
@@ -224,7 +231,7 @@ impl GuestMemory {
     pub fn range_overlap(&self, start: GuestAddress, end: GuestAddress) -> bool {
         self.regions
             .iter()
-            .any(|region| region.guest_base < end && start < region_end(region))
+            .any(|region| region.start() < end && start < region.end())
     }
 
     /// Returns the address plus the offset if it is in range.
@@ -267,7 +274,7 @@ impl GuestMemory {
         for (index, region) in self.regions.iter().enumerate() {
             cb(
                 index,
-                region.guest_base,
+                region.start(),
                 region.mapping.size(),
                 region.mapping.as_ptr() as usize,
                 region.memfd_offset,
@@ -442,6 +449,61 @@ impl GuestMemory {
         })
     }
 
+    /// Returns a `VolatileSlice` of `len` bytes starting at `addr`. Returns an error if the slice
+    /// is not a subset of this `GuestMemory`.
+    ///
+    /// # Examples
+    /// * Write `99` to 30 bytes starting at guest address 0x1010.
+    ///
+    /// ```
+    /// # use sys_util::{GuestAddress, GuestMemory, GuestMemoryError, MemoryMapping};
+    /// # fn test_volatile_slice() -> Result<(), GuestMemoryError> {
+    /// #   let start_addr = GuestAddress(0x1000);
+    /// #   let mut gm = GuestMemory::new(&vec![(start_addr, 0x400)])?;
+    ///     let vslice = gm.get_slice_at_addr(GuestAddress(0x1010), 30)?;
+    ///     vslice.write_bytes(99);
+    /// #   Ok(())
+    /// # }
+    /// ```
+    pub fn get_slice_at_addr(&self, addr: GuestAddress, len: usize) -> Result<VolatileSlice> {
+        self.regions
+            .iter()
+            .find(|region| region.contains(addr))
+            .ok_or(Error::InvalidGuestAddress(addr))
+            .and_then(|region| {
+                // The cast to a usize is safe here because we know that `region.contains(addr)` and
+                // it's not possible for a memory region to be larger than what fits in a usize.
+                region
+                    .mapping
+                    .get_slice(addr.offset_from(region.start()) as usize, len)
+                    .map_err(Error::VolatileMemoryAccess)
+            })
+    }
+
+    /// Returns a `VolatileRef` to an object at `addr`. Returns Ok(()) if the object fits, or Err if
+    /// it extends past the end.
+    ///
+    /// # Examples
+    /// * Get a &u64 at offset 0x1010.
+    ///
+    /// ```
+    /// # use sys_util::{GuestAddress, GuestMemory, GuestMemoryError, MemoryMapping};
+    /// # fn test_ref_u64() -> Result<(), GuestMemoryError> {
+    /// #   let start_addr = GuestAddress(0x1000);
+    /// #   let mut gm = GuestMemory::new(&vec![(start_addr, 0x400)])?;
+    ///     gm.write_obj_at_addr(47u64, GuestAddress(0x1010))?;
+    ///     let vref = gm.get_ref_at_addr::<u64>(GuestAddress(0x1010))?;
+    ///     assert_eq!(vref.load(), 47u64);
+    /// #   Ok(())
+    /// # }
+    /// ```
+    pub fn get_ref_at_addr<T: DataInit>(&self, addr: GuestAddress) -> Result<VolatileRef<T>> {
+        let buf = self.get_slice_at_addr(addr, size_of::<T>())?;
+        // Safe because we have know that `buf` is at least `size_of::<T>()` bytes and that the
+        // returned reference will not outlive this `GuestMemory`.
+        Ok(unsafe { VolatileRef::new(buf.as_mut_ptr() as *mut T) })
+    }
+
     /// Reads data from a file descriptor and writes it to guest memory.
     ///
     /// # Arguments
@@ -550,15 +612,16 @@ impl GuestMemory {
     where
         F: FnOnce(&MemoryMapping, usize) -> Result<T>,
     {
-        for region in self.regions.iter() {
-            if guest_addr >= region.guest_base && guest_addr < region_end(region) {
-                return cb(
+        self.regions
+            .iter()
+            .find(|region| region.contains(guest_addr))
+            .ok_or(Error::InvalidGuestAddress(guest_addr))
+            .and_then(|region| {
+                cb(
                     &region.mapping,
-                    guest_addr.offset_from(region.guest_base) as usize,
-                );
-            }
-        }
-        Err(Error::InvalidGuestAddress(guest_addr))
+                    guest_addr.offset_from(region.start()) as usize,
+                )
+            })
     }
 
     /// Convert a GuestAddress into an offset within self.memfd.
@@ -585,25 +648,11 @@ impl GuestMemory {
     /// assert_eq!(offset, 0x3500);
     /// ```
     pub fn offset_from_base(&self, guest_addr: GuestAddress) -> Result<u64> {
-        for region in self.regions.iter() {
-            if guest_addr >= region.guest_base && guest_addr < region_end(region) {
-                return Ok(region.memfd_offset + guest_addr.offset_from(region.guest_base) as u64);
-            }
-        }
-        Err(Error::InvalidGuestAddress(guest_addr))
-    }
-}
-
-impl VolatileMemory for GuestMemory {
-    fn get_slice(&self, offset: u64, count: u64) -> VolatileMemoryResult<VolatileSlice> {
-        for region in self.regions.iter() {
-            if offset >= region.guest_base.0 && offset < region_end(region).0 {
-                return region
-                    .mapping
-                    .get_slice(offset - region.guest_base.0, count);
-            }
-        }
-        Err(VolatileMemoryError::OutOfBounds { addr: offset })
+        self.regions
+            .iter()
+            .find(|region| region.contains(guest_addr))
+            .ok_or(Error::InvalidGuestAddress(guest_addr))
+            .map(|region| region.memfd_offset + guest_addr.offset_from(region.start()))
     }
 }
 
@@ -690,8 +739,11 @@ mod tests {
         gm.write_obj_at_addr(val1, GuestAddress(0x500)).unwrap();
         gm.write_obj_at_addr(val2, GuestAddress(0x1000 + 32))
             .unwrap();
-        let num1: u64 = gm.get_ref(0x500).unwrap().load();
-        let num2: u64 = gm.get_ref(0x1000 + 32).unwrap().load();
+        let num1: u64 = gm.get_ref_at_addr(GuestAddress(0x500)).unwrap().load();
+        let num2: u64 = gm
+            .get_ref_at_addr(GuestAddress(0x1000 + 32))
+            .unwrap()
+            .load();
         assert_eq!(val1, num1);
         assert_eq!(val2, num2);
     }
@@ -704,8 +756,10 @@ mod tests {
 
         let val1: u64 = 0xaa55aa55aa55aa55;
         let val2: u64 = 0x55aa55aa55aa55aa;
-        gm.get_ref(0x500).unwrap().store(val1);
-        gm.get_ref(0x1000 + 32).unwrap().store(val2);
+        gm.get_ref_at_addr(GuestAddress(0x500)).unwrap().store(val1);
+        gm.get_ref_at_addr(GuestAddress(0x1000 + 32))
+            .unwrap()
+            .store(val2);
         let num1: u64 = gm.read_obj_from_addr(GuestAddress(0x500)).unwrap();
         let num2: u64 = gm.read_obj_from_addr(GuestAddress(0x1000 + 32)).unwrap();
         assert_eq!(val1, num1);