summary refs log tree commit diff
path: root/sys_util/src
diff options
context:
space:
mode:
Diffstat (limited to 'sys_util/src')
-rw-r--r--sys_util/src/guest_memory.rs2
-rw-r--r--sys_util/src/ioctl.rs4
-rw-r--r--sys_util/src/mmap.rs140
3 files changed, 93 insertions, 53 deletions
diff --git a/sys_util/src/guest_memory.rs b/sys_util/src/guest_memory.rs
index 60b775a..b13b8cc 100644
--- a/sys_util/src/guest_memory.rs
+++ b/sys_util/src/guest_memory.rs
@@ -13,7 +13,7 @@ use std::result;
 use std::sync::Arc;
 
 use crate::guest_address::GuestAddress;
-use crate::mmap::{self, MemoryMapping};
+use crate::mmap::{self, MappedRegion, MemoryMapping};
 use crate::shm::{MemfdSeals, SharedMemory};
 use crate::{errno, pagesize};
 use data_model::volatile_memory::*;
diff --git a/sys_util/src/ioctl.rs b/sys_util/src/ioctl.rs
index f8c7604..3cf622b 100644
--- a/sys_util/src/ioctl.rs
+++ b/sys_util/src/ioctl.rs
@@ -23,13 +23,13 @@ macro_rules! ioctl_expr {
 macro_rules! ioctl_ioc_nr {
     ($name:ident, $dir:expr, $ty:expr, $nr:expr, $size:expr) => {
         #[allow(non_snake_case)]
-        pub fn $name() -> ::std::os::raw::c_ulong {
+        pub const fn $name() -> ::std::os::raw::c_ulong {
             $crate::ioctl_expr!($dir, $ty, $nr, $size)
         }
     };
     ($name:ident, $dir:expr, $ty:expr, $nr:expr, $size:expr, $($v:ident),+) => {
         #[allow(non_snake_case)]
-        pub fn $name($($v: ::std::os::raw::c_uint),+) -> ::std::os::raw::c_ulong {
+        pub const fn $name($($v: ::std::os::raw::c_uint),+) -> ::std::os::raw::c_ulong {
             $crate::ioctl_expr!($dir, $ty, $nr, $size)
         }
     };
diff --git a/sys_util/src/mmap.rs b/sys_util/src/mmap.rs
index 64ffe17..0a67d88 100644
--- a/sys_util/src/mmap.rs
+++ b/sys_util/src/mmap.rs
@@ -105,6 +105,58 @@ impl Into<c_int> for Protection {
     }
 }
 
+/// Validates that `offset`..`offset+range_size` lies within the bounds of a memory mapping of
+/// `mmap_size` bytes.  Also checks for any overflow.
+fn validate_includes_range(mmap_size: usize, offset: usize, range_size: usize) -> Result<()> {
+    // Ensure offset + size doesn't overflow
+    let end_offset = offset
+        .checked_add(range_size)
+        .ok_or(Error::InvalidAddress)?;
+    // Ensure offset + size are within the mapping bounds
+    if end_offset <= mmap_size {
+        Ok(())
+    } else {
+        Err(Error::InvalidAddress)
+    }
+}
+
+/// A range of memory that can be msynced, for abstracting over different types of memory mappings.
+///
+/// Safe when implementers guarantee `ptr`..`ptr+size` is an mmaped region owned by this object that
+/// can't be unmapped during the `MappedRegion`'s lifetime.
+pub unsafe trait MappedRegion: Send + Sync {
+    /// Returns a pointer to the beginning of the memory region. Should only be
+    /// used for passing this region to ioctls for setting guest memory.
+    fn as_ptr(&self) -> *mut u8;
+
+    /// Returns the size of the memory region in bytes.
+    fn size(&self) -> usize;
+}
+
+impl dyn MappedRegion {
+    /// Calls msync with MS_SYNC on a mapping of `size` bytes starting at `offset` from the start of
+    /// the region.  `offset`..`offset+size` must be contained within the `MappedRegion`.
+    pub fn msync(&self, offset: usize, size: usize) -> Result<()> {
+        validate_includes_range(self.size(), offset, size)?;
+
+        // Safe because the MemoryMapping/MemoryMappingArena interface ensures our pointer and size
+        // are correct, and we've validated that `offset`..`offset+size` is in the range owned by
+        // this `MappedRegion`.
+        let ret = unsafe {
+            libc::msync(
+                (self.as_ptr() as usize + offset) as *mut libc::c_void,
+                size,
+                libc::MS_SYNC,
+            )
+        };
+        if ret != -1 {
+            Ok(())
+        } else {
+            Err(Error::SystemCallFailed(errno::Error::last()))
+        }
+    }
+}
+
 /// Wraps an anonymous shared memory mapping in the current process.
 #[derive(Debug)]
 pub struct MemoryMapping {
@@ -314,17 +366,6 @@ impl MemoryMapping {
         })
     }
 
-    /// Returns a pointer to the beginning of the memory region. Should only be
-    /// used for passing this region to ioctls for setting guest memory.
-    pub fn as_ptr(&self) -> *mut u8 {
-        self.addr
-    }
-
-    /// Returns the size of the memory region in bytes.
-    pub fn size(&self) -> usize {
-        self.size
-    }
-
     /// Calls msync with MS_SYNC on the mapping.
     pub fn msync(&self) -> Result<()> {
         // This is safe since we use the exact address and length of a known
@@ -607,6 +648,18 @@ impl MemoryMapping {
     }
 }
 
+// Safe because the pointer and size point to a memory range owned by this MemoryMapping that won't
+// be unmapped until it's Dropped.
+unsafe impl MappedRegion for MemoryMapping {
+    fn as_ptr(&self) -> *mut u8 {
+        self.addr
+    }
+
+    fn size(&self) -> usize {
+        self.size
+    }
+}
+
 impl VolatileMemory for MemoryMapping {
     fn get_slice(&self, offset: usize, count: usize) -> VolatileMemoryResult<VolatileSlice> {
         let mem_end = calc_offset(offset, count)?;
@@ -732,7 +785,11 @@ impl MemoryMappingArena {
         prot: Protection,
         fd: Option<(&dyn AsRawFd, u64)>,
     ) -> Result<()> {
-        self.validate_range(offset, size)?;
+        // Ensure offset is page-aligned
+        if offset % pagesize() != 0 {
+            return Err(Error::NotPageAligned);
+        }
+        validate_includes_range(self.size(), offset, size)?;
 
         // This is safe since the range has been validated.
         let mmap = unsafe {
@@ -762,50 +819,18 @@ impl MemoryMappingArena {
     pub fn remove(&mut self, offset: usize, size: usize) -> Result<()> {
         self.try_add(offset, size, Protection::read(), None)
     }
+}
 
-    /// Calls msync with MS_SYNC on a mapping of `size` bytes starting at `offset` from the start of
-    /// the arena. `offset` must be page aligned.
-    pub fn msync(&self, offset: usize, size: usize) -> Result<()> {
-        self.validate_range(offset, size)?;
-
-        // Safe because we've validated that this memory range is owned by this `MemoryMappingArena`.
-        let ret =
-            unsafe { libc::msync(self.addr as *mut libc::c_void, self.size(), libc::MS_SYNC) };
-        if ret == -1 {
-            return Err(Error::SystemCallFailed(errno::Error::last()));
-        }
-        Ok(())
-    }
-
-    /// Returns a pointer to the beginning of the memory region.  Should only be
-    /// used for passing this region to ioctls for setting guest memory.
-    pub fn as_ptr(&self) -> *mut u8 {
+// Safe because the pointer and size point to a memory range owned by this MemoryMappingArena that
+// won't be unmapped until it's Dropped.
+unsafe impl MappedRegion for MemoryMappingArena {
+    fn as_ptr(&self) -> *mut u8 {
         self.addr
     }
 
-    /// Returns the size of the memory region in bytes.
-    pub fn size(&self) -> usize {
+    fn size(&self) -> usize {
         self.size
     }
-
-    /// Validates `offset` and `size`.
-    /// Checks that offset..offset+size doesn't overlap with existing mappings.
-    /// Also ensures correct alignment, and checks for any overflow.
-    /// Note: offset..offset+size is considered valid if it _perfectly_ overlaps
-    /// with single other region.
-    fn validate_range(&self, offset: usize, size: usize) -> Result<()> {
-        // Ensure offset is page-aligned
-        if offset % pagesize() != 0 {
-            return Err(Error::NotPageAligned);
-        }
-        // Ensure offset + size doesn't overflow
-        let end_offset = offset.checked_add(size).ok_or(Error::InvalidAddress)?;
-        // Ensure offset + size are within the arena bounds
-        if end_offset > self.size {
-            return Err(Error::InvalidAddress);
-        }
-        Ok(())
-    }
 }
 
 impl From<MemoryMapping> for MemoryMappingArena {
@@ -1021,4 +1046,19 @@ mod tests {
         m.remove(0, ps - 1)
             .expect("failed to remove unaligned mapping");
     }
+
+    #[test]
+    fn arena_msync() {
+        let size = 0x40000;
+        let m = MemoryMappingArena::new(size).unwrap();
+        let ps = pagesize();
+        MappedRegion::msync(&m, 0, ps).unwrap();
+        MappedRegion::msync(&m, 0, size).unwrap();
+        MappedRegion::msync(&m, ps, size - ps).unwrap();
+        let res = MappedRegion::msync(&m, ps, size).unwrap_err();
+        match res {
+            Error::InvalidAddress => {}
+            e => panic!("unexpected error: {}", e),
+        }
+    }
 }