summary refs log tree commit diff
path: root/devices/src/virtio/balloon.rs
diff options
context:
space:
mode:
Diffstat (limited to 'devices/src/virtio/balloon.rs')
-rw-r--r--devices/src/virtio/balloon.rs39
1 files changed, 22 insertions, 17 deletions
diff --git a/devices/src/virtio/balloon.rs b/devices/src/virtio/balloon.rs
index 8109395..633b3fc 100644
--- a/devices/src/virtio/balloon.rs
+++ b/devices/src/virtio/balloon.rs
@@ -17,7 +17,7 @@ use sys_util::{
 use vm_control::{BalloonControlCommand, BalloonControlResponseSocket};
 
 use super::{
-    copy_config, DescriptorChain, Queue, VirtioDevice, INTERRUPT_STATUS_CONFIG_CHANGED,
+    copy_config, Queue, Reader, VirtioDevice, INTERRUPT_STATUS_CONFIG_CHANGED,
     INTERRUPT_STATUS_USED_RING, TYPE_BALLOON, VIRTIO_F_VERSION_1,
 };
 
@@ -81,10 +81,6 @@ struct Worker {
     command_socket: BalloonControlResponseSocket,
 }
 
-fn valid_inflate_desc(desc: &DescriptorChain) -> bool {
-    !desc.is_write_only() && desc.len % 4 == 0
-}
-
 impl Worker {
     fn process_inflate_deflate(&mut self, inflate: bool) -> bool {
         let queue = if inflate {
@@ -95,19 +91,28 @@ impl Worker {
 
         let mut needs_interrupt = false;
         while let Some(avail_desc) = queue.pop(&self.mem) {
-            if inflate && valid_inflate_desc(&avail_desc) {
-                let num_addrs = avail_desc.len / 4;
-                for i in 0..num_addrs as usize {
-                    let addr = match avail_desc.addr.checked_add((i * 4) as u64) {
-                        Some(a) => a,
-                        None => break,
-                    };
-                    let guest_input: u32 = match self.mem.read_obj_from_addr(addr) {
-                        Ok(a) => a,
-                        Err(_) => continue,
+            let index = avail_desc.index;
+
+            if inflate {
+                let mut reader = Reader::new(&self.mem, avail_desc);
+                let data_length = reader.available_bytes();
+
+                if data_length % 4 != 0 {
+                    error!("invalid inflate buffer size: {}", data_length);
+                    continue;
+                }
+
+                let num_addrs = data_length / 4;
+                for _ in 0..num_addrs as usize {
+                    let guest_input = match reader.read_obj::<Le32>() {
+                        Ok(a) => a.to_native(),
+                        Err(err) => {
+                            error!("error while reading unused pages: {}", err);
+                            break;
+                        }
                     };
                     let guest_address =
-                        GuestAddress((guest_input as u64) << VIRTIO_BALLOON_PFN_SHIFT);
+                        GuestAddress((u64::from(guest_input)) << VIRTIO_BALLOON_PFN_SHIFT);
 
                     if self
                         .mem
@@ -120,7 +125,7 @@ impl Worker {
                 }
             }
 
-            queue.add_used(&self.mem, avail_desc.index, 0);
+            queue.add_used(&self.mem, index, 0);
             needs_interrupt = true;
         }