diff options
Diffstat (limited to 'servers/src/memfd.rs')
-rw-r--r-- | servers/src/memfd.rs | 171 |
1 files changed, 171 insertions, 0 deletions
diff --git a/servers/src/memfd.rs b/servers/src/memfd.rs new file mode 100644 index 0000000..ea6e148 --- /dev/null +++ b/servers/src/memfd.rs @@ -0,0 +1,171 @@ +// Copyright 2020 Alyssa Ross. All rights reserved. +// Copyright 2017 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +use std::convert::TryInto; +use std::error::Error; +use std::io::prelude::*; +use std::io::IoSlice; +use std::net::Shutdown; +use std::os::unix::net::{UnixListener, UnixStream}; +use std::os::unix::prelude::*; +use std::thread::{self, JoinHandle}; + +use sys_util::{error, EventFd, MemfdSeals, PollContext, PollToken, ScmSocket, SharedMemory}; + +use crate::Server; + +fn create_memfd(name: [u8; 224], size: u64) -> Result<SharedMemory, Box<dyn Error>> { + let mut full_name: Vec<u8> = b"crosvm-guest-memfd-".to_vec(); + full_name.extend( + name.iter() + .map(|x| *x) + .take_while(|b| *b != 0 && *b != b'/'), + ); + + let size: usize = size.try_into()?; + + let mut seals = MemfdSeals::new(); + seals.set_grow_seal(); + seals.set_shrink_seal(); + seals.set_seal_seal(); + + let mut memfd = SharedMemory::named(full_name)?; + memfd.set_size(size as u64)?; + memfd.add_seals(seals)?; + + Ok(memfd) +} + +fn do_request(mut conn: UnixStream) -> Result<(), Box<dyn Error>> { + let mut name = [0; 224]; + conn.read_exact(&mut name)?; + + let mut size = [0; 8]; + conn.read_exact(&mut size)?; + let size = u64::from_le_bytes(size); + + let _ = conn.shutdown(Shutdown::Read); + + match create_memfd(name, size) { + Ok(memfd) => conn.send_with_fd(&[IoSlice::new(&[0x00])], memfd.as_raw_fd())?, + Err(_) => conn.write(&[0x01])?, + }; + + Ok(()) +} + +fn run(wl_socket: UnixListener, kill_evt: EventFd) { + #[derive(Debug, PollToken)] + enum Token { + Socket, + Kill, + } + + let poll_ctx = + match PollContext::build_with(&[(&wl_socket, Token::Socket), (&kill_evt, Token::Kill)]) { + Ok(pc) => pc, + Err(e) => { + error!("failed creating PollContext: {}", e); + return; + } + }; + + 'poll: loop { + let events = match poll_ctx.wait() { + Ok(v) => v, + Err(e) => { + error!("failed polling for events: {}", e); + break; + } + }; + + for event in &events { + match dbg!(event.token()) { + Token::Socket => { + let conn = match wl_socket.accept() { + Ok((conn, _)) => conn, + Err(e) => { + error!("Failed to accept memfd connection: {}", e); + break 'poll; + } + }; + + if let Err(e) = do_request(conn) { + error!("Failed to service memfd request: {}", e); + break 'poll; + } + } + + Token::Kill => break 'poll, + } + } + } +} + +#[derive(Debug)] +pub struct MemfdServer { + kill_evt: Option<EventFd>, + worker_thread: Option<JoinHandle<()>>, + wl_socket: Option<UnixListener>, +} + +impl MemfdServer { + pub fn new(wl_socket: UnixListener) -> Self { + Self { + kill_evt: None, + worker_thread: None, + wl_socket: Some(wl_socket), + } + } +} + +impl Drop for MemfdServer { + fn drop(&mut self) { + if let Some(kill_evt) = self.kill_evt.take() { + // Ignore the result because there is nothing we can do about it. + let _ = kill_evt.write(1); + } + + if let Some(worker_thread) = self.worker_thread.take() { + let _ = worker_thread.join(); + } + } +} + +impl Server for MemfdServer { + fn keep_fds(&self) -> Vec<RawFd> { + self.wl_socket.iter().map(AsRawFd::as_raw_fd).collect() + } + + fn activate(&mut self) { + let wl_socket = match self.wl_socket.take() { + Some(wl) => wl, + None => return, + }; + + let (self_kill_evt, kill_evt) = match EventFd::new().and_then(|e| Ok((e.try_clone()?, e))) { + Ok(v) => v, + Err(e) => { + error!("failed creating kill EventFd pair: {}", e); + return; + } + }; + self.kill_evt = Some(self_kill_evt); + + let worker_result = thread::Builder::new() + .name("memfd-server".to_string()) + .spawn(move || run(wl_socket, kill_evt)); + + match worker_result { + Err(e) => { + error!("failed to spawn memfd server worker: {}", e); + return; + } + Ok(join_handle) => { + self.worker_thread = Some(join_handle); + } + } + } +} |