summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--Cargo.toml2
-rw-r--r--devices/src/virtio/balloon.rs26
-rw-r--r--seccomp/arm/balloon_device.policy2
-rw-r--r--seccomp/x86_64/balloon_device.policy4
-rw-r--r--src/linux.rs151
-rw-r--r--src/main.rs1
-rw-r--r--sys_util/src/guest_memory.rs4
-rw-r--r--sys_util/src/lib.rs2
-rw-r--r--sys_util/src/mmap.rs7
-rw-r--r--sys_util/src/timerfd.rs142
10 files changed, 313 insertions, 28 deletions
diff --git a/Cargo.toml b/Cargo.toml
index ec58da2..708c9b0 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -38,6 +38,7 @@ protobuf = { version = "=1.4.3", optional = true }
 qcow_utils = { path = "qcow_utils" }
 resources = { path = "resources" }
 p9 = { path = "p9" }
+rand = "=0.3.20"
 
 [target.'cfg(target_arch = "x86_64")'.dependencies]
 x86_64 = { path = "x86_64" }
@@ -46,5 +47,4 @@ x86_64 = { path = "x86_64" }
 aarch64 = { path = "aarch64" }
 
 [dev-dependencies]
-rand = "=0.3.20"
 sys_util = { path = "sys_util" }
diff --git a/devices/src/virtio/balloon.rs b/devices/src/virtio/balloon.rs
index 2194154..2943359 100644
--- a/devices/src/virtio/balloon.rs
+++ b/devices/src/virtio/balloon.rs
@@ -5,13 +5,14 @@
 use std;
 use std::cmp;
 use std::io::Write;
+use std::mem;
 use std::os::unix::io::{AsRawFd, RawFd};
 use std::os::unix::net::UnixDatagram;
 use std::sync::Arc;
 use std::sync::atomic::{AtomicUsize, Ordering};
 use std::thread;
 
-use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
+use byteorder::{ByteOrder, LittleEndian, ReadBytesExt, WriteBytesExt};
 use sys_util::{self, EventFd, GuestAddress, GuestMemory, PollContext, PollToken};
 
 use super::{VirtioDevice, Queue, DescriptorChain, INTERRUPT_STATUS_CONFIG_CHANGED,
@@ -85,7 +86,7 @@ impl Worker {
                             GuestAddress((guest_input as u64) << VIRTIO_BALLOON_PFN_SHIFT);
 
                         if self.mem
-                            .dont_need_range(guest_address, 1 << VIRTIO_BALLOON_PFN_SHIFT)
+                            .remove_range(guest_address, 1 << VIRTIO_BALLOON_PFN_SHIFT)
                             .is_err()
                         {
                             warn!("Marking pages unused failed {:?}", guest_address);
@@ -174,20 +175,15 @@ impl Worker {
                         needs_interrupt |= self.process_inflate_deflate(false);
                     }
                     Token::CommandSocket => {
-                        let mut buf = [0u8; 4];
+                        let mut buf = [0u8; mem::size_of::<u64>()];
                         if let Ok(count) = self.command_socket.recv(&mut buf) {
-                            if count == 4 {
-                                let mut buf = &buf[0..];
-                                let increment: i32 = buf.read_i32::<LittleEndian>().unwrap();
-                                let num_pages = self.config.num_pages.load(Ordering::Relaxed) as
-                                    i32;
-                                if increment < 0 && increment.abs() > num_pages {
-                                    continue;
-                                }
-                                self.config.num_pages.fetch_add(
-                                    increment as usize,
-                                    Ordering::Relaxed,
-                                );
+                            // Ignore any malformed messages that are not exactly 8 bytes long.
+                            if count == mem::size_of::<u64>() {
+                                let num_bytes = LittleEndian::read_u64(&buf);
+                                let num_pages = (num_bytes >> VIRTIO_BALLOON_PFN_SHIFT) as usize;
+                                info!("ballon config changed to consume {} pages", num_pages);
+
+                                self.config.num_pages.store(num_pages, Ordering::Relaxed);
                                 self.signal_config_changed();
                             }
                         }
diff --git a/seccomp/arm/balloon_device.policy b/seccomp/arm/balloon_device.policy
index d5011be..42e5a09 100644
--- a/seccomp/arm/balloon_device.policy
+++ b/seccomp/arm/balloon_device.policy
@@ -6,11 +6,11 @@ close: 1
 exit_group: 1
 futex: 1
 gettimeofday: 1
-madvise: 1
 # Disallow mmap with PROT_EXEC set.  The syntax here doesn't allow bit
 # negation, thus the manually negated mask constant.
 mmap2: arg2 in 0xfffffffb
 mprotect: arg2 in 0xfffffffb
+madvise: arg2 == MADV_DONTDUMP || arg2 == MADV_DONTNEED || arg2 == MADV_REMOVE
 munmap: 1
 read: 1
 recv: 1
diff --git a/seccomp/x86_64/balloon_device.policy b/seccomp/x86_64/balloon_device.policy
index 8060374..b10f9ef 100644
--- a/seccomp/x86_64/balloon_device.policy
+++ b/seccomp/x86_64/balloon_device.policy
@@ -5,13 +5,11 @@
 close: 1
 exit_group: 1
 futex: 1
-madvise: 1
 # Disallow mmap with PROT_EXEC set.  The syntax here doesn't allow bit
 # negation, thus the manually negated mask constant.
 mmap: arg2 in 0xfffffffb
 mprotect: arg2 in 0xfffffffb
-# Allow MADV_DONTDUMP only.
-madvise: arg2 == 0x00000010
+madvise: arg2 == MADV_DONTDUMP || arg2 == MADV_DONTNEED || arg2 == MADV_REMOVE
 munmap: 1
 read: 1
 recvfrom: 1
diff --git a/src/linux.rs b/src/linux.rs
index 46ea35b..594383e 100644
--- a/src/linux.rs
+++ b/src/linux.rs
@@ -3,22 +3,28 @@
 // found in the LICENSE file.
 
 use std;
+use std::cmp::min;
 use std::ffi::{CString, CStr};
 use std::fmt;
 use std::error;
 use std::fs::{File, OpenOptions, remove_file};
-use std::io::{self, stdin};
+use std::io::{self, Read, stdin};
+use std::mem;
 use std::os::unix::io::{AsRawFd, FromRawFd, RawFd};
 use std::os::unix::net::UnixDatagram;
 use std::path::{Path, PathBuf};
+use std::str;
 use std::sync::atomic::{AtomicBool, Ordering};
 use std::sync::{Arc, Mutex, Barrier};
 use std::thread;
+use std::time::Duration;
 use std::thread::JoinHandle;
 
-use libc;
-use libc::c_int;
+use libc::{self, c_int};
+use rand::thread_rng;
+use rand::distributions::{IndependentSample, Range};
 
+use byteorder::{ByteOrder, LittleEndian};
 use devices;
 use io_jail::{self, Minijail};
 use kernel_cmdline;
@@ -55,6 +61,7 @@ pub enum Error {
     CreatePollContext(sys_util::Error),
     CreateSignalFd(sys_util::SignalFdError),
     CreateSocket(io::Error),
+    CreateTimerFd(sys_util::Error),
     CreateVcpu(sys_util::Error),
     CreateVm(Box<error::Error>),
     DeviceJail(io_jail::Error),
@@ -67,8 +74,12 @@ pub enum Error {
     NetDeviceNew(devices::virtio::NetError),
     NoVarEmpty,
     OpenKernel(PathBuf, io::Error),
+    OpenLowMem(io::Error),
     PollContextAdd(sys_util::Error),
+    PollContextDelete(sys_util::Error),
     QcowDeviceCreate(qcow::Error),
+    ReadLowmemAvailable(io::Error),
+    ReadLowmemMargin(io::Error),
     RegisterBalloon(MmioRegisterError),
     RegisterBlock(MmioRegisterError),
     RegisterGpu(MmioRegisterError),
@@ -77,11 +88,13 @@ pub enum Error {
     RegisterSignalHandler(sys_util::Error),
     RegisterVsock(MmioRegisterError),
     RegisterWayland(MmioRegisterError),
+    ResetTimerFd(sys_util::Error),
     RngDeviceNew(devices::virtio::RngError),
     SettingGidMap(io_jail::Error),
     SettingUidMap(io_jail::Error),
     SignalFd(sys_util::SignalFdError),
     SpawnVcpu(io::Error),
+    TimerFd(sys_util::Error),
     VhostNetDeviceNew(devices::virtio::vhost::Error),
     VhostVsockDeviceNew(devices::virtio::vhost::Error),
     WaylandDeviceNew(sys_util::Error),
@@ -110,6 +123,7 @@ impl fmt::Display for Error {
             &Error::CreatePollContext(ref e) => write!(f, "failed to create poll context: {:?}", e),
             &Error::CreateSignalFd(ref e) => write!(f, "failed to create signalfd: {:?}", e),
             &Error::CreateSocket(ref e) => write!(f, "failed to create socket: {}", e),
+            &Error::CreateTimerFd(ref e) => write!(f, "failed to create timerfd: {}", e),
             &Error::CreateVcpu(ref e) => write!(f, "failed to create VCPU: {:?}", e),
             &Error::CreateVm(ref e) => write!(f, "failed to create KVM VM object: {:?}", e),
             &Error::DeviceJail(ref e) => write!(f, "failed to jail device: {}", e),
@@ -126,10 +140,20 @@ impl fmt::Display for Error {
             &Error::OpenKernel(ref p, ref e) => {
                 write!(f, "failed to open kernel image {:?}: {}", p, e)
             }
+            &Error::OpenLowMem(ref e) => write!(f, "failed to open /dev/chromeos-low-mem: {}", e),
             &Error::PollContextAdd(ref e) => write!(f, "failed to add fd to poll context: {:?}", e),
+            &Error::PollContextDelete(ref e) => {
+                write!(f, "failed to remove fd from poll context: {:?}", e)
+            }
             &Error::QcowDeviceCreate(ref e) => {
                 write!(f, "failed to read qcow formatted file {:?}", e)
             }
+            &Error::ReadLowmemAvailable(ref e) => {
+                write!(f, "failed to read /sys/kernel/mm/chromeos-low_mem/available: {}", e)
+            }
+            &Error::ReadLowmemMargin(ref e) => {
+                write!(f, "failed to read /sys/kernel/mm/chromeos-low_mem/margin: {}", e)
+            }
             &Error::RegisterBalloon(ref e) => {
                 write!(f, "error registering balloon device: {:?}", e)
             },
@@ -144,11 +168,13 @@ impl fmt::Display for Error {
                 write!(f, "error registering virtual socket device: {:?}", e)
             }
             &Error::RegisterWayland(ref e) => write!(f, "error registering wayland device: {}", e),
+            &Error::ResetTimerFd(ref e) => write!(f, "failed to reset timerfd: {}", e),
             &Error::RngDeviceNew(ref e) => write!(f, "failed to set up rng: {:?}", e),
             &Error::SettingGidMap(ref e) => write!(f, "error setting GID map: {}", e),
             &Error::SettingUidMap(ref e) => write!(f, "error setting UID map: {}", e),
             &Error::SignalFd(ref e) => write!(f, "failed to read signal fd: {:?}", e),
             &Error::SpawnVcpu(ref e) => write!(f, "failed to spawn VCPU thread: {:?}", e),
+            &Error::TimerFd(ref e) => write!(f, "failed to read timer fd: {:?}", e),
             &Error::VhostNetDeviceNew(ref e) => {
                 write!(f, "failed to set up vhost networking: {:?}", e)
             }
@@ -687,6 +713,18 @@ fn run_vcpu(vcpu: Vcpu,
         .map_err(Error::SpawnVcpu)
 }
 
+// Reads the contents of a file and converts them into a u64.
+fn file_to_u64<P: AsRef<Path>>(path: P) -> io::Result<u64> {
+    let mut file = File::open(path)?;
+
+    let mut buf = [0u8; 32];
+    let count = file.read(&mut buf)?;
+
+    let content = str::from_utf8(&buf[..count])
+        .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
+    content.trim().parse().map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
+}
+
 fn run_control(mut vm: Vm,
                control_sockets: Vec<UnlinkUnixDatagram>,
                mut resources: SystemAllocator,
@@ -700,11 +738,34 @@ fn run_control(mut vm: Vm,
                -> Result<()> {
     const MAX_VM_FD_RECV: usize = 1;
 
+    // Paths to get the currently available memory and the low memory threshold.
+    const LOWMEM_MARGIN: &'static str = "/sys/kernel/mm/chromeos-low_mem/margin";
+    const LOWMEM_AVAILABLE: &'static str = "/sys/kernel/mm/chromeos-low_mem/available";
+
+    // The amount of additional memory to claim back from the VM whenever the system is
+    // low on memory.
+    const ONE_GB: u64 = (1 << 30);
+
+    let max_balloon_memory = match vm.get_memory().memory_size() {
+        // If the VM has at least 1.5 GB, the balloon driver can consume all but the last 1 GB.
+        n if n >= (ONE_GB / 2) * 3 => n - ONE_GB,
+        // Otherwise, if the VM has at least 500MB the balloon driver will consume at most
+        // half of it.
+        n if n >= (ONE_GB / 2) => n / 2,
+        // Otherwise, the VM is too small for us to take memory away from it.
+        _ => 0,
+    };
+    let mut current_balloon_memory: u64 = 0;
+    let balloon_memory_increment: u64 = max_balloon_memory / 16;
+
     #[derive(PollToken)]
     enum Token {
         Exit,
         Stdin,
         ChildSignal,
+        CheckAvailableMemory,
+        LowMemory,
+        LowmemTimer,
         VmControl { index: usize },
     }
 
@@ -724,6 +785,25 @@ fn run_control(mut vm: Vm,
         poll_ctx.add(socket.as_ref(), Token::VmControl{ index }).map_err(Error::PollContextAdd)?;
     }
 
+    // Watch for low memory notifications and take memory back from the VM.
+    let low_mem = File::open("/dev/chromeos-low-mem").map_err(Error::OpenLowMem)?;
+    poll_ctx.add(&low_mem, Token::LowMemory).map_err(Error::PollContextAdd)?;
+
+    // Used to rate limit balloon requests.
+    let mut lowmem_timer = TimerFd::new().map_err(Error::CreateTimerFd)?;
+    poll_ctx.add(&lowmem_timer, Token::LowmemTimer).map_err(Error::PollContextAdd)?;
+
+    // Used to check whether it's ok to start giving memory back to the VM.
+    let mut freemem_timer = TimerFd::new().map_err(Error::CreateTimerFd)?;
+    poll_ctx.add(&freemem_timer, Token::CheckAvailableMemory).map_err(Error::PollContextAdd)?;
+
+    // Used to add jitter to timer values so that we don't have a thundering herd problem when
+    // multiple VMs are running.
+    let mut rng = thread_rng();
+    let lowmem_jitter_ms = Range::new(0, 200);
+    let freemem_jitter_secs = Range::new(0, 12);
+    let interval_jitter_secs = Range::new(0, 6);
+
     let mut scm = Scm::new(MAX_VM_FD_RECV);
 
     'poll: loop {
@@ -776,6 +856,68 @@ fn run_control(mut vm: Vm,
                         break 'poll;
                     }
                 }
+                Token::CheckAvailableMemory => {
+                    // Acknowledge the timer.
+                    freemem_timer.wait().map_err(Error::TimerFd)?;
+                    if current_balloon_memory == 0 {
+                        // Nothing to see here.
+                        if let Err(e) = freemem_timer.clear() {
+                            warn!("unable to clear available memory check timer: {}", e);
+                        }
+                        continue;
+                    }
+
+                    // Otherwise see if we can free up some memory.
+                    let margin = file_to_u64(LOWMEM_MARGIN).map_err(Error::ReadLowmemMargin)?;
+                    let available = file_to_u64(LOWMEM_AVAILABLE).map_err(Error::ReadLowmemAvailable)?;
+
+                    // `available` and `margin` are specified in MB while `balloon_memory_increment` is in
+                    // bytes.  So to correctly compare them we need to turn the increment value into MB.
+                    if available >= margin + 2*(balloon_memory_increment >> 20) {
+                        current_balloon_memory = if current_balloon_memory >= balloon_memory_increment {
+                            current_balloon_memory - balloon_memory_increment
+                        } else {
+                            0
+                        };
+                        let mut buf = [0u8; mem::size_of::<u64>()];
+                        LittleEndian::write_u64(&mut buf, current_balloon_memory);
+                        if let Err(e) = balloon_host_socket.send(&buf) {
+                            warn!("failed to send memory value to balloon device: {}", e);
+                        }
+                    }
+                }
+                Token::LowMemory => {
+                    let old_balloon_memory = current_balloon_memory;
+                    current_balloon_memory = min(current_balloon_memory + balloon_memory_increment, max_balloon_memory);
+                    if current_balloon_memory != old_balloon_memory {
+                        let mut buf = [0u8; mem::size_of::<u64>()];
+                        LittleEndian::write_u64(&mut buf, current_balloon_memory);
+                        if let Err(e) = balloon_host_socket.send(&buf) {
+                            warn!("failed to send memory value to balloon device: {}", e);
+                        }
+                    }
+
+                    // Stop polling the lowmem device until the timer fires.
+                    poll_ctx.delete(&low_mem).map_err(Error::PollContextDelete)?;
+
+                    // Add some jitter to the timer so that if there are multiple VMs running they don't
+                    // all start ballooning at exactly the same time.
+                    let lowmem_dur = Duration::from_millis(1000 + lowmem_jitter_ms.ind_sample(&mut rng));
+                    lowmem_timer.reset(lowmem_dur, None).map_err(Error::ResetTimerFd)?;
+
+                    // Also start a timer to check when we can start giving memory back.  Do the first check
+                    // after a minute (with jitter) and subsequent checks after every 30 seconds (with jitter).
+                    let freemem_dur = Duration::from_secs(60 + freemem_jitter_secs.ind_sample(&mut rng));
+                    let freemem_int = Duration::from_secs(30 + interval_jitter_secs.ind_sample(&mut rng));
+                    freemem_timer.reset(freemem_dur, Some(freemem_int)).map_err(Error::ResetTimerFd)?;
+                }
+                Token::LowmemTimer => {
+                    // Acknowledge the timer.
+                    lowmem_timer.wait().map_err(Error::TimerFd)?;
+
+                    // Start polling the lowmem device again.
+                    poll_ctx.add(&low_mem, Token::LowMemory).map_err(Error::PollContextAdd)?;
+                }
                 Token::VmControl { index } => {
                     if let Some(socket) = control_sockets.get(index as usize) {
                         match VmRequest::recv(&mut scm, socket.as_ref()) {
@@ -811,6 +953,9 @@ fn run_control(mut vm: Vm,
                         let _ = poll_ctx.delete(&stdin_handle);
                     },
                     Token::ChildSignal => {},
+                    Token::CheckAvailableMemory => {},
+                    Token::LowMemory => {},
+                    Token::LowmemTimer => {},
                     Token::VmControl { index } => {
                         if let Some(socket) = control_sockets.get(index as usize) {
                             let _ = poll_ctx.delete(socket.as_ref());
diff --git a/src/main.rs b/src/main.rs
index fc9c18c..c85d720 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -31,6 +31,7 @@ extern crate plugin_proto;
 extern crate protobuf;
 #[cfg(feature = "wl-dmabuf")]
 extern crate gpu_buffer;
+extern crate rand;
 
 pub mod argument;
 pub mod linux;
diff --git a/sys_util/src/guest_memory.rs b/sys_util/src/guest_memory.rs
index 31500ca..04ef124 100644
--- a/sys_util/src/guest_memory.rs
+++ b/sys_util/src/guest_memory.rs
@@ -132,10 +132,10 @@ impl GuestMemory {
     }
 
     /// Madvise away the address range in the host that is associated with the given guest range.
-    pub fn dont_need_range(&self, addr: GuestAddress, count: u64) -> Result<()> {
+    pub fn remove_range(&self, addr: GuestAddress, count: u64) -> Result<()> {
         self.do_in_region(addr, move |mapping, offset| {
             mapping
-                .dont_need_range(offset, count as usize)
+                .remove_range(offset, count as usize)
                 .map_err(|e| Error::MemoryAccess(addr, e))
         })
     }
diff --git a/sys_util/src/lib.rs b/sys_util/src/lib.rs
index caef4c2..30b293d 100644
--- a/sys_util/src/lib.rs
+++ b/sys_util/src/lib.rs
@@ -33,6 +33,7 @@ mod signalfd;
 mod sock_ctrl_msg;
 mod passwd;
 mod file_flags;
+mod timerfd;
 
 pub use mmap::*;
 pub use shm::*;
@@ -53,6 +54,7 @@ pub use sock_ctrl_msg::*;
 pub use passwd::*;
 pub use poll_token_derive::*;
 pub use file_flags::*;
+pub use timerfd::*;
 
 pub use mmap::Error as MmapError;
 pub use guest_memory::Error as GuestMemoryError;
diff --git a/sys_util/src/mmap.rs b/sys_util/src/mmap.rs
index 4ce72b6..ced48b1 100644
--- a/sys_util/src/mmap.rs
+++ b/sys_util/src/mmap.rs
@@ -316,8 +316,9 @@ impl MemoryMapping {
         Ok(())
     }
 
-    /// Uses madvise to tell the kernel the specified range won't be needed soon.
-    pub fn dont_need_range(&self, mem_offset: usize, count: usize) -> Result<()> {
+    /// Uses madvise to tell the kernel to remove the specified range.  Subsequent reads
+    /// to the pages in the range will return zero bytes.
+    pub fn remove_range(&self, mem_offset: usize, count: usize) -> Result<()> {
         self.range_end(mem_offset, count)
             .map_err(|_| Error::InvalidRange(mem_offset, count))?;
         let ret = unsafe {
@@ -325,7 +326,7 @@ impl MemoryMapping {
             // Next time it is read, it may return zero pages.
             libc::madvise((self.addr as usize + mem_offset) as *mut _,
                           count,
-                          libc::MADV_DONTNEED)
+                          libc::MADV_REMOVE)
         };
         if ret < 0 {
             Err(Error::InvalidRange(mem_offset, count))
diff --git a/sys_util/src/timerfd.rs b/sys_util/src/timerfd.rs
new file mode 100644
index 0000000..7d2c76a
--- /dev/null
+++ b/sys_util/src/timerfd.rs
@@ -0,0 +1,142 @@
+// Copyright 2018 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+use std::fs::File;
+use std::mem;
+use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd};
+use std::ptr;
+use std::time::Duration;
+
+use libc::{self, CLOCK_MONOTONIC, TFD_CLOEXEC, timerfd_create, timerfd_settime};
+
+use {Result, errno_result};
+
+/// A safe wrapper around a Linux timerfd (man 2 timerfd_create).
+pub struct TimerFd(File);
+
+impl TimerFd {
+    /// Creates a new timerfd.  The timer is initally disarmed and must be armed by calling
+    /// `reset`.
+    pub fn new() -> Result<TimerFd> {
+        // Safe because this doesn't modify any memory and we check the return value.
+        let ret = unsafe { timerfd_create(CLOCK_MONOTONIC, TFD_CLOEXEC) };
+        if ret < 0 {
+            return errno_result();
+        }
+
+        // Safe because we uniquely own the file descriptor.
+        Ok(TimerFd(unsafe { File::from_raw_fd(ret) } ))
+    }
+
+    /// Sets the timer to expire after `dur`.  If `interval` is not `None` it represents
+    /// the period for repeated expirations after the initial expiration.  Otherwise
+    /// the timer will expire just once.  Cancels any existing duration and repeating interval.
+    pub fn reset(&mut self, dur: Duration, interval: Option<Duration>) -> Result<()> {
+        // Safe because we are zero-initializing a struct with only primitive member fields.
+        let mut spec: libc::itimerspec = unsafe { mem::zeroed() };
+        spec.it_value.tv_sec = dur.as_secs() as libc::time_t;
+        spec.it_value.tv_nsec = dur.subsec_nanos() as libc::c_long;
+
+        if let Some(int) = interval {
+            spec.it_interval.tv_sec = int.as_secs() as libc::time_t;
+            spec.it_interval.tv_nsec = int.subsec_nanos() as libc::c_long;
+        }
+
+        // Safe because this doesn't modify any memory and we check the return value.
+        let ret = unsafe { timerfd_settime(self.as_raw_fd(), 0, &spec, ptr::null_mut()) };
+        if ret < 0 {
+            return errno_result();
+        }
+
+        Ok(())
+    }
+
+    /// Waits until the timer expires.  The return value represents the number of times the timer
+    /// has expired since the last time `wait` was called.  If the timer has not yet expired once
+    /// this call will block until it does.
+    pub fn wait(&mut self) -> Result<u64> {
+        let mut count = 0u64;
+
+        // Safe because this will only modify |buf| and we check the return value.
+        let ret = unsafe {
+            libc::read(self.as_raw_fd(),
+                       &mut count as *mut _ as *mut libc::c_void,
+                       mem::size_of_val(&count))
+        };
+        if ret < 0 {
+            return errno_result();
+        }
+
+        // The bytes in the buffer are guaranteed to be in native byte-order so we don't need to
+        // use from_le or from_be.
+        Ok(count)
+    }
+
+    /// Disarms the timer.
+    pub fn clear(&mut self) -> Result<()> {
+        // Safe because we are zero-initializing a struct with only primitive member fields.
+        let spec: libc::itimerspec = unsafe { mem::zeroed() };
+
+        // Safe because this doesn't modify any memory and we check the return value.
+        let ret = unsafe { timerfd_settime(self.as_raw_fd(),  0, &spec, ptr::null_mut()) };
+        if ret < 0 {
+            return errno_result();
+        }
+
+        Ok(())
+    }
+}
+
+impl AsRawFd for TimerFd {
+    fn as_raw_fd(&self) -> RawFd {
+        self.0.as_raw_fd()
+    }
+}
+
+impl FromRawFd for TimerFd {
+    unsafe fn from_raw_fd(fd: RawFd) -> Self {
+        TimerFd(File::from_raw_fd(fd))
+    }
+}
+
+impl IntoRawFd for TimerFd {
+    fn into_raw_fd(self) -> RawFd {
+        self.0.into_raw_fd()
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use std::thread::sleep;
+    use std::time::{Duration, Instant};
+
+    #[test]
+    fn one_shot() {
+        let mut tfd = TimerFd::new().expect("failed to create timerfd");
+
+        let dur = Duration::from_millis(200);
+        let now = Instant::now();
+        tfd.reset(dur.clone(), None).expect("failed to arm timer");
+
+        let count = tfd.wait().expect("unable to wait for timer");
+
+        assert_eq!(count, 1);
+        assert!(now.elapsed() >= dur);
+    }
+
+    #[test]
+    fn repeating() {
+        let mut tfd = TimerFd::new().expect("failed to create timerfd");
+
+        let dur = Duration::from_millis(200);
+        let interval = Duration::from_millis(100);
+        tfd.reset(dur.clone(), Some(interval)).expect("failed to arm timer");
+
+        sleep(dur * 3);
+
+        let count = tfd.wait().expect("unable to wait for timer");
+        assert!(count >= 5, "count = {}", count);
+    }
+}