about summary refs log tree commit diff
path: root/tvix/nix-compat/src/wire/bytes/reader/trailer.rs
use std::{
    pin::Pin,
    task::{self, ready, Poll},
};

use tokio::io::{self, AsyncRead, ReadBuf};

#[derive(Debug)]
pub enum TrailerReader<R> {
    Reading {
        reader: R,
        user_len: u8,
        filled: u8,
        buf: [u8; 8],
    },
    Releasing {
        off: u8,
        len: u8,
        buf: [u8; 8],
    },
    Done,
}

impl<R: AsyncRead + Unpin> TrailerReader<R> {
    pub fn new(reader: R, user_len: u8) -> Self {
        if user_len == 0 {
            return Self::Done;
        }

        assert!(user_len < 8, "payload in trailer must be less than 8 bytes");
        Self::Reading {
            reader,
            user_len,
            filled: 0,
            buf: [0; 8],
        }
    }
}

impl<R: AsyncRead + Unpin> AsyncRead for TrailerReader<R> {
    fn poll_read(
        mut self: Pin<&mut Self>,
        cx: &mut task::Context,
        user_buf: &mut ReadBuf,
    ) -> Poll<io::Result<()>> {
        let this = &mut *self;

        loop {
            match this {
                &mut Self::Reading {
                    reader: _,
                    user_len,
                    filled: 8,
                    buf,
                } => {
                    *this = Self::Releasing {
                        off: 0,
                        len: user_len,
                        buf,
                    };
                }
                Self::Reading {
                    reader,
                    user_len,
                    filled,
                    buf,
                } => {
                    let mut read_buf = ReadBuf::new(&mut buf[..]);
                    read_buf.advance(*filled as usize);
                    ready!(Pin::new(reader).poll_read(cx, &mut read_buf))?;

                    let new_filled = read_buf.filled().len() as u8;
                    if *filled == new_filled {
                        return Err(io::ErrorKind::UnexpectedEof.into()).into();
                    }

                    *filled = new_filled;

                    // ensure the padding is all zeroes
                    if (u64::from_le_bytes(*buf) >> (*user_len * 8)) != 0 {
                        return Err(io::ErrorKind::InvalidData.into()).into();
                    }
                }
                Self::Releasing { off: 8, .. } => {
                    *this = Self::Done;
                }
                Self::Releasing { off, len, buf } => {
                    assert_ne!(user_buf.remaining(), 0);

                    let buf = &buf[*off as usize..*len as usize];
                    let buf = &buf[..usize::min(buf.len(), user_buf.remaining())];

                    user_buf.put_slice(buf);
                    *off += buf.len() as u8;

                    break;
                }
                Self::Done => break,
            }
        }

        Ok(()).into()
    }
}

#[cfg(test)]
mod tests {
    use std::time::Duration;
    use tokio::io::AsyncReadExt;

    use super::*;

    #[tokio::test]
    async fn unexpected_eof() {
        let reader = tokio_test::io::Builder::new()
            .read(&[0xed])
            .wait(Duration::ZERO)
            .read(&[0xef, 0x00])
            .build();

        let mut reader = TrailerReader::new(reader, 2);

        let mut buf = vec![];
        assert_eq!(
            reader.read_to_end(&mut buf).await.unwrap_err().kind(),
            io::ErrorKind::UnexpectedEof
        );
    }

    #[tokio::test]
    async fn invalid_padding() {
        let reader = tokio_test::io::Builder::new()
            .read(&[0xed])
            .wait(Duration::ZERO)
            .read(&[0xef, 0x01, 0x00])
            .wait(Duration::ZERO)
            .build();

        let mut reader = TrailerReader::new(reader, 2);

        let mut buf = vec![];
        assert_eq!(
            reader.read_to_end(&mut buf).await.unwrap_err().kind(),
            io::ErrorKind::InvalidData
        );
    }

    #[tokio::test]
    async fn success() {
        let reader = tokio_test::io::Builder::new()
            .read(&[0xed])
            .wait(Duration::ZERO)
            .read(&[0xef, 0x00])
            .wait(Duration::ZERO)
            .read(&[0x00, 0x00, 0x00, 0x00, 0x00])
            .build();

        let mut reader = TrailerReader::new(reader, 2);

        let mut buf = vec![];
        reader.read_to_end(&mut buf).await.unwrap();

        assert_eq!(buf, &[0xed, 0xef]);
    }

    #[tokio::test]
    async fn no_padding() {
        let reader = tokio_test::io::Builder::new().build();
        let mut reader = TrailerReader::new(reader, 0);

        let mut buf = vec![];
        reader.read_to_end(&mut buf).await.unwrap();
        assert!(buf.is_empty());
    }
}