summary refs log tree commit diff
path: root/sys_util/src/write_zeroes.rs
diff options
context:
space:
mode:
Diffstat (limited to 'sys_util/src/write_zeroes.rs')
-rw-r--r--sys_util/src/write_zeroes.rs35
1 files changed, 29 insertions, 6 deletions
diff --git a/sys_util/src/write_zeroes.rs b/sys_util/src/write_zeroes.rs
index e3b531e..0e733c7 100644
--- a/sys_util/src/write_zeroes.rs
+++ b/sys_util/src/write_zeroes.rs
@@ -4,7 +4,7 @@
 
 use std::cmp::min;
 use std::fs::File;
-use std::io::{self, Seek, SeekFrom, Write};
+use std::io::{self, Error, ErrorKind, Seek, SeekFrom, Write};
 
 use crate::fallocate;
 use crate::FallocateMode;
@@ -24,8 +24,31 @@ impl PunchHole for File {
 
 /// A trait for writing zeroes to a stream.
 pub trait WriteZeroes {
-    /// Write `length` bytes of zeroes to the stream, returning how many bytes were written.
+    /// Write up to `length` bytes of zeroes to the stream, returning how many bytes were written.
     fn write_zeroes(&mut self, length: usize) -> io::Result<usize>;
+
+    /// Write zeroes to the stream until `length` bytes have been written.
+    ///
+    /// This method will continuously call `write_zeroes` until the requested
+    /// `length` is satisfied or an error is encountered.
+    fn write_zeroes_all(&mut self, mut length: usize) -> io::Result<()> {
+        while length > 0 {
+            match self.write_zeroes(length) {
+                Ok(0) => return Err(Error::from(ErrorKind::WriteZero)),
+                Ok(bytes_written) => {
+                    length = length
+                        .checked_sub(bytes_written)
+                        .ok_or(Error::from(ErrorKind::Other))?
+                }
+                Err(e) => {
+                    if e.kind() != ErrorKind::Interrupted {
+                        return Err(e);
+                    }
+                }
+            }
+        }
+        Ok(())
+    }
 }
 
 impl<T: PunchHole + Seek + Write> WriteZeroes for T {
@@ -98,8 +121,8 @@ mod tests {
 
         // Overwrite some of the data with zeroes
         f.seek(SeekFrom::Start(2345)).unwrap();
-        f.write_zeroes(4321).expect("write_zeroes failed");
-        // Verify seek position after write_zeroes()
+        f.write_zeroes_all(4321).expect("write_zeroes failed");
+        // Verify seek position after write_zeroes_all()
         assert_eq!(f.seek(SeekFrom::Current(0)).unwrap(), 2345 + 4321);
 
         // Read back the data and verify that it is now zero
@@ -147,8 +170,8 @@ mod tests {
 
         // Overwrite some of the data with zeroes
         f.seek(SeekFrom::Start(0)).unwrap();
-        f.write_zeroes(0x10001).expect("write_zeroes failed");
-        // Verify seek position after write_zeroes()
+        f.write_zeroes_all(0x10001).expect("write_zeroes failed");
+        // Verify seek position after write_zeroes_all()
         assert_eq!(f.seek(SeekFrom::Current(0)).unwrap(), 0x10001);
 
         // Read back the data and verify that it is now zero