summary refs log tree commit diff
path: root/vhost/src/lib.rs
diff options
context:
space:
mode:
Diffstat (limited to 'vhost/src/lib.rs')
-rw-r--r--vhost/src/lib.rs31
1 files changed, 21 insertions, 10 deletions
diff --git a/vhost/src/lib.rs b/vhost/src/lib.rs
index c1fb4cd..04ba655 100644
--- a/vhost/src/lib.rs
+++ b/vhost/src/lib.rs
@@ -9,14 +9,16 @@ pub use crate::net::Net;
 pub use crate::net::NetT;
 pub use crate::vsock::Vsock;
 
+use std::alloc::Layout;
 use std::fmt::{self, Display};
 use std::io::Error as IoError;
 use std::mem;
 use std::os::unix::io::AsRawFd;
 use std::ptr::null;
 
+use assertions::const_assert;
 use sys_util::{ioctl, ioctl_with_mut_ref, ioctl_with_ptr, ioctl_with_ref};
-use sys_util::{EventFd, GuestAddress, GuestMemory, GuestMemoryError};
+use sys_util::{EventFd, GuestAddress, GuestMemory, GuestMemoryError, LayoutAllocation};
 
 #[derive(Debug)]
 pub enum Error {
@@ -108,14 +110,21 @@ pub trait Vhost: AsRawFd + std::marker::Sized {
 
     /// Set the guest memory mappings for vhost to use.
     fn set_mem_table(&self) -> Result<()> {
+        const SIZE_OF_MEMORY: usize = mem::size_of::<virtio_sys::vhost_memory>();
+        const SIZE_OF_REGION: usize = mem::size_of::<virtio_sys::vhost_memory_region>();
+        const ALIGN_OF_MEMORY: usize = mem::align_of::<virtio_sys::vhost_memory>();
+        const ALIGN_OF_REGION: usize = mem::align_of::<virtio_sys::vhost_memory_region>();
+        const_assert!(ALIGN_OF_MEMORY >= ALIGN_OF_REGION);
+
         let num_regions = self.mem().num_regions() as usize;
-        let vec_size_bytes = mem::size_of::<virtio_sys::vhost_memory>()
-            + (num_regions * mem::size_of::<virtio_sys::vhost_memory_region>());
-        let mut bytes: Vec<u8> = vec![0; vec_size_bytes];
-        // Convert bytes pointer to a vhost_memory mut ref. The vector has been
-        // sized correctly to ensure it can hold vhost_memory and N regions.
-        let vhost_memory: &mut virtio_sys::vhost_memory =
-            unsafe { &mut *(bytes.as_mut_ptr() as *mut virtio_sys::vhost_memory) };
+        let size = SIZE_OF_MEMORY + num_regions * SIZE_OF_REGION;
+        let layout = Layout::from_size_align(size, ALIGN_OF_MEMORY).expect("impossible layout");
+        let mut allocation = LayoutAllocation::zeroed(layout);
+
+        // Safe to obtain an exclusive reference because there are no other
+        // references to the allocation yet and all-zero is a valid bit pattern.
+        let vhost_memory = unsafe { allocation.as_mut::<virtio_sys::vhost_memory>() };
+
         vhost_memory.nregions = num_regions as u32;
         // regions is a zero-length array, so taking a mut slice requires that
         // we correctly specify the size to match the amount of backing memory.
@@ -136,12 +145,14 @@ pub trait Vhost: AsRawFd + std::marker::Sized {
         // This ioctl is called with a pointer that is valid for the lifetime
         // of this function. The kernel will make its own copy of the memory
         // tables. As always, check the return value.
-        let ret =
-            unsafe { ioctl_with_ptr(self, virtio_sys::VHOST_SET_MEM_TABLE(), bytes.as_ptr()) };
+        let ret = unsafe { ioctl_with_ptr(self, virtio_sys::VHOST_SET_MEM_TABLE(), vhost_memory) };
         if ret < 0 {
             return ioctl_result();
         }
+
         Ok(())
+
+        // vhost_memory allocation is deallocated.
     }
 
     /// Set the number of descriptors in the vring.