about summary refs log tree commit diff
path: root/tvix/nix-compat/src/wire/bytes/reader/trailer.rs
diff options
context:
space:
mode:
Diffstat (limited to 'tvix/nix-compat/src/wire/bytes/reader/trailer.rs')
-rw-r--r--tvix/nix-compat/src/wire/bytes/reader/trailer.rs112
1 files changed, 28 insertions, 84 deletions
diff --git a/tvix/nix-compat/src/wire/bytes/reader/trailer.rs b/tvix/nix-compat/src/wire/bytes/reader/trailer.rs
index 958cead42d..3a5bb75e71 100644
--- a/tvix/nix-compat/src/wire/bytes/reader/trailer.rs
+++ b/tvix/nix-compat/src/wire/bytes/reader/trailer.rs
@@ -1,4 +1,5 @@
 use std::{
+    fmt::Debug,
     future::Future,
     marker::PhantomData,
     ops::Deref,
@@ -8,11 +9,11 @@ use std::{
 
 use tokio::io::{self, AsyncRead, ReadBuf};
 
-/// Trailer represents up to 7 bytes of data read as part of the trailer block(s)
+/// Trailer represents up to 8 bytes of data read as part of the trailer block(s)
 #[derive(Debug)]
 pub(crate) struct Trailer {
     data_len: u8,
-    buf: [u8; 7],
+    buf: [u8; 8],
 }
 
 impl Deref for Trailer {
@@ -27,20 +28,20 @@ impl Deref for Trailer {
 pub(crate) trait Tag {
     /// The expected suffix
     ///
-    /// The first 7 bytes may be ignored, and it must be an 8-byte aligned size.
+    /// The first 8 bytes may be ignored, and it must be an 8-byte aligned size.
     const PATTERN: &'static [u8];
 
     /// Suitably sized buffer for reading [Self::PATTERN]
     ///
     /// HACK: This is a workaround for const generics limitations.
-    type Buf: AsRef<[u8]> + AsMut<[u8]> + Unpin;
+    type Buf: AsRef<[u8]> + AsMut<[u8]> + Debug + Unpin;
 
     /// Make an instance of [Self::Buf]
     fn make_buf() -> Self::Buf;
 }
 
 #[derive(Debug)]
-pub(crate) enum Pad {}
+pub enum Pad {}
 
 impl Tag for Pad {
     const PATTERN: &'static [u8] = &[0; 8];
@@ -58,7 +59,7 @@ pub(crate) struct ReadTrailer<R, T: Tag> {
     data_len: u8,
     filled: u8,
     buf: T::Buf,
-    _phantom: PhantomData<*const T>,
+    _phantom: PhantomData<fn(T) -> T>,
 }
 
 /// read_trailer returns a [Future] that reads a trailer with a given [Tag] from `reader`
@@ -66,7 +67,7 @@ pub(crate) fn read_trailer<R: AsyncRead + Unpin, T: Tag>(
     reader: R,
     data_len: u8,
 ) -> ReadTrailer<R, T> {
-    assert!(data_len < 8, "payload in trailer must be less than 8 bytes");
+    assert!(data_len <= 8, "payload in trailer must be <= 8 bytes");
 
     let buf = T::make_buf();
     assert_eq!(buf.as_ref().len(), T::PATTERN.len());
@@ -81,10 +82,16 @@ pub(crate) fn read_trailer<R: AsyncRead + Unpin, T: Tag>(
     }
 }
 
+impl<R, T: Tag> ReadTrailer<R, T> {
+    pub fn len(&self) -> u8 {
+        self.data_len
+    }
+}
+
 impl<R: AsyncRead + Unpin, T: Tag> Future for ReadTrailer<R, T> {
     type Output = io::Result<Trailer>;
 
-    fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context) -> task::Poll<Self::Output> {
+    fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context) -> Poll<Self::Output> {
         let this = &mut *self;
 
         loop {
@@ -101,8 +108,8 @@ impl<R: AsyncRead + Unpin, T: Tag> Future for ReadTrailer<R, T> {
             }
 
             if this.filled as usize == T::PATTERN.len() {
-                let mut buf = [0; 7];
-                buf.copy_from_slice(&this.buf.as_ref()[..7]);
+                let mut buf = [0; 8];
+                buf.copy_from_slice(&this.buf.as_ref()[..8]);
 
                 return Ok(Trailer {
                     data_len: this.data_len,
@@ -117,10 +124,9 @@ impl<R: AsyncRead + Unpin, T: Tag> Future for ReadTrailer<R, T> {
             ready!(Pin::new(&mut this.reader).poll_read(cx, &mut buf))?;
 
             this.filled = {
-                let prev_filled = this.filled;
                 let filled = buf.filled().len() as u8;
 
-                if filled == prev_filled {
+                if filled == this.filled {
                     return Err(io::ErrorKind::UnexpectedEof.into()).into();
                 }
 
@@ -130,61 +136,9 @@ impl<R: AsyncRead + Unpin, T: Tag> Future for ReadTrailer<R, T> {
     }
 }
 
-#[derive(Debug)]
-pub(crate) enum TrailerReader<R> {
-    Reading(ReadTrailer<R, Pad>),
-    Releasing { off: u8, data: Trailer },
-    Done,
-}
-
-impl<R: AsyncRead + Unpin> TrailerReader<R> {
-    pub fn new(reader: R, data_len: u8) -> Self {
-        Self::Reading(read_trailer(reader, data_len))
-    }
-}
-
-impl<R: AsyncRead + Unpin> AsyncRead for TrailerReader<R> {
-    fn poll_read(
-        mut self: Pin<&mut Self>,
-        cx: &mut task::Context,
-        user_buf: &mut ReadBuf,
-    ) -> Poll<io::Result<()>> {
-        let this = &mut *self;
-
-        loop {
-            match this {
-                Self::Reading(fut) => {
-                    *this = Self::Releasing {
-                        off: 0,
-                        data: ready!(Pin::new(fut).poll(cx))?,
-                    };
-                }
-                Self::Releasing { off: 8, .. } => {
-                    *this = Self::Done;
-                }
-                Self::Releasing { off, data } => {
-                    assert_ne!(user_buf.remaining(), 0);
-
-                    let buf = &data[*off as usize..];
-                    let buf = &buf[..usize::min(buf.len(), user_buf.remaining())];
-
-                    user_buf.put_slice(buf);
-                    *off += buf.len() as u8;
-
-                    break;
-                }
-                Self::Done => break,
-            }
-        }
-
-        Ok(()).into()
-    }
-}
-
 #[cfg(test)]
 mod tests {
     use std::time::Duration;
-    use tokio::io::AsyncReadExt;
 
     use super::*;
 
@@ -196,11 +150,8 @@ mod tests {
             .read(&[0xef, 0x00])
             .build();
 
-        let mut reader = TrailerReader::new(reader, 2);
-
-        let mut buf = vec![];
         assert_eq!(
-            reader.read_to_end(&mut buf).await.unwrap_err().kind(),
+            read_trailer::<_, Pad>(reader, 2).await.unwrap_err().kind(),
             io::ErrorKind::UnexpectedEof
         );
     }
@@ -214,11 +165,8 @@ mod tests {
             .wait(Duration::ZERO)
             .build();
 
-        let mut reader = TrailerReader::new(reader, 2);
-
-        let mut buf = vec![];
         assert_eq!(
-            reader.read_to_end(&mut buf).await.unwrap_err().kind(),
+            read_trailer::<_, Pad>(reader, 2).await.unwrap_err().kind(),
             io::ErrorKind::InvalidData
         );
     }
@@ -233,21 +181,17 @@ mod tests {
             .read(&[0x00, 0x00, 0x00, 0x00, 0x00])
             .build();
 
-        let mut reader = TrailerReader::new(reader, 2);
-
-        let mut buf = vec![];
-        reader.read_to_end(&mut buf).await.unwrap();
-
-        assert_eq!(buf, &[0xed, 0xef]);
+        assert_eq!(
+            &*read_trailer::<_, Pad>(reader, 2).await.unwrap(),
+            &[0xed, 0xef]
+        );
     }
 
     #[tokio::test]
     async fn no_padding() {
-        let reader = tokio_test::io::Builder::new().build();
-        let mut reader = TrailerReader::new(reader, 0);
-
-        let mut buf = vec![];
-        reader.read_to_end(&mut buf).await.unwrap();
-        assert!(buf.is_empty());
+        assert!(read_trailer::<_, Pad>(io::empty(), 0)
+            .await
+            .unwrap()
+            .is_empty());
     }
 }