summary refs log tree commit diff
path: root/servers/src/memfd.rs
diff options
context:
space:
mode:
Diffstat (limited to 'servers/src/memfd.rs')
-rw-r--r--servers/src/memfd.rs171
1 files changed, 171 insertions, 0 deletions
diff --git a/servers/src/memfd.rs b/servers/src/memfd.rs
new file mode 100644
index 0000000..ea6e148
--- /dev/null
+++ b/servers/src/memfd.rs
@@ -0,0 +1,171 @@
+// Copyright 2020 Alyssa Ross. All rights reserved.
+// Copyright 2017 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+use std::convert::TryInto;
+use std::error::Error;
+use std::io::prelude::*;
+use std::io::IoSlice;
+use std::net::Shutdown;
+use std::os::unix::net::{UnixListener, UnixStream};
+use std::os::unix::prelude::*;
+use std::thread::{self, JoinHandle};
+
+use sys_util::{error, EventFd, MemfdSeals, PollContext, PollToken, ScmSocket, SharedMemory};
+
+use crate::Server;
+
+fn create_memfd(name: [u8; 224], size: u64) -> Result<SharedMemory, Box<dyn Error>> {
+    let mut full_name: Vec<u8> = b"crosvm-guest-memfd-".to_vec();
+    full_name.extend(
+        name.iter()
+            .map(|x| *x)
+            .take_while(|b| *b != 0 && *b != b'/'),
+    );
+
+    let size: usize = size.try_into()?;
+
+    let mut seals = MemfdSeals::new();
+    seals.set_grow_seal();
+    seals.set_shrink_seal();
+    seals.set_seal_seal();
+
+    let mut memfd = SharedMemory::named(full_name)?;
+    memfd.set_size(size as u64)?;
+    memfd.add_seals(seals)?;
+
+    Ok(memfd)
+}
+
+fn do_request(mut conn: UnixStream) -> Result<(), Box<dyn Error>> {
+    let mut name = [0; 224];
+    conn.read_exact(&mut name)?;
+
+    let mut size = [0; 8];
+    conn.read_exact(&mut size)?;
+    let size = u64::from_le_bytes(size);
+
+    let _ = conn.shutdown(Shutdown::Read);
+
+    match create_memfd(name, size) {
+        Ok(memfd) => conn.send_with_fd(&[IoSlice::new(&[0x00])], memfd.as_raw_fd())?,
+        Err(_) => conn.write(&[0x01])?,
+    };
+
+    Ok(())
+}
+
+fn run(wl_socket: UnixListener, kill_evt: EventFd) {
+    #[derive(Debug, PollToken)]
+    enum Token {
+        Socket,
+        Kill,
+    }
+
+    let poll_ctx =
+        match PollContext::build_with(&[(&wl_socket, Token::Socket), (&kill_evt, Token::Kill)]) {
+            Ok(pc) => pc,
+            Err(e) => {
+                error!("failed creating PollContext: {}", e);
+                return;
+            }
+        };
+
+    'poll: loop {
+        let events = match poll_ctx.wait() {
+            Ok(v) => v,
+            Err(e) => {
+                error!("failed polling for events: {}", e);
+                break;
+            }
+        };
+
+        for event in &events {
+            match dbg!(event.token()) {
+                Token::Socket => {
+                    let conn = match wl_socket.accept() {
+                        Ok((conn, _)) => conn,
+                        Err(e) => {
+                            error!("Failed to accept memfd connection: {}", e);
+                            break 'poll;
+                        }
+                    };
+
+                    if let Err(e) = do_request(conn) {
+                        error!("Failed to service memfd request: {}", e);
+                        break 'poll;
+                    }
+                }
+
+                Token::Kill => break 'poll,
+            }
+        }
+    }
+}
+
+#[derive(Debug)]
+pub struct MemfdServer {
+    kill_evt: Option<EventFd>,
+    worker_thread: Option<JoinHandle<()>>,
+    wl_socket: Option<UnixListener>,
+}
+
+impl MemfdServer {
+    pub fn new(wl_socket: UnixListener) -> Self {
+        Self {
+            kill_evt: None,
+            worker_thread: None,
+            wl_socket: Some(wl_socket),
+        }
+    }
+}
+
+impl Drop for MemfdServer {
+    fn drop(&mut self) {
+        if let Some(kill_evt) = self.kill_evt.take() {
+            // Ignore the result because there is nothing we can do about it.
+            let _ = kill_evt.write(1);
+        }
+
+        if let Some(worker_thread) = self.worker_thread.take() {
+            let _ = worker_thread.join();
+        }
+    }
+}
+
+impl Server for MemfdServer {
+    fn keep_fds(&self) -> Vec<RawFd> {
+        self.wl_socket.iter().map(AsRawFd::as_raw_fd).collect()
+    }
+
+    fn activate(&mut self) {
+        let wl_socket = match self.wl_socket.take() {
+            Some(wl) => wl,
+            None => return,
+        };
+
+        let (self_kill_evt, kill_evt) = match EventFd::new().and_then(|e| Ok((e.try_clone()?, e))) {
+            Ok(v) => v,
+            Err(e) => {
+                error!("failed creating kill EventFd pair: {}", e);
+                return;
+            }
+        };
+        self.kill_evt = Some(self_kill_evt);
+
+        let worker_result = thread::Builder::new()
+            .name("memfd-server".to_string())
+            .spawn(move || run(wl_socket, kill_evt));
+
+        match worker_result {
+            Err(e) => {
+                error!("failed to spawn memfd server worker: {}", e);
+                return;
+            }
+            Ok(join_handle) => {
+                self.worker_thread = Some(join_handle);
+            }
+        }
+    }
+}