about summary refs log tree commit diff
path: root/tvix/castore/src/import/archive.rs
diff options
context:
space:
mode:
Diffstat (limited to 'tvix/castore/src/import/archive.rs')
-rw-r--r--tvix/castore/src/import/archive.rs100
1 files changed, 90 insertions, 10 deletions
diff --git a/tvix/castore/src/import/archive.rs b/tvix/castore/src/import/archive.rs
index eab695d650..84e280fafb 100644
--- a/tvix/castore/src/import/archive.rs
+++ b/tvix/castore/src/import/archive.rs
@@ -1,28 +1,45 @@
+use std::io::{Cursor, Write};
+use std::sync::Arc;
 use std::{collections::HashMap, path::PathBuf};
 
 use petgraph::graph::{DiGraph, NodeIndex};
 use petgraph::visit::{DfsPostOrder, EdgeRef};
 use petgraph::Direction;
 use tokio::io::AsyncRead;
+use tokio::sync::Semaphore;
+use tokio::task::JoinSet;
 use tokio_stream::StreamExt;
 use tokio_tar::Archive;
+use tokio_util::io::InspectReader;
 use tracing::{instrument, warn, Level};
 
 use crate::blobservice::BlobService;
 use crate::directoryservice::DirectoryService;
 use crate::import::{ingest_entries, Error, IngestionEntry};
 use crate::proto::node::Node;
+use crate::B3Digest;
+
+/// Files smaller than this threshold, in bytes, are uploaded to the [BlobService] in the
+/// background.
+///
+/// This is a u32 since we acquire a weighted semaphore using the size of the blob.
+/// [Semaphore::acquire_many_owned] takes a u32, so we need to ensure the size of
+/// the blob can be represented using a u32 and will not cause an overflow.
+const CONCURRENT_BLOB_UPLOAD_THRESHOLD: u32 = 1024 * 1024;
+
+/// The maximum amount of bytes allowed to be buffered in memory to perform async blob uploads.
+const MAX_TARBALL_BUFFER_SIZE: usize = 128 * 1024 * 1024;
 
 /// Ingests elements from the given tar [`Archive`] into a the passed [`BlobService`] and
 /// [`DirectoryService`].
 #[instrument(skip_all, ret(level = Level::TRACE), err)]
-pub async fn ingest_archive<'a, BS, DS, R>(
+pub async fn ingest_archive<BS, DS, R>(
     blob_service: BS,
     directory_service: DS,
     mut archive: Archive<R>,
 ) -> Result<Node, Error>
 where
-    BS: AsRef<dyn BlobService> + Clone,
+    BS: BlobService + Clone + 'static,
     DS: AsRef<dyn DirectoryService>,
     R: AsyncRead + Unpin,
 {
@@ -33,21 +50,80 @@ where
     // In the first phase, collect up all the regular files and symlinks.
     let mut nodes = IngestionEntryGraph::new();
 
+    let semaphore = Arc::new(Semaphore::new(MAX_TARBALL_BUFFER_SIZE));
+    let mut async_blob_uploads: JoinSet<Result<(), Error>> = JoinSet::new();
+
     let mut entries_iter = archive.entries().map_err(Error::Archive)?;
     while let Some(mut entry) = entries_iter.try_next().await.map_err(Error::Archive)? {
         let path: PathBuf = entry.path().map_err(Error::Archive)?.into();
 
-        let entry = match entry.header().entry_type() {
+        let header = entry.header();
+        let entry = match header.entry_type() {
             tokio_tar::EntryType::Regular
             | tokio_tar::EntryType::GNUSparse
             | tokio_tar::EntryType::Continuous => {
-                // TODO: If the same path is overwritten in the tarball, we may leave
-                // an unreferenced blob after uploading.
-                let mut writer = blob_service.as_ref().open_write().await;
-                let size = tokio::io::copy(&mut entry, &mut writer)
-                    .await
-                    .map_err(Error::Archive)?;
-                let digest = writer.close().await.map_err(Error::Archive)?;
+                let header_size = header.size().map_err(Error::Archive)?;
+
+                // If the blob is small enough, read it off the wire, compute the digest,
+                // and upload it to the [BlobService] in the background.
+                let (size, digest) = if header_size <= CONCURRENT_BLOB_UPLOAD_THRESHOLD as u64 {
+                    let mut buffer = Vec::with_capacity(header_size as usize);
+                    let mut hasher = blake3::Hasher::new();
+                    let mut reader = InspectReader::new(&mut entry, |bytes| {
+                        hasher.write_all(bytes).unwrap();
+                    });
+
+                    // Ensure that we don't buffer into memory until we've acquired a permit.
+                    // This prevents consuming too much memory when performing concurrent
+                    // blob uploads.
+                    let permit = semaphore
+                        .clone()
+                        // This cast is safe because ensure the header_size is less than
+                        // CONCURRENT_BLOB_UPLOAD_THRESHOLD which is a u32.
+                        .acquire_many_owned(header_size as u32)
+                        .await
+                        .unwrap();
+                    let size = tokio::io::copy(&mut reader, &mut buffer)
+                        .await
+                        .map_err(Error::Archive)?;
+
+                    let digest: B3Digest = hasher.finalize().as_bytes().into();
+
+                    {
+                        let blob_service = blob_service.clone();
+                        let digest = digest.clone();
+                        async_blob_uploads.spawn({
+                            async move {
+                                let mut writer = blob_service.open_write().await;
+
+                                tokio::io::copy(&mut Cursor::new(buffer), &mut writer)
+                                    .await
+                                    .map_err(Error::Archive)?;
+
+                                let blob_digest = writer.close().await.map_err(Error::Archive)?;
+
+                                assert_eq!(digest, blob_digest, "Tvix bug: blob digest mismatch");
+
+                                // Make sure we hold the permit until we finish writing the blob
+                                // to the [BlobService].
+                                drop(permit);
+                                Ok(())
+                            }
+                        });
+                    }
+
+                    (size, digest)
+                } else {
+                    let mut writer = blob_service.open_write().await;
+
+                    let size = tokio::io::copy(&mut entry, &mut writer)
+                        .await
+                        .map_err(Error::Archive)?;
+
+                    let digest = writer.close().await.map_err(Error::Archive)?;
+
+                    (size, digest)
+                };
 
                 IngestionEntry::Regular {
                     path,
@@ -77,6 +153,10 @@ where
         nodes.add(entry)?;
     }
 
+    while let Some(result) = async_blob_uploads.join_next().await {
+        result.expect("task panicked")?;
+    }
+
     ingest_entries(
         directory_service,
         futures::stream::iter(nodes.finalize()?.into_iter().map(Ok)),