From ba46b1818a5065fd84d969bc4bf2e6f7aa98d9c7 Mon Sep 17 00:00:00 2001 From: edef Date: Thu, 25 Apr 2024 23:27:58 +0000 Subject: fix(nix-compat/bytes): make BytesReader less hazardous We now *never* return the final bytes until we've read the padding in full, so read_exact is safe to use. This is implemented by TrailerReader, which splits the phases of reading (and validating) the final 8-byte block, and providing the contained payload bytes to the caller. Change-Id: I0d05946a8af9c260a18d71d2b763ba7a68b3c27f Reviewed-on: https://cl.tvl.fyi/c/depot/+/11518 Tested-by: BuildkiteCI Reviewed-by: flokli --- tvix/nix-compat/src/wire/bytes/reader/trailer.rs | 175 +++++++++++++++++++++++ 1 file changed, 175 insertions(+) create mode 100644 tvix/nix-compat/src/wire/bytes/reader/trailer.rs (limited to 'tvix/nix-compat/src/wire/bytes/reader/trailer.rs') diff --git a/tvix/nix-compat/src/wire/bytes/reader/trailer.rs b/tvix/nix-compat/src/wire/bytes/reader/trailer.rs new file mode 100644 index 0000000000..d2b867c2c3 --- /dev/null +++ b/tvix/nix-compat/src/wire/bytes/reader/trailer.rs @@ -0,0 +1,175 @@ +use std::{ + pin::Pin, + task::{self, ready, Poll}, +}; + +use tokio::io::{self, AsyncRead, ReadBuf}; + +#[derive(Debug)] +pub enum TrailerReader { + Reading { + reader: R, + user_len: u8, + filled: u8, + buf: [u8; 8], + }, + Releasing { + off: u8, + len: u8, + buf: [u8; 8], + }, + Done, +} + +impl TrailerReader { + pub fn new(reader: R, user_len: u8) -> Self { + if user_len == 0 { + return Self::Done; + } + + assert!(user_len < 8, "payload in trailer must be less than 8 bytes"); + Self::Reading { + reader, + user_len, + filled: 0, + buf: [0; 8], + } + } +} + +impl AsyncRead for TrailerReader { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut task::Context, + user_buf: &mut ReadBuf, + ) -> Poll> { + let this = self.get_mut(); + + loop { + match this { + &mut Self::Reading { + reader: _, + user_len, + filled: 8, + buf, + } => { + *this = Self::Releasing { + off: 0, + len: user_len, + buf, + }; + } + Self::Reading { + reader, + user_len, + filled, + buf, + } => { + let mut read_buf = ReadBuf::new(&mut buf[..]); + read_buf.advance(*filled as usize); + ready!(Pin::new(reader).poll_read(cx, &mut read_buf))?; + + let new_filled = read_buf.filled().len() as u8; + if *filled == new_filled { + return Err(io::ErrorKind::UnexpectedEof.into()).into(); + } + + *filled = new_filled; + + // ensure the padding is all zeroes + if (u64::from_le_bytes(*buf) >> (*user_len * 8)) != 0 { + return Err(io::ErrorKind::InvalidData.into()).into(); + } + } + Self::Releasing { off: 8, .. } => { + *this = Self::Done; + } + Self::Releasing { off, len, buf } => { + assert_ne!(user_buf.remaining(), 0); + + let buf = &buf[*off as usize..*len as usize]; + let buf = &buf[..usize::min(buf.len(), user_buf.remaining())]; + + user_buf.put_slice(buf); + *off += buf.len() as u8; + + break; + } + Self::Done => break, + } + } + + Ok(()).into() + } +} + +#[cfg(test)] +mod tests { + use std::time::Duration; + use tokio::io::AsyncReadExt; + + use super::*; + + #[tokio::test] + async fn unexpected_eof() { + let reader = tokio_test::io::Builder::new() + .read(&[0xed]) + .wait(Duration::ZERO) + .read(&[0xef, 0x00]) + .build(); + + let mut reader = TrailerReader::new(reader, 2); + + let mut buf = vec![]; + assert_eq!( + reader.read_to_end(&mut buf).await.unwrap_err().kind(), + io::ErrorKind::UnexpectedEof + ); + } + + #[tokio::test] + async fn invalid_padding() { + let reader = tokio_test::io::Builder::new() + .read(&[0xed]) + .wait(Duration::ZERO) + .read(&[0xef, 0x01, 0x00]) + .wait(Duration::ZERO) + .build(); + + let mut reader = TrailerReader::new(reader, 2); + + let mut buf = vec![]; + assert_eq!( + reader.read_to_end(&mut buf).await.unwrap_err().kind(), + io::ErrorKind::InvalidData + ); + } + + #[tokio::test] + async fn success() { + let reader = tokio_test::io::Builder::new() + .read(&[0xed]) + .wait(Duration::ZERO) + .read(&[0xef, 0x00]) + .wait(Duration::ZERO) + .read(&[0x00, 0x00, 0x00, 0x00, 0x00]) + .build(); + + let mut reader = TrailerReader::new(reader, 2); + + let mut buf = vec![]; + reader.read_to_end(&mut buf).await.unwrap(); + + assert_eq!(buf, &[0xed, 0xef]); + } + + #[tokio::test] + async fn no_padding() { + let reader = tokio_test::io::Builder::new().build(); + let mut reader = TrailerReader::new(reader, 0); + + let mut buf = vec![]; + reader.read_to_end(&mut buf).await.unwrap(); + assert!(buf.is_empty()); + } +} -- cgit 1.4.1