From 5fccbe5939b376d56690cdf8ebcab61f7e852d86 Mon Sep 17 00:00:00 2001 From: Florian Klink Date: Thu, 14 Mar 2024 17:03:38 +0200 Subject: feat(nix-compat/wire): add read_bytes[_unchecked] This introduces a version reading sized byte packets. Both read_bytes, accepting a range of allowed sizes, as well as read_bytes_unchecked, which doesn't care, are added, including tests. Co-Authored-By: picnoir Change-Id: I9fc1c61eb561105e649eecca832af28badfdaaa8 Reviewed-on: https://cl.tvl.fyi/c/depot/+/11150 Autosubmit: flokli Reviewed-by: picnoir picnoir Tested-by: BuildkiteCI --- tvix/nix-compat/src/wire/bytes.rs | 130 ++++++++++++++++++++++++++++++++++++++ tvix/nix-compat/src/wire/mod.rs | 3 + 2 files changed, 133 insertions(+) create mode 100644 tvix/nix-compat/src/wire/bytes.rs diff --git a/tvix/nix-compat/src/wire/bytes.rs b/tvix/nix-compat/src/wire/bytes.rs new file mode 100644 index 0000000000..c720f912ee --- /dev/null +++ b/tvix/nix-compat/src/wire/bytes.rs @@ -0,0 +1,130 @@ +use std::ops::RangeBounds; + +use tokio::io::AsyncReadExt; + +use super::primitive; + +#[allow(dead_code)] +/// Read a limited number of bytes from the AsyncRead. +/// Rejects reading more than `allowed_size` bytes of payload. +/// Internally takes care of dealing with the padding, so the returned Vec +/// only contains the payload. +/// This always buffers the entire contents into memory, we'll add a streaming +/// version later. +pub async fn read_bytes(r: &mut R, allowed_size: S) -> std::io::Result> +where + R: AsyncReadExt + Unpin, + S: RangeBounds, +{ + // read the length field + let len = primitive::read_u64(r).await?; + + if !allowed_size.contains(&len) { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "signalled package size not in allowed range", + )); + } + + // calculate the total length, including padding. + // byte packets are padded to 8 byte blocks each. + let padded_len = if len % 8 == 0 { + len + } else { + len + (8 - len % 8) + }; + + let mut limited_reader = r.take(padded_len); + + let mut buf = Vec::new(); + + let s = limited_reader.read_to_end(&mut buf).await?; + + // make sure we got exactly the number of bytes, and not less. + if s as u64 != padded_len { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "got less bytes than expected", + )); + } + + let (_content, padding) = buf.split_at(len as usize); + + // ensure the padding is all zeroes. + if !padding.iter().all(|e| *e == b'\0') { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "padding is not all zeroes", + )); + } + + // return the data without the padding + buf.truncate(len as usize); + Ok(buf) +} + +#[allow(dead_code)] +/// Read an unlimited number of bytes from the AsyncRead. +/// Note this can exhaust memory. +/// Internally uses [read_bytes], which takes care of dealing with the padding, +/// so the returned Vec only contains the payload. +pub async fn read_bytes_unchecked(r: &mut R) -> std::io::Result> { + read_bytes(r, 0u64..).await +} + +#[cfg(test)] +mod tests { + use tokio_test::io::Builder; + + use super::*; + use hex_literal::hex; + + #[tokio::test] + async fn test_read_8_bytes_unchecked() { + let mut mock = Builder::new() + .read(&8u64.to_le_bytes()) + .read(&12345678u64.to_le_bytes()) + .build(); + + assert_eq!( + &12345678u64.to_le_bytes(), + read_bytes_unchecked(&mut mock).await.unwrap().as_slice() + ); + } + + #[tokio::test] + async fn test_read_9_bytes_unchecked() { + let mut mock = Builder::new() + .read(&9u64.to_le_bytes()) + .read(&hex!("01020304050607080900000000000000")) + .build(); + + assert_eq!( + hex!("010203040506070809"), + read_bytes_unchecked(&mut mock).await.unwrap().as_slice() + ); + } + + #[tokio::test] + async fn test_read_0_bytes_unchecked() { + // A empty byte packet is essentially just the 0 length field. + // No data is read, and there's zero padding. + let mut mock = Builder::new().read(&0u64.to_le_bytes()).build(); + + assert_eq!( + hex!(""), + read_bytes_unchecked(&mut mock).await.unwrap().as_slice() + ); + } + + #[tokio::test] + /// Ensure we don't read any further than the size field if the length + /// doesn't match the range we want to accept. + async fn test_reject_too_large() { + let mut mock = Builder::new().read(&100u64.to_le_bytes()).build(); + + read_bytes(&mut mock, 10..10) + .await + .expect_err("expect this to fail"); + } +} diff --git a/tvix/nix-compat/src/wire/mod.rs b/tvix/nix-compat/src/wire/mod.rs index e0b184c78a..9444ebbcfe 100644 --- a/tvix/nix-compat/src/wire/mod.rs +++ b/tvix/nix-compat/src/wire/mod.rs @@ -1,5 +1,8 @@ //! Module parsing and emitting the wire format used by Nix, both in the //! nix-daemon protocol as well as in the NAR format. +#[cfg(feature = "async")] +pub mod bytes; + #[cfg(feature = "async")] pub mod primitive; -- cgit 1.4.1