summary refs log tree commit diff
path: root/io_uring/src/uring.rs
diff options
context:
space:
mode:
Diffstat (limited to 'io_uring/src/uring.rs')
-rw-r--r--io_uring/src/uring.rs120
1 files changed, 97 insertions, 23 deletions
diff --git a/io_uring/src/uring.rs b/io_uring/src/uring.rs
index 8d569aa..2a78753 100644
--- a/io_uring/src/uring.rs
+++ b/io_uring/src/uring.rs
@@ -2,14 +2,14 @@
 // Use of this source code is governed by a BSD-style license that can be
 // found in the LICENSE file.
 
+use std::collections::BTreeMap;
 use std::fmt;
 use std::fs::File;
-use std::io::IoSlice;
 use std::os::unix::io::{AsRawFd, FromRawFd, RawFd};
 use std::ptr::null_mut;
 use std::sync::atomic::{AtomicU32, Ordering};
 
-use sys_util::{MemoryMapping, WatchingEvents};
+use sys_util::{MappedRegion, MemoryMapping, WatchingEvents};
 
 use crate::bindings::*;
 use crate::syscalls::*;
@@ -78,11 +78,11 @@ pub struct URingStats {
 /// let f = File::open(Path::new("/dev/zero")).unwrap();
 /// let mut uring = URingContext::new(16).unwrap();
 /// uring
-///   .add_poll_fd(f.as_raw_fd(), WatchingEvents::empty().set_read(), 454)
+///   .add_poll_fd(f.as_raw_fd(), &WatchingEvents::empty().set_read(), 454)
 /// .unwrap();
 /// let (user_data, res) = uring.wait().unwrap().next().unwrap();
-/// assert_eq!(user_data, 454 as UserData);
-/// assert_eq!(res.unwrap(), 1 as i32);
+/// assert_eq!(user_data, 454 as io_uring::UserData);
+/// assert_eq!(res.unwrap(), 1 as u32);
 ///
 /// ```
 pub struct URingContext {
@@ -260,6 +260,21 @@ impl URingContext {
         self.add_rw_op(ptr, len, fd, offset, user_data, IORING_OP_READV as u8)
     }
 
+    /// See 'writev' but accepts an iterator instead of a vector if there isn't already a vector in
+    /// existence.
+    pub unsafe fn add_writev_iter<I>(
+        &mut self,
+        iovecs: I,
+        fd: RawFd,
+        offset: u64,
+        user_data: UserData,
+    ) -> Result<()>
+    where
+        I: Iterator<Item = libc::iovec>,
+    {
+        self.add_writev(iovecs.collect(), fd, offset, user_data)
+    }
+
     /// Asynchronously writes to `fd` from the addresses given in `iovecs`.
     /// # Safety
     /// `add_writev` will write to the address given by `iovecs`. This is only safe if the caller
@@ -267,9 +282,10 @@ impl URingContext {
     /// transaction is complete and that completion has been returned from the `wait` function.  In
     /// addition there must not be any mutable references to the data pointed to by `iovecs` until
     /// the operation completes.  Ensure that the fd remains open until the op completes as well.
+    /// The iovecs reference must be kept alive until the op returns.
     pub unsafe fn add_writev(
         &mut self,
-        iovecs: &[IoSlice],
+        iovecs: Vec<libc::iovec>,
         fd: RawFd,
         offset: u64,
         user_data: UserData,
@@ -284,7 +300,24 @@ impl URingContext {
             sqe.user_data = user_data;
             sqe.flags = 0;
             sqe.fd = fd;
-        })
+        })?;
+        self.complete_ring.add_op_data(user_data, iovecs);
+        Ok(())
+    }
+
+    /// See 'readv' but accepts an iterator instead of a vector if there isn't already a vector in
+    /// existence.
+    pub unsafe fn add_readv_iter<I>(
+        &mut self,
+        iovecs: I,
+        fd: RawFd,
+        offset: u64,
+        user_data: UserData,
+    ) -> Result<()>
+    where
+        I: Iterator<Item = libc::iovec>,
+    {
+        self.add_readv(iovecs.collect(), fd, offset, user_data)
     }
 
     /// Asynchronously reads from `fd` to the addresses given in `iovecs`.
@@ -294,9 +327,10 @@ impl URingContext {
     /// transaction is complete and that completion has been returned from the `wait` function.  In
     /// addition there must not be any references to the data pointed to by `iovecs` until the
     /// operation completes.  Ensure that the fd remains open until the op completes as well.
+    /// The iovecs reference must be kept alive until the op returns.
     pub unsafe fn add_readv(
         &mut self,
-        iovecs: &[IoSlice],
+        iovecs: Vec<libc::iovec>,
         fd: RawFd,
         offset: u64,
         user_data: UserData,
@@ -311,7 +345,9 @@ impl URingContext {
             sqe.user_data = user_data;
             sqe.flags = 0;
             sqe.fd = fd;
-        })
+        })?;
+        self.complete_ring.add_op_data(user_data, iovecs);
+        Ok(())
     }
 
     /// Syncs all completed operations, the ordering with in-flight async ops is not
@@ -367,7 +403,7 @@ impl URingContext {
     pub fn add_poll_fd(
         &mut self,
         fd: RawFd,
-        events: WatchingEvents,
+        events: &WatchingEvents,
         user_data: UserData,
     ) -> Result<()> {
         self.prep_next_sqe(|sqe, _iovec| {
@@ -389,7 +425,7 @@ impl URingContext {
     pub fn remove_poll_fd(
         &mut self,
         fd: RawFd,
-        events: WatchingEvents,
+        events: &WatchingEvents,
         user_data: UserData,
     ) -> Result<()> {
         self.prep_next_sqe(|sqe, _iovec| {
@@ -524,6 +560,9 @@ struct CompleteQueueState {
     ring_mask: u32,
     cqes_offset: u32,
     completed: usize,
+    //For ops that pass in arrays of iovecs, they need to be valid for the duration of the
+    //operation because the kernel might read them at any time.
+    pending_op_addrs: BTreeMap<UserData, Vec<libc::iovec>>,
 }
 
 impl CompleteQueueState {
@@ -541,9 +580,14 @@ impl CompleteQueueState {
             ring_mask,
             cqes_offset: params.cq_off.cqes,
             completed: 0,
+            pending_op_addrs: BTreeMap::new(),
         }
     }
 
+    fn add_op_data(&mut self, user_data: UserData, addrs: Vec<libc::iovec>) {
+        self.pending_op_addrs.insert(user_data, addrs);
+    }
+
     fn get_cqe(&self, head: u32) -> &io_uring_cqe {
         unsafe {
             // Safe because we trust that the kernel has returned enough memory in io_uring_setup
@@ -582,6 +626,9 @@ impl Iterator for CompleteQueueState {
         let user_data = cqe.user_data;
         let res = cqe.res;
 
+        // free the addrs saved for this op.
+        let _ = self.pending_op_addrs.remove(&user_data);
+
         // Store the new head and ensure the reads above complete before the kernel sees the
         // update to head, `set_head` uses `Release` ordering
         let new_head = head.wrapping_add(1);
@@ -637,6 +684,7 @@ impl QueuePointers {
 #[cfg(test)]
 mod tests {
     use std::fs::OpenOptions;
+    use std::io::{IoSlice, IoSliceMut};
     use std::io::{Read, Seek, SeekFrom, Write};
     use std::path::{Path, PathBuf};
     use std::time::Duration;
@@ -677,10 +725,18 @@ mod tests {
         offset: u64,
         user_data: UserData,
     ) {
-        let iovecs = [IoSlice::new(buf)];
+        let io_vecs = unsafe {
+            //safe to transmut from IoSlice to iovec.
+            vec![IoSliceMut::new(buf)]
+                .into_iter()
+                .map(|slice| std::mem::transmute::<IoSliceMut, libc::iovec>(slice))
+                .collect::<Vec<libc::iovec>>()
+        };
         let (user_data_ret, res) = unsafe {
             // Safe because the `wait` call waits until the kernel is done with `buf`.
-            uring.add_readv(&iovecs, fd, offset, user_data).unwrap();
+            uring
+                .add_readv_iter(io_vecs.into_iter(), fd, offset, user_data)
+                .unwrap();
             uring.wait().unwrap().next().unwrap()
         };
         assert_eq!(user_data_ret, user_data);
@@ -771,15 +827,27 @@ mod tests {
         const BUF_SIZE: usize = 0x2000;
 
         let mut uring = URingContext::new(queue_size).unwrap();
-        let buf = [0u8; BUF_SIZE];
-        let buf2 = [0u8; BUF_SIZE];
-        let buf3 = [0u8; BUF_SIZE];
-        let io_slices = vec![IoSlice::new(&buf), IoSlice::new(&buf2), IoSlice::new(&buf3)];
-        let total_len = io_slices.iter().fold(0, |a, iovec| a + iovec.len());
+        let mut buf = [0u8; BUF_SIZE];
+        let mut buf2 = [0u8; BUF_SIZE];
+        let mut buf3 = [0u8; BUF_SIZE];
+        let io_vecs = unsafe {
+            //safe to transmut from IoSlice to iovec.
+            vec![
+                IoSliceMut::new(&mut buf),
+                IoSliceMut::new(&mut buf2),
+                IoSliceMut::new(&mut buf3),
+            ]
+            .into_iter()
+            .map(|slice| std::mem::transmute::<IoSliceMut, libc::iovec>(slice))
+            .collect::<Vec<libc::iovec>>()
+        };
+        let total_len = io_vecs.iter().fold(0, |a, iovec| a + iovec.iov_len);
         let f = create_test_file(&temp_dir, total_len as u64 * 2);
         let (user_data_ret, res) = unsafe {
             // Safe because the `wait` call waits until the kernel is done with `buf`.
-            uring.add_readv(&io_slices, f.as_raw_fd(), 0, 55).unwrap();
+            uring
+                .add_readv_iter(io_vecs.into_iter(), f.as_raw_fd(), 0, 55)
+                .unwrap();
             uring.wait().unwrap().next().unwrap()
         };
         assert_eq!(user_data_ret, 55);
@@ -865,13 +933,19 @@ mod tests {
         let buf = [0xaau8; BUF_SIZE];
         let buf2 = [0xffu8; BUF_SIZE];
         let buf3 = [0x55u8; BUF_SIZE];
-        let io_slices = vec![IoSlice::new(&buf), IoSlice::new(&buf2), IoSlice::new(&buf3)];
-        let total_len = io_slices.iter().fold(0, |a, iovec| a + iovec.len());
+        let io_vecs = unsafe {
+            //safe to transmut from IoSlice to iovec.
+            vec![IoSlice::new(&buf), IoSlice::new(&buf2), IoSlice::new(&buf3)]
+                .into_iter()
+                .map(|slice| std::mem::transmute::<IoSlice, libc::iovec>(slice))
+                .collect::<Vec<libc::iovec>>()
+        };
+        let total_len = io_vecs.iter().fold(0, |a, iovec| a + iovec.iov_len);
         let mut f = create_test_file(&temp_dir, total_len as u64 * 2);
         let (user_data_ret, res) = unsafe {
             // Safe because the `wait` call waits until the kernel is done with `buf`.
             uring
-                .add_writev(&io_slices, f.as_raw_fd(), OFFSET, 55)
+                .add_writev_iter(io_vecs.into_iter(), f.as_raw_fd(), OFFSET, 55)
                 .unwrap();
             uring.wait().unwrap().next().unwrap()
         };
@@ -951,7 +1025,7 @@ mod tests {
         let f = File::open(Path::new("/dev/zero")).unwrap();
         let mut uring = URingContext::new(16).unwrap();
         uring
-            .add_poll_fd(f.as_raw_fd(), WatchingEvents::empty().set_read(), 454)
+            .add_poll_fd(f.as_raw_fd(), &WatchingEvents::empty().set_read(), 454)
             .unwrap();
         let (user_data, res) = uring.wait().unwrap().next().unwrap();
         assert_eq!(user_data, 454 as UserData);