diff --git a/src/caching_downloader.rs b/src/caching_downloader.rs new file mode 100644 index 0000000..cca2241 --- /dev/null +++ b/src/caching_downloader.rs @@ -0,0 +1,115 @@ +use std::{ + collections::HashMap, + fs, mem, + path::{Path, PathBuf}, +}; + +use anyhow::{ensure, Context}; +use downloader::{Download, DownloadSummary, Downloader, Error}; +use rand::random; + +pub struct CachingDownloader { + inner: Downloader, + cache: PathBuf, + base: PathBuf, +} + +impl CachingDownloader { + /// + /// This downloader will put all downloaded files into a cache directory, before moving them to + /// their actual target location. If a download fails, any corresponding file will not be + /// touched. + /// + /// The `inner` is the downloader used to perform the actual download. The `cache` directory is + /// the location to put the downloaded files in. Individual downloads will create subdirectories + /// in it, to avoid conflicts. The `base` is the path relative file names are relative to. + pub fn new(inner: Downloader, cache: &Path, base: &Path) -> anyhow::Result { + let cache_ready = cache.is_dir(); + ensure!( + cache_ready || !cache.exists(), + "cache directory is neither directory nor nonexistant ({})", + cache.display() + ); + + if !cache_ready { + fs::create_dir_all(cache).context(format!( + "failed to create cache directory ({})", + cache.display() + ))?; + } + + ensure!(base.is_dir(), "base directory doesn't exist"); + + Ok(Self { + inner, + cache: cache.to_path_buf(), + base: base.to_path_buf(), + }) + } + + pub fn download( + &mut self, + downloads: &mut [Download], + partiton: Option<&str>, + ) -> anyhow::Result>> { + let partition = partiton + .map(ToString::to_string) + .unwrap_or_else(|| format!("{:0>16x}", random::())); + let cache = self.cache.join(partition); + + ensure!(!cache.exists(), "cache partition exists"); + fs::create_dir(&cache).context("failed to create cache partition")?; + + let mut mapping: HashMap<_, _> = downloads + .iter_mut() + .enumerate() + .map(|(counter, down)| (cache.join(format!("{counter:0>16x}")), down)) + .map(|(cache, down)| (cache.clone(), mem::replace(&mut down.file_name, cache))) + .collect(); + + let mut results: Vec<_> = self + .inner + .download(downloads) + .context("all downloads failed")? + .into_iter() + .map(|r| handle_file(&mut mapping, r, &self.base)) + .collect(); + + for (leftover, _) in mapping { + if let Err(err) = + fs::remove_file(leftover).context("failed to delete leftover cache file") + { + results.push(Err(err)); + } + } + + if let Err(err) = fs::remove_dir(cache).context("failed to delete cache partition") { + results.push(Err(err)); + } + + Ok(results) + } +} + +fn handle_file( + mapping: &mut HashMap, + summary: Result, + base: &Path, +) -> anyhow::Result { + let mut summary = summary.context("download failed")?; + + let cache = mapping + .remove(&summary.file_name) + .context("unknown target location")?; + let mut cache = base.join(cache); + + mem::swap(&mut summary.file_name, &mut cache); + + fs::hard_link(&cache, &summary.file_name) + .or_else(|_| fs::copy(&cache, &summary.file_name).map(|_| ())) + .context("failed to copy downloaded file to target location")?; + + fs::remove_file(cache).context("failed to remove cached file")?; + + Ok(summary) +} diff --git a/src/config.rs b/src/config.rs index 33b3203..bd9e49e 100644 --- a/src/config.rs +++ b/src/config.rs @@ -10,7 +10,7 @@ use anyhow::{anyhow, Context}; use downloader::Downloader; use serde::{Deserialize, Serialize}; -use crate::{PROJ_DIRS, USER_DIRS}; +use crate::{caching_downloader::CachingDownloader, PROJ_DIRS, USER_DIRS}; /// Global configuration values. #[derive(Debug, Deserialize, Serialize)] @@ -73,6 +73,16 @@ impl Config { .build() .context("failed to build downloader") } + + pub fn default_caching_downloader(&self) -> anyhow::Result { + let dirs = PROJ_DIRS.get().expect("directories not initialized"); + self.caching_downloader(dirs.cache_dir()) + } + + pub fn caching_downloader(&self, cache: &Path) -> anyhow::Result { + self.downloader() + .and_then(|d| CachingDownloader::new(d, cache, &self.base_directory)) + } } impl Default for Config { diff --git a/src/main.rs b/src/main.rs index 548cd18..89f8734 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,4 +1,4 @@ -use std::{collections::HashMap, fs, path::PathBuf, sync::OnceLock}; +use std::{path::PathBuf, sync::OnceLock}; use anyhow::Context; use clap::{Parser, Subcommand}; @@ -9,6 +9,7 @@ use target::Target; use target_list::TargetList; use url::Url; +mod caching_downloader; mod config; mod persistent_state; mod target; @@ -29,7 +30,10 @@ fn main() -> anyhow::Result<()> { .set(user_dirs) .ok() .context("failed to initialize user directories")?; - let proj_dirs = PROJ_DIRS.get_or_init(|| proj_dirs); + PROJ_DIRS + .set(proj_dirs) + .ok() + .context("failedto initialize program directories")?; if let Cmd::License = cli.cmd { println!("{}", include_str!("../LICENSE")); @@ -38,7 +42,9 @@ fn main() -> anyhow::Result<()> { // prepare for operation let cfg = Config::read_from_default_file().context("failed to load config")?; - let mut downloader = cfg.downloader().context("failed to create downloader")?; + let mut downloader = cfg + .default_caching_downloader() + .context("failed to create downloader")?; let mut persistent = PersistentState::read_from_default_file().context("failed to load persistent state")?; @@ -50,82 +56,18 @@ fn main() -> anyhow::Result<()> { println!("{persistent}"); } Cmd::Download { name } => { - let mut cache = proj_dirs.cache_dir().to_path_buf(); - cache.push(&format!("{:0>16x}", rand::random::())); - fs::create_dir_all(&cache).context("failed to create cache dir")?; - let name = name .as_deref() .or(persistent.list()) .context("no list specified or selected")?; - let list = TargetList::load(name).context("failed to load list")?; + let list = TargetList::load(name).context("failed to load target list")?; let mut downloads = list.downloads(); - let mut mapping: HashMap<_, _> = downloads - .iter_mut() - .enumerate() - .map(|(counter, value)| { - let mut cache_path = cache.clone(); - cache_path.push(format!("{counter:0>16x}")); - (cache_path, value) - }) - .map(|(cache, down)| { - let target = std::mem::replace(&mut down.file_name, cache.clone()); - (cache, target) - }) - .collect(); - - let results = downloader - .download(&downloads) - .context("all downloads failed")?; - - for res in results { - if res.is_err() { - eprintln!("{:?}", res.context("download_failed").unwrap_err()); - continue; - } - - let res = res.unwrap(); - let res = mapping - .remove(&res.file_name) - .context("target file name missing") - .map(|target| { - if target.is_absolute() { - target - } else { - let mut path = cfg.base_directory.clone(); - path.push(target); - path - } - }) - .and_then(|target| { - fs::hard_link(&res.file_name, &target) - .context("failed to hard-link result") // for same type as below - .or_else(|_| { - fs::copy(&res.file_name, &target) - .map(|_| ()) - .context("failed to copy result") - }) - .and_then(|_| { - fs::remove_file(&res.file_name) - .context("failed to delete cached result") - }) - }); - + for res in downloader.download(&mut downloads, None)? { if let Err(err) = res { eprintln!("{:?}", err); } } - - for (leftover, _) in mapping { - if let Err(err) = - fs::remove_file(leftover).context("failed to delete leftover cache file") - { - eprintln!("{err:?}"); - } - } - - fs::remove_dir(cache).context("failed to delete cache directory")?; } Cmd::List { cmd } => match cmd { ListCommand::Create {