diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..90f212f --- /dev/null +++ b/.env.example @@ -0,0 +1 @@ +DATABASE_URL=sqlite:./run/database.sqlite \ No newline at end of file diff --git a/.gitignore b/.gitignore index c19116f..e310c9d 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ /target .idea/ -/run/ \ No newline at end of file +/run/ +*.env \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock index e93c26b..03f296c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -122,7 +122,7 @@ checksum = "c7c24de15d275a1ecfd47a380fb4d5ec9bfe0933f309ed5e705b775596a3574d" dependencies = [ "proc-macro2", "quote", - "syn", + "syn", ] [[package]] @@ -133,7 +133,7 @@ checksum = "e539d3fca749fcee5236ab05e93a52867dd549cc157c8cb7f99595f3cedffdb5" dependencies = [ "proc-macro2", "quote", - "syn", + "syn", ] [[package]] @@ -252,22 +252,13 @@ version = "1.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d71b6127be86fdcfddb610f7182ac57211d4b18a3e9c82eb2d17662f2227ad6a" -[[package]] -name = "bytesize" -version = "2.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a3c8f83209414aacf0eeae3cf730b18d6981697fba62f200fcfb92b9f082acba" -dependencies = [ - "serde", -] - [[package]] name = "camino" version = "1.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b96ec4966b5813e2c0507c1f86115c8c5abaadc3980879c3424042a02fd1ad3" dependencies = [ - "serde", + "serde", ] [[package]] @@ -342,7 +333,7 @@ version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4ca0197aee26d1ae37445ee532fefce43251d24cc7c166799f4d46817f1d3973" dependencies = [ - "crossbeam-utils", + "crossbeam-utils", ] [[package]] @@ -374,8 +365,8 @@ version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b55271e5c8c478ad3f38ad24ef34923091e0548492a266d19b3c0b4d82574c63" dependencies = [ - "core-foundation-sys", - "libc", + "core-foundation-sys", + "libc", ] [[package]] @@ -454,7 +445,7 @@ dependencies = [ "proc-macro2", "quote", "strsim", - "syn", + "syn", ] [[package]] @@ -465,7 +456,21 @@ checksum = "d336a2a514f6ccccaa3e09b02d41d35330c07ddf03a62165fcec10bb561c7806" dependencies = [ "darling_core", "quote", - "syn", + "syn", +] + +[[package]] +name = "dashmap" +version = "7.0.0-rc2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e4a1e35a65fe0538a60167f0ada6e195ad5d477f6ddae273943596d4a1a5730b" +dependencies = [ + "cfg-if", + "crossbeam-utils", + "equivalent", + "hashbrown", + "lock_api", + "parking_lot_core", ] [[package]] @@ -518,7 +523,7 @@ dependencies = [ "proc-macro2", "proc-macro2-diagnostics", "quote", - "syn", + "syn", ] [[package]] @@ -541,7 +546,7 @@ checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" dependencies = [ "proc-macro2", "quote", - "syn", + "syn", ] [[package]] @@ -624,9 +629,9 @@ version = "5.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3492acde4c3fc54c845eaab3eed8bd00c7a7d881f78bfc801e43a93dec1331ae" dependencies = [ - "concurrent-queue", - "parking", - "pin-project-lite", + "concurrent-queue", + "parking", + "pin-project-lite", ] [[package]] @@ -698,7 +703,7 @@ version = "0.2.13" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "17e4d821c226048ee351e9f7cf8554b5f226ca41b35ffe632eddaa3a8934da2f" dependencies = [ - "serde", + "serde", ] [[package]] @@ -767,7 +772,7 @@ checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" dependencies = [ "proc-macro2", "quote", - "syn", + "syn", ] [[package]] @@ -900,8 +905,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bf151400ff0baff5465007dd2f3e717f3fe502074ca563069ce3a6629d07b289" dependencies = [ "allocator-api2", - "equivalent", - "foldhash", + "equivalent", + "foldhash", ] [[package]] @@ -910,7 +915,7 @@ version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7382cf6263419f2d8df38c55d7da83da5c18aef87fc7a7fc1fb1e344edfe14c1" dependencies = [ - "hashbrown", + "hashbrown", ] [[package]] @@ -1085,12 +1090,12 @@ dependencies = [ "http 1.3.1", "hyper 1.6.0", "hyper-util", - "rustls", + "rustls", "rustls-pki-types", "tokio", "tokio-rustls", "tower-service", - "webpki-roots", + "webpki-roots", ] [[package]] @@ -1250,7 +1255,7 @@ checksum = "1ec89e9337638ecdc08744df490b221a7399bf8d164eb52a665454e60e075ad6" dependencies = [ "proc-macro2", "quote", - "syn", + "syn", ] [[package]] @@ -1293,7 +1298,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3954d50fe15b02142bf25d3b8bdadb634ec3948f103d04ffe3031bc8fe9d7058" dependencies = [ "equivalent", - "hashbrown", + "hashbrown", "serde", ] @@ -1353,7 +1358,7 @@ checksum = "8d16e75759ee0aa64c57a56acbf43916987b20c77373cb7e808979e02b93c9f9" dependencies = [ "proc-macro2", "quote", - "syn", + "syn", ] [[package]] @@ -1485,23 +1490,26 @@ dependencies = [ name = "minna_caos" version = "0.1.0" dependencies = [ - "bytesize", - "camino", + "async-trait", + "camino", "color-eyre", - "constant_time_eq", + "constant_time_eq", + "dashmap", "env_logger", "figment", - "fstr", + "fstr", "log", - "nanoid", "once_cell", "opendal", + "rand 0.9.0", "regex", + "replace_with", "rocket", "serde", - "serde_regex", + "serde_json", "sqlx", "tokio", + "tokio-util", "validator", ] @@ -1535,15 +1543,6 @@ dependencies = [ "version_check", ] -[[package]] -name = "nanoid" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3ffa00dec017b5b1a8b7cf5e2c008bfda1aa7e0697ac1508b491fdf2622fb4d8" -dependencies = [ - "rand 0.8.5", -] - [[package]] name = "nu-ansi-term" version = "0.46.0" @@ -1641,7 +1640,7 @@ dependencies = [ "anyhow", "async-trait", "backon", - "base64", + "base64", "bytes", "chrono", "futures", @@ -1726,7 +1725,7 @@ dependencies = [ "proc-macro2", "proc-macro2-diagnostics", "quote", - "syn", + "syn", ] [[package]] @@ -1810,7 +1809,7 @@ version = "0.2.21" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "85eae3c4ed2f50dcfe72643da4befc30deadb458a9b590d720cde2f2b1e97da9" dependencies = [ - "zerocopy", + "zerocopy", ] [[package]] @@ -1832,7 +1831,7 @@ dependencies = [ "proc-macro-error-attr2", "proc-macro2", "quote", - "syn", + "syn", ] [[package]] @@ -1852,7 +1851,7 @@ checksum = "af066a9c399a26e020ada66a034357a868728e72cd426f3adcd35f80d88d88c8" dependencies = [ "proc-macro2", "quote", - "syn", + "syn", "version_check", "yansi", ] @@ -1879,9 +1878,9 @@ dependencies = [ "quinn-proto", "quinn-udp", "rustc-hash", - "rustls", + "rustls", "socket2", - "thiserror", + "thiserror", "tokio", "tracing", "web-time", @@ -1898,10 +1897,10 @@ dependencies = [ "rand 0.9.0", "ring", "rustc-hash", - "rustls", + "rustls", "rustls-pki-types", "slab", - "thiserror", + "thiserror", "tinyvec", "tracing", "web-time", @@ -1955,7 +1954,7 @@ checksum = "3779b94aeb87e8bd4e834cee3650289ee9e0d5677f976ecdb6d219e5f4f6cd94" dependencies = [ "rand_chacha 0.9.0", "rand_core 0.9.3", - "zerocopy", + "zerocopy", ] [[package]] @@ -2022,7 +2021,7 @@ checksum = "1165225c21bff1f3bbce98f5a1f889949bc902d3575308cc7b0de30b4f6d27c7" dependencies = [ "proc-macro2", "quote", - "syn", + "syn", ] [[package]] @@ -2069,13 +2068,19 @@ version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" +[[package]] +name = "replace_with" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3a8614ee435691de62bcffcf4a66d91b3594bf1428a5722e79103249a095690" + [[package]] name = "reqwest" version = "0.12.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d19c46a6fdd48bc4dab94b6103fccc55d34c67cc0ad04653aad4ea2a07cd7bbb" dependencies = [ - "base64", + "base64", "bytes", "futures-core", "futures-util", @@ -2093,8 +2098,8 @@ dependencies = [ "percent-encoding", "pin-project-lite", "quinn", - "rustls", - "rustls-pemfile", + "rustls", + "rustls-pemfile", "rustls-pki-types", "serde", "serde_json", @@ -2110,7 +2115,7 @@ dependencies = [ "wasm-bindgen-futures", "wasm-streams", "web-sys", - "webpki-roots", + "webpki-roots", "windows-registry", ] @@ -2178,7 +2183,7 @@ dependencies = [ "proc-macro2", "quote", "rocket_http", - "syn", + "syn", "unicode-xid", "version_check", ] @@ -2264,7 +2269,7 @@ dependencies = [ "once_cell", "ring", "rustls-pki-types", - "rustls-webpki", + "rustls-webpki", "subtle", "zeroize", ] @@ -2275,10 +2280,10 @@ version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7fcff2dd52b58a8d98a70243663a0d234c4e2b79235637849d15913394a247d3" dependencies = [ - "openssl-probe", - "rustls-pki-types", - "schannel", - "security-framework", + "openssl-probe", + "rustls-pki-types", + "schannel", + "security-framework", ] [[package]] @@ -2328,7 +2333,7 @@ version = "0.1.27" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1f29ebaa345f945cec9fbbc532eb307f0fdad8161f281b6369539c8d84876b3d" dependencies = [ - "windows-sys 0.59.0", + "windows-sys 0.59.0", ] [[package]] @@ -2349,11 +2354,11 @@ version = "3.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "271720403f46ca04f7ba6f55d438f8bd878d6b8ca0a1046e8228c4145bcbb316" dependencies = [ - "bitflags", - "core-foundation", - "core-foundation-sys", - "libc", - "security-framework-sys", + "bitflags", + "core-foundation", + "core-foundation-sys", + "libc", + "security-framework-sys", ] [[package]] @@ -2362,8 +2367,8 @@ version = "2.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "49db231d56a190491cb4aeda9527f1ad45345af50b0851622a7adb8c03b01c32" dependencies = [ - "core-foundation-sys", - "libc", + "core-foundation-sys", + "libc", ] [[package]] @@ -2383,7 +2388,7 @@ checksum = "5b0276cf7f2c73365f7157c8123c21cd9a50fbbd844757af28ca1f5925fc2a00" dependencies = [ "proc-macro2", "quote", - "syn", + "syn", ] [[package]] @@ -2398,16 +2403,6 @@ dependencies = [ "serde", ] -[[package]] -name = "serde_regex" -version = "1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a8136f1a4ea815d7eac4101cfd0b16dc0cb5e1fe1b8609dfd728058656b7badf" -dependencies = [ - "regex", - "serde", -] - [[package]] name = "serde_spanned" version = "0.6.8" @@ -2500,7 +2495,7 @@ version = "1.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7fcf8323ef1faaee30a44a340193b1ac6814fd9b7b4e88e9d4519a3e4abe1cfd" dependencies = [ - "serde", + "serde", ] [[package]] @@ -2560,21 +2555,23 @@ dependencies = [ "futures-intrusive", "futures-io", "futures-util", - "hashbrown", + "hashbrown", "hashlink", "indexmap", "log", "memchr", "once_cell", "percent-encoding", - "rustls", - "rustls-native-certs", - "rustls-pemfile", + "rustls", + "rustls-native-certs", + "rustls-pemfile", "serde", "serde_json", "sha2", "smallvec", - "thiserror", + "thiserror", + "tokio", + "tokio-stream", "tracing", "url", ] @@ -2589,7 +2586,7 @@ dependencies = [ "quote", "sqlx-core", "sqlx-macros-core", - "syn", + "syn", ] [[package]] @@ -2610,10 +2607,11 @@ dependencies = [ "sha2", "sqlx-core", "sqlx-mysql", - "sqlx-postgres", + "sqlx-postgres", "sqlx-sqlite", - "syn", + "syn", "tempfile", + "tokio", "url", ] @@ -2624,7 +2622,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4560278f0e00ce64938540546f59f590d60beee33fffbd3b9cd47851e5fff233" dependencies = [ "atoi", - "base64", + "base64", "bitflags", "byteorder", "bytes", @@ -2654,7 +2652,7 @@ dependencies = [ "smallvec", "sqlx-core", "stringprep", - "thiserror", + "thiserror", "tracing", "whoami", ] @@ -2666,7 +2664,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c5b98a57f363ed6764d5b3a12bfedf62f07aa16e1856a7ddc2a0bb190a959613" dependencies = [ "atoi", - "base64", + "base64", "bitflags", "byteorder", "crc", @@ -2691,7 +2689,7 @@ dependencies = [ "smallvec", "sqlx-core", "stringprep", - "thiserror", + "thiserror", "tracing", "whoami", ] @@ -2713,7 +2711,7 @@ dependencies = [ "log", "percent-encoding", "serde", - "serde_urlencoded", + "serde_urlencoded", "sqlx-core", "tracing", "url", @@ -2794,7 +2792,7 @@ checksum = "c8af7666ab7b6390ab78131fb5b0fce11d6b7a6951602017c35fa82800708971" dependencies = [ "proc-macro2", "quote", - "syn", + "syn", ] [[package]] @@ -2816,7 +2814,7 @@ version = "2.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "567b8a2dae586314f7be2a752ec7474332959c6460e02bde30d702a66d488708" dependencies = [ - "thiserror-impl", + "thiserror-impl", ] [[package]] @@ -2827,7 +2825,7 @@ checksum = "7f7cf42b4507d8ea322120659672cf1b9dbb93f8f2d4ecfd6e51350ff5b17a1d" dependencies = [ "proc-macro2", "quote", - "syn", + "syn", ] [[package]] @@ -2922,7 +2920,7 @@ checksum = "6e06d43f1345a3bcd39f6a56dbb7dcab2ba47e68e8ac134855e7e2bdbaf8cab8" dependencies = [ "proc-macro2", "quote", - "syn", + "syn", ] [[package]] @@ -2931,7 +2929,7 @@ version = "0.26.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8e727b36a1a0e8b74c376ac2211e40c2c8af09fb4013c60d910495810f008e9b" dependencies = [ - "rustls", + "rustls", "tokio", ] @@ -3040,7 +3038,7 @@ checksum = "395ae124c09f9e6918a2310af6038fba074bcf474ac352496d5910dd59a2226d" dependencies = [ "proc-macro2", "quote", - "syn", + "syn", ] [[package]] @@ -3228,7 +3226,7 @@ dependencies = [ "proc-macro-error2", "proc-macro2", "quote", - "syn", + "syn", ] [[package]] @@ -3301,7 +3299,7 @@ dependencies = [ "log", "proc-macro2", "quote", - "syn", + "syn", "wasm-bindgen-shared", ] @@ -3336,7 +3334,7 @@ checksum = "8ae87ea40c9f689fc23f209965b6fb8a99ad69aeeb0231408be24920604395de" dependencies = [ "proc-macro2", "quote", - "syn", + "syn", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -3748,7 +3746,7 @@ checksum = "2380878cad4ac9aac1e2435f3eb4020e8374b5f13c296cb75b4620ff8e229154" dependencies = [ "proc-macro2", "quote", - "syn", + "syn", "synstructure", ] @@ -3758,7 +3756,7 @@ version = "0.8.24" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2586fea28e186957ef732a5f8b3be2da217d65c5969d4b1e17f973ebbe876879" dependencies = [ - "zerocopy-derive", + "zerocopy-derive", ] [[package]] @@ -3769,7 +3767,7 @@ checksum = "a996a8f63c5c4448cd959ac1bab0aaa3306ccfd060472f85943ee0750f0169be" dependencies = [ "proc-macro2", "quote", - "syn", + "syn", ] [[package]] @@ -3789,7 +3787,7 @@ checksum = "d71e5d6e06ab090c67b5e44993ec16b72dcbaabc526db883a360057678b48502" dependencies = [ "proc-macro2", "quote", - "syn", + "syn", "synstructure", ] @@ -3818,5 +3816,5 @@ checksum = "6eafa6dfb17584ea3e2bd6e76e0cc15ad7af12b09abdd1ca55961bed9b1063c6" dependencies = [ "proc-macro2", "quote", - "syn", + "syn", ] diff --git a/Cargo.toml b/Cargo.toml index ed6d032..b0c7754 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,8 +3,11 @@ name = "minna_caos" version = "0.1.0" edition = "2024" +[profile.dev.package.sqlx-macros] +opt-level = 3 + [dependencies] -sqlx = { version = "0.8.3", features = ["tls-rustls-ring-native-roots", "sqlite"] } +sqlx = { version = "0.8.3", features = ["tls-rustls-ring-native-roots", "sqlite", "runtime-tokio"] } rocket = { version = "0.5.1", default-features = false, features = ["http2", "json"] } opendal = { version = "0.52.0", features = ["services-fs"] } tokio = { version = "1.44.1", features = ["rt-multi-thread", "macros", "parking_lot"] } @@ -16,9 +19,12 @@ serde = { version = "1.0.219", features = ["derive"] } validator = { version = "0.20.0", features = ["derive"] } once_cell = "1.21.1" regex = "1.11.1" -bytesize = { version = "2.0.1", features = ["serde"] } -serde_regex = "1.1.0" +serde_json = "1.0.140" constant_time_eq = "0.4.2" fstr = { version = "0.2.13", features = ["serde"] } camino = { version = "1.1.9", features = ["serde1"] } -nanoid = "0.4.0" +dashmap = "7.0.0-rc2" +tokio-util = "0.7.14" +replace_with = "0.1.7" +async-trait = "0.1.88" +rand = "0.9.0" \ No newline at end of file diff --git a/README.md b/README.md index 921adfb..afd824d 100644 --- a/README.md +++ b/README.md @@ -37,11 +37,14 @@ - app to caos: `GET /staging-area/{upload_id}`, returns metadata (including `{hash}`) as soon as the upload is complete - app to caos: `POST /staging-area/{upload_id}/accept` with target bucket IDs + ## Roadmap - basic uploading - upload expiration - media type detection +- graceful shutdown - metadata endpoints - accepting uploads +- add code comments - more storage backends \ No newline at end of file diff --git a/migrations/20250321201214_initial.sql b/migrations/20250321201214_initial.sql index 73765fa..65930e1 100644 --- a/migrations/20250321201214_initial.sql +++ b/migrations/20250321201214_initial.sql @@ -16,19 +16,11 @@ create table object_replicas foreign key (hash) references objects (hash) on delete restrict on update restrict ) strict; -create table ongoing_uploads -( - id text not null, - current_size integer not null, -- in bytes - total_size integer, -- in bytes, or null if the upload was not started yet - primary key (id) -) without rowid, strict; - -create table finished_uploads +create table uploads ( id text not null, - size integer not null, -- in bytes - hash text not null, -- BLAKE3, 265 bits, base 16 - media_type text not null, -- RFC 6838 format - primary key (id) -) without rowid, strict; \ No newline at end of file + total_size integer not null, -- in bytes + hash text, -- null if the upload is not finished yet or the hash simply was not calculated yet + primary key (id), + foreign key (hash) references objects (hash) on delete restrict on update restrict +) without rowid, strict; diff --git a/run/config.toml b/run/config.toml index 740643f..4af7f40 100644 --- a/run/config.toml +++ b/run/config.toml @@ -1,5 +1,8 @@ http_address = "0.0.0.0" http_port = 8001 +api_secret = "Xt99Hp%wU%zf&vczQ%bJPbr2$owC#wuM#7fxEy%Uc%pp4Thdk7V$4kxMJFupvNKk" +database_file = "./database.sqlite" +staging_directory = "./data/staging" [[buckets]] id = "local" diff --git a/rustfmt.toml b/rustfmt.toml new file mode 100644 index 0000000..1cbd329 --- /dev/null +++ b/rustfmt.toml @@ -0,0 +1 @@ +max_width = 160 \ No newline at end of file diff --git a/src/config.rs b/src/config.rs index 4f12e3a..1f38eee 100644 --- a/src/config.rs +++ b/src/config.rs @@ -29,13 +29,7 @@ fn validate_buckets(buckets: &Vec) -> Result<(), ValidationError> for bucket_config in buckets { if !ids.insert(&bucket_config.id) { - return Err(ValidationError::new("duplicate_id").with_message( - format!( - "There is more than one bucket with this ID: {}", - bucket_config.id - ) - .into(), - )); + return Err(ValidationError::new("duplicate_id").with_message(format!("There is more than one bucket with this ID: {}", bucket_config.id).into())); }; } @@ -43,21 +37,26 @@ fn validate_buckets(buckets: &Vec) -> Result<(), ValidationError> } // a-zA-z0-9 and _, but not "staging" -static BUCKET_ID_PATTERN: Lazy = Lazy::new(|| Regex::new(r"^(?!staging$)\w*$").unwrap()); +static BUCKET_ID_PATTERN: Lazy = Lazy::new(|| Regex::new(r"^\w+$").unwrap()); #[derive(Debug, Serialize, Deserialize, Validate)] pub struct ConfigBucket { - #[validate(length(min = 1, max = 32), regex(path = *BUCKET_ID_PATTERN))] + #[validate(length(min = 1, max = 32), regex(path = *BUCKET_ID_PATTERN), custom(function = "validate_config_bucket_id"))] pub id: String, #[validate(length(min = 1, max = 128))] pub display_name: String, - pub size_limit: Option, - #[serde(with = "serde_regex")] - pub media_type_pattern: Option, #[serde(flatten)] pub backend: ConfigBucketBackend, } +fn validate_config_bucket_id(value: &str) -> Result<(), ValidationError> { + if value == "staging" { + return Err(ValidationError::new("illegal_bucket_id").with_message("Illegal bucket ID: staging".into())); + } + + Ok(()) +} + #[derive(Debug, Serialize, Deserialize)] #[serde(tag = "backend", rename_all = "snake_case")] pub enum ConfigBucketBackend { @@ -72,19 +71,11 @@ pub struct ConfigBucketBackendFilesystem { pub fn load_config() -> Result { let figment = Figment::new() .merge(figment::providers::Toml::file("config.toml")) - .merge(figment::providers::Env::prefixed("CAOS_").only(&[ - "HTTP_ADDRESS", - "HTTP_PORT", - "API_SECRET", - ])); + .merge(figment::providers::Env::prefixed("CAOS_").only(&["HTTP_ADDRESS", "HTTP_PORT", "API_SECRET"])); - let config = figment - .extract::() - .wrap_err("Failed to load configuration.")?; + let config = figment.extract::().wrap_err("Failed to load configuration.")?; - config - .validate() - .wrap_err("Failed to validate configuration.")?; + config.validate().wrap_err("Failed to validate configuration.")?; Ok(config) } diff --git a/src/http_api.rs b/src/http_api.rs deleted file mode 100644 index 73883d5..0000000 --- a/src/http_api.rs +++ /dev/null @@ -1,125 +0,0 @@ -use crate::config::Config; -use color_eyre::Result; -use fstr::FStr; -use nanoid::nanoid; -use rocket::form::validate::Len; -use rocket::http::Status; -use rocket::outcome::Outcome::Success; -use rocket::request::{FromRequest, Outcome}; -use rocket::response::Responder; -use rocket::serde::json::Json; -use rocket::{Request, State, post, response, routes}; -use serde::{Deserialize, Serialize}; -use sqlx::SqlitePool; -use std::borrow::Cow; - -pub async fn start_http_api_server(config: &Config, database: SqlitePool) -> Result<()> { - let rocket_app = rocket::custom(rocket::config::Config { - address: config.http_address, - port: config.http_port, - ident: rocket::config::Ident::try_new("minna-caos".to_owned()).unwrap(), - ip_header: if config.trust_http_reverse_proxy { - Some("X-Forwarded-For".into()) - } else { - None - }, - shutdown: rocket::config::Shutdown { - grace: 5, - mercy: 5, - ..rocket::config::Shutdown::default() - }, - keep_alive: 10, - ..rocket::Config::default() - }); - - rocket_app - .manage(CorrectApiSecret(config.api_secret.clone())) - .manage(database) - .mount("/", routes![create_upload]) - .launch() - .await?; - - Ok(()) -} - -#[derive(Debug)] -enum ApiError { - BodyValidationFailed { - path: Cow<'static, str>, - message: Cow<'static, str>, - }, -} - -impl<'r> Responder<'r, 'static> for ApiError { - fn respond_to(self, _: &Request<'_>) -> response::Result<'static> { - todo!() - } -} - -#[derive(Debug, Deserialize)] -struct CreateUploadRequest { - size: u64, -} - -#[derive(Debug, Serialize)] -struct CreateUploadResponse { - upload_id: String, -} - -#[post("/uploads", data = "")] -async fn create_upload( - _accessor: AuthorizedApiAccessor, - database: &State, - request: Json, -) -> Result, ApiError> { - let id = nanoid!(); - - let total_size: i64 = request - .size - .try_into() - .map_err(|_| ApiError::BodyValidationFailed { - path: "size".into(), - message: "".into(), - })?; - - sqlx::query!( - "INSERT INTO ongoing_uploads (id, total_size, current_size) VALUES(?, ?, 0)", - id, - total_size - ) - .execute(database.inner()) - .await - .unwrap(); - - Ok(Json(CreateUploadResponse { upload_id: id })) -} - -struct CorrectApiSecret(FStr<64>); - -struct AuthorizedApiAccessor(); - -#[rocket::async_trait] -impl<'r> FromRequest<'r> for AuthorizedApiAccessor { - type Error = (); - - async fn from_request(request: &'r Request<'_>) -> Outcome { - let provided_secret = request - .headers() - .get_one("Authorization") - .map(|v| v.strip_prefix("Bearer ")) - .take_if(|v| v.len() == 64) - .flatten(); - - let correct_secret = request.rocket().state::().unwrap().0; - if let Some(provided_secret) = provided_secret { - if constant_time_eq::constant_time_eq( - provided_secret.as_bytes(), - correct_secret.as_bytes(), - ) { - return Success(AuthorizedApiAccessor()); - } - } - - Outcome::Error((Status::Forbidden, ())) - } -} diff --git a/src/http_api/api_error.rs b/src/http_api/api_error.rs new file mode 100644 index 0000000..e9d3d97 --- /dev/null +++ b/src/http_api/api_error.rs @@ -0,0 +1,24 @@ +use color_eyre::Report; +use rocket::response::Responder; +use rocket::{Request, response}; +use std::borrow::Cow; + +#[derive(Debug)] +pub enum ApiError { + Internal { report: Report }, + HeaderValidationFailed { name: Cow<'static, str>, message: Cow<'static, str> }, + BodyValidationFailed { path: Cow<'static, str>, message: Cow<'static, str> }, + ResourceNotFound { resource_type: Cow<'static, str>, id: Cow<'static, str> }, +} + +impl From for ApiError { + fn from(report: Report) -> Self { + ApiError::Internal { report } + } +} + +impl<'r> Responder<'r, 'static> for ApiError { + fn respond_to(self, _: &Request<'_>) -> response::Result<'static> { + todo!() + } +} diff --git a/src/http_api/auth.rs b/src/http_api/auth.rs new file mode 100644 index 0000000..908ecbe --- /dev/null +++ b/src/http_api/auth.rs @@ -0,0 +1,33 @@ +use fstr::FStr; +use rocket::Request; +use rocket::form::validate::Len; +use rocket::http::Status; +use rocket::outcome::Outcome::Success; +use rocket::request::{FromRequest, Outcome}; + +pub struct CorrectApiSecret(pub FStr<64>); + +pub struct AuthorizedApiAccessor(); + +#[rocket::async_trait] +impl<'r> FromRequest<'r> for AuthorizedApiAccessor { + type Error = (); + + async fn from_request(request: &'r Request<'_>) -> Outcome { + let provided_secret = request + .headers() + .get_one("Authorization") + .map(|v| v.strip_prefix("Bearer ")) + .take_if(|v| v.len() == 64) + .flatten(); + + let correct_secret = request.rocket().state::().unwrap().0; + if let Some(provided_secret) = provided_secret { + if constant_time_eq::constant_time_eq(provided_secret.as_bytes(), correct_secret.as_bytes()) { + return Success(AuthorizedApiAccessor()); + } + } + + Outcome::Error((Status::Forbidden, ())) + } +} diff --git a/src/http_api/mod.rs b/src/http_api/mod.rs new file mode 100644 index 0000000..6c7ca4b --- /dev/null +++ b/src/http_api/mod.rs @@ -0,0 +1,173 @@ +mod api_error; +mod auth; +mod stream_upload_payload_to_file; +mod upload_headers; + +use crate::http_api::api_error::ApiError; +use crate::http_api::auth::{AuthorizedApiAccessor, CorrectApiSecret}; +use crate::http_api::stream_upload_payload_to_file::{StreamUploadPayloadToFileOutcome, stream_upload_payload_to_file}; +use crate::http_api::upload_headers::{SuppliedOptionalContentLength, SuppliedUploadComplete, SuppliedUploadOffset}; +use crate::upload_manager::{UploadId, UploadManager}; +use color_eyre::{Report, Result}; +use fstr::FStr; +use rocket::data::{DataStream, ToByteUnit}; +use rocket::http::{ContentType, MediaType, Status}; +use rocket::serde::json::Json; +use rocket::{Data, Request, Response, State, patch, post, routes}; +use serde::{Deserialize, Serialize}; +use serde_json::json; +use std::borrow::Cow; +use std::io::ErrorKind; +use std::net::IpAddr; +use tokio::fs::File; +use tokio::io::AsyncSeekExt; +use tokio_util::bytes::Buf; + +pub async fn start_http_api_server(upload_manager: UploadManager, address: IpAddr, port: u16, trust_reverse_proxy: bool, api_secret: FStr<64>) -> Result<()> { + let rocket_app = rocket::custom(rocket::config::Config { + address, + port, + ident: rocket::config::Ident::try_new("minna-caos".to_owned()).unwrap(), + ip_header: if trust_reverse_proxy { Some("X-Forwarded-For".into()) } else { None }, + shutdown: rocket::config::Shutdown { + grace: 5, + mercy: 5, + ..rocket::config::Shutdown::default() + }, + keep_alive: 10, + ..rocket::Config::default() + }); + + rocket_app + .manage(CorrectApiSecret(api_secret)) + .manage(upload_manager) + .mount("/", routes![create_upload, append_upload]) + .launch() + .await?; + + Ok(()) +} + +#[derive(Debug, Deserialize)] +struct CreateUploadPayload { + size: u64, +} + +#[derive(Debug, Serialize)] +struct CreateUploadResponse { + upload_id: UploadId, +} + +#[post("/uploads", data = "")] +async fn create_upload( + _accessor: AuthorizedApiAccessor, + upload_manager: &State, + payload: Json, +) -> Result, ApiError> { + if payload.size < 1 || payload.size > (2 ^ 63 - 1) { + return Err(ApiError::BodyValidationFailed { + path: "size".into(), + message: "size must be in 1..(2^63 - 1)".into(), + }); + } + + let upload = upload_manager.create_upload(payload.size).await?; + + Ok(Json(CreateUploadResponse { upload_id: *upload.id() })) +} + +const PARTIAL_UPLOAD_MEDIA_TYPE: MediaType = MediaType::const_new("application", "partial-upload", &[]); + +#[derive(Debug)] +enum AppendUploadResponse { + RequestSuperseded, + UploadOffsetMismatch { expected: u64 }, + InconsistentUploadLength { expected: u64, detail: Cow<'static, str> }, + StreamToFileOutcome(StreamUploadPayloadToFileOutcome), +} + +#[patch("/uploads/", data = "")] +async fn append_upload( + upload_id: &str, + upload_manager: &State, + payload: Data<'_>, + supplied_content_type: Option<&ContentType>, + supplied_content_length: SuppliedOptionalContentLength, + supplied_upload_offset: SuppliedUploadOffset, + supplied_upload_complete: SuppliedUploadComplete, +) -> Result { + if !supplied_content_type.map(|c| c.exact_eq(&PARTIAL_UPLOAD_MEDIA_TYPE)).unwrap_or(false) { + return Err(ApiError::HeaderValidationFailed { + name: "content-type".into(), + message: format!("must be {}", PARTIAL_UPLOAD_MEDIA_TYPE.to_string()).into(), + }); + } + + let upload = if let Some(upload) = upload_manager.get_upload_by_id(upload_id) { + upload + } else { + return Err(ApiError::ResourceNotFound { + resource_type: "upload".into(), + id: upload_id.to_owned().into(), + }); + }; + + let mut file_acquisition = if let Some(file) = upload.file().acquire().await { + file + } else { + return Ok(AppendUploadResponse::RequestSuperseded); + }; + + let release_request_token = file_acquisition.release_request_token(); + let mut file = file_acquisition.inner().get_or_open().await.map_err(Report::new)?; + + let total_size = upload.total_size(); + let current_offset = file.stream_position().await.map_err(Report::new)?; + let remaining_content_length = total_size - current_offset; + + if supplied_upload_offset.0 != current_offset { + return Ok(AppendUploadResponse::UploadOffsetMismatch { expected: current_offset }); + } + + let payload_length_limit = if let Some(supplied_content_length) = supplied_content_length.0 { + if supplied_upload_complete.0 { + if remaining_content_length != supplied_content_length { + return Ok(AppendUploadResponse::InconsistentUploadLength { + expected: total_size, + detail: "Upload-Complete is set to true, and Content-Length is set, \ + but the value of Content-Length does not equal the length of the remaining content." + .into(), + }); + } + } else { + if supplied_content_length >= remaining_content_length { + return Ok(AppendUploadResponse::InconsistentUploadLength { + expected: total_size, + detail: "Upload-Complete is set to false, and Content-Length is set, \ + but the value of Content-Length is not smaller than the length of the remaining content." + .into(), + }); + } + } + + supplied_content_length + } else { + remaining_content_length + }; + + let outcome = tokio::select! { + o = stream_upload_payload_to_file( + payload.open(payload_length_limit.bytes()), + &mut file, + remaining_content_length, + supplied_content_length.0, + supplied_upload_complete.0 + ) => Some(o), + _ = release_request_token.cancelled() => None + }; + + file.sync_all().await.map_err(Report::new)?; + file_acquisition.release().await; + + todo!() +} diff --git a/src/http_api/stream_upload_payload_to_file.rs b/src/http_api/stream_upload_payload_to_file.rs new file mode 100644 index 0000000..c6dea2e --- /dev/null +++ b/src/http_api/stream_upload_payload_to_file.rs @@ -0,0 +1,46 @@ +use rocket::data::DataStream; +use std::io::ErrorKind; +use tokio::fs::File; + +#[derive(Debug)] +pub enum StreamUploadPayloadToFileOutcome { + StoppedUnexpectedly, + TooMuchData, + Success, +} + +pub async fn stream_upload_payload_to_file( + stream: DataStream<'_>, + file: &mut File, + remaining_content_length: u64, + supplied_content_length: Option, + supplied_upload_complete: bool, +) -> Result { + match stream.stream_to(file).await { + Ok(n) => { + if let Some(supplied_content_length) = supplied_content_length { + if n.written < supplied_content_length { + return Ok(StreamUploadPayloadToFileOutcome::StoppedUnexpectedly); + } + } else { + if supplied_upload_complete { + if n.written < remaining_content_length { + return Ok(StreamUploadPayloadToFileOutcome::StoppedUnexpectedly); + } + } + } + + if !n.complete { + return Ok(StreamUploadPayloadToFileOutcome::TooMuchData); + } + + Ok(StreamUploadPayloadToFileOutcome::Success) + } + Err(error) => match error.kind() { + ErrorKind::TimedOut => Ok(StreamUploadPayloadToFileOutcome::StoppedUnexpectedly), + ErrorKind::BrokenPipe => Ok(StreamUploadPayloadToFileOutcome::StoppedUnexpectedly), + ErrorKind::ConnectionReset => Ok(StreamUploadPayloadToFileOutcome::StoppedUnexpectedly), + _ => Err(error), + }, + } +} diff --git a/src/http_api/upload_headers.rs b/src/http_api/upload_headers.rs new file mode 100644 index 0000000..40ea0b9 --- /dev/null +++ b/src/http_api/upload_headers.rs @@ -0,0 +1,90 @@ +use crate::http_api::api_error::ApiError; +use rocket::Request; +use rocket::http::Status; +use rocket::request::{FromRequest, Outcome}; +use std::str::FromStr; + +pub struct SuppliedUploadOffset(pub u64); + +#[rocket::async_trait] +impl<'r> FromRequest<'r> for SuppliedUploadOffset { + type Error = ApiError; + + async fn from_request(request: &'r Request<'_>) -> Outcome { + let mut value_iterator = request.headers().get("upload-offset"); + + if let Some(value) = value_iterator.next() { + if let Ok(value) = u64::from_str(value) { + if value_iterator.next().is_none() { + return Outcome::Success(SuppliedUploadOffset(value)); + } + } + }; + + Outcome::Error(( + Status::BadRequest, + ApiError::HeaderValidationFailed { + name: "Upload-Offset".into(), + message: "must be a single 64-bit unsigned decimal number".into(), + }, + )) + } +} + +pub struct SuppliedOptionalContentLength(pub Option); + +#[rocket::async_trait] +impl<'r> FromRequest<'r> for SuppliedOptionalContentLength { + type Error = ApiError; + + async fn from_request(request: &'r Request<'_>) -> Outcome { + let mut value_iterator = request.headers().get("content-length"); + + if let Some(value) = value_iterator.next() { + if let Ok(value) = u64::from_str(value) { + if value_iterator.next().is_none() { + return Outcome::Success(SuppliedOptionalContentLength(Some(value))); + } + } + } else { + return Outcome::Success(SuppliedOptionalContentLength(None)); + }; + + Outcome::Error(( + Status::BadRequest, + ApiError::HeaderValidationFailed { + name: "Content-Length".into(), + message: "must be a single 64-bit unsigned decimal number".into(), + }, + )) + } +} + +pub struct SuppliedUploadComplete(pub bool); + +#[rocket::async_trait] +impl<'r> FromRequest<'r> for SuppliedUploadComplete { + type Error = ApiError; + + async fn from_request(request: &'r Request<'_>) -> Outcome { + let mut value_iterator = request.headers().get("upload-complete"); + + if let Some(value) = value_iterator.next() { + if value_iterator.next().is_none() { + if value == "?1" { + return Outcome::Success(SuppliedUploadComplete(true)); + } else if value == "?0" { + return Outcome::Success(SuppliedUploadComplete(false)); + } + } + }; + + Outcome::Error(( + Status::BadRequest, + ApiError::HeaderValidationFailed { + name: "Upload-Complete".into(), + message: "must be `?1` (true) or `?0` (false)".into(), + }, + )) + } +} diff --git a/src/main.rs b/src/main.rs index 8c19125..b8a4166 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,8 +1,14 @@ +extern crate core; + mod config; mod http_api; +mod processing_worker; +mod upload_manager; +mod util; use crate::config::{ConfigBucket, ConfigBucketBackend, load_config}; use crate::http_api::start_http_api_server; +use crate::upload_manager::UploadManager; use camino::Utf8Path; use color_eyre::Result; use color_eyre::eyre::{WrapErr, eyre}; @@ -26,9 +32,18 @@ async fn main() -> Result<()> { .await .wrap_err("Failed to open the database connection.")?; + let upload_manager = UploadManager::create(database.clone(), config.staging_directory).await?; + log::info!("Initialization successful."); - start_http_api_server(&config, database).await?; + start_http_api_server( + upload_manager, + config.http_address, + config.http_port, + config.trust_http_reverse_proxy, + config.api_secret, + ) + .await?; Ok(()) } @@ -58,12 +73,7 @@ async fn initialize_buckets(bucket_configs: &Vec) -> Result<()> { Err(error) if error.kind() == ErrorKind::NotFound => { fs::create_dir_all(&filesystem_backend_config.path) .await - .wrap_err_with(|| { - format!( - "Could not create directory: {}", - filesystem_backend_config.path - ) - })?; + .wrap_err_with(|| format!("Could not create directory: {}", filesystem_backend_config.path))?; filesystem_backend_config.path.canonicalize_utf8()? } @@ -76,12 +86,9 @@ async fn initialize_buckets(bucket_configs: &Vec) -> Result<()> { )); } - check_directory_writable(&path).await.wrap_err_with(|| { - format!( - "The writable check for the {} bucket failed.", - &bucket_config.id - ) - })?; + check_directory_writable(&path) + .await + .wrap_err_with(|| format!("The writable check for the {} bucket failed.", &bucket_config.id))?; filesystem_backend_paths.insert(path); } @@ -93,9 +100,7 @@ async fn initialize_buckets(bucket_configs: &Vec) -> Result<()> { async fn check_directory_writable(directory_path: &Utf8Path) -> Result<()> { let path = directory_path.join("./minna-caos-write-check"); - let _ = fs::File::create(&path) - .await - .wrap_err("Writable check failed.")?; + let _ = fs::File::create(&path).await.wrap_err("Writable check failed.")?; fs::remove_file(path).await?; Ok(()) } diff --git a/src/processing_worker.rs b/src/processing_worker.rs new file mode 100644 index 0000000..e164482 --- /dev/null +++ b/src/processing_worker.rs @@ -0,0 +1,4 @@ +use crate::upload_manager::UploadId; +use sqlx::SqlitePool; + +pub async fn do_processing_work(tasks_receiver: tokio::sync::mpsc::UnboundedReceiver, database: SqlitePool) {} diff --git a/src/upload_manager.rs b/src/upload_manager.rs new file mode 100644 index 0000000..d77a34b --- /dev/null +++ b/src/upload_manager.rs @@ -0,0 +1,142 @@ +use crate::processing_worker::do_processing_work; +use crate::util::acquirable::{Acquirable, Acquisition}; +use crate::util::id::generate_id; +use camino::Utf8PathBuf; +use color_eyre::Result; +use dashmap::DashMap; +use fstr::FStr; +use sqlx::SqlitePool; +use std::fmt::Debug; +use std::sync::Arc; +use tokio::fs::{File, OpenOptions}; + +pub const UPLOAD_ID_LENGTH: usize = 16; +pub type UploadId = FStr; + +#[derive(Debug)] +pub struct UploadManager { + database: SqlitePool, + staging_directory_path: Utf8PathBuf, + ongoing_uploads: DashMap>, + small_file_processing_tasks_sender: tokio::sync::mpsc::UnboundedSender, + large_file_processing_tasks_sender: tokio::sync::mpsc::UnboundedSender, +} + +impl UploadManager { + pub async fn create(database: SqlitePool, staging_directory_path: Utf8PathBuf) -> Result { + log::info!("Loading unfinished uploads…"); + let ongoing_uploads = sqlx::query!("SELECT id, total_size FROM uploads") + .map(|row| { + let staging_file_path = staging_directory_path.join(&row.id); + let id = UploadId::from_str_lossy(&row.id, b'_'); + + ( + id, + Arc::new(UnfinishedUpload { + id, + total_size: row.total_size as u64, + file: Acquirable::new(FileReference::new(staging_file_path)), + }), + ) + }) + .fetch_all(&database) + .await?; + + log::info!("Starting upload processing…"); + + let (small_file_processing_tasks_sender, small_file_processing_tasks_receiver) = tokio::sync::mpsc::unbounded_channel(); + tokio::spawn(do_processing_work(small_file_processing_tasks_receiver, database.clone())); + + let (large_file_processing_tasks_sender, large_file_processing_tasks_receiver) = tokio::sync::mpsc::unbounded_channel(); + tokio::spawn(do_processing_work(large_file_processing_tasks_receiver, database.clone())); + + Ok(UploadManager { + database, + staging_directory_path, + ongoing_uploads: DashMap::from_iter(ongoing_uploads.into_iter()), + small_file_processing_tasks_sender, + large_file_processing_tasks_sender, + }) + } + + pub async fn create_upload(&self, total_size: u64) -> Result> { + let id: UploadId = generate_id(); + + { + let id = id.as_str(); + let total_size = total_size as i64; + sqlx::query!("INSERT INTO uploads (id, total_size) VALUES (?, ?)", id, total_size) + .execute(&self.database) + .await?; + } + + let upload = Arc::new(UnfinishedUpload { + id, + total_size, + file: Acquirable::new(FileReference::new(self.staging_directory_path.join(id.as_str()))), + }); + + self.ongoing_uploads.insert(id, Arc::clone(&upload)); + + Ok(upload) + } + + pub fn get_upload_by_id(&self, id: &str) -> Option> { + self.ongoing_uploads.get(id).map(|a| Arc::clone(a.value())) + } +} + +#[derive(Debug)] +pub struct UnfinishedUpload { + id: UploadId, + total_size: u64, + file: Acquirable, +} + +impl UnfinishedUpload { + pub fn id(&self) -> &UploadId { + &self.id + } + + pub fn file(&self) -> &Acquirable { + &self.file + } + + pub fn total_size(&self) -> u64 { + self.total_size + } + + pub async fn mark_as_finished(&self, file_acquisition: Acquisition) { + file_acquisition.destroy().await; + } +} + +#[derive(Debug)] +pub struct FileReference { + path: Utf8PathBuf, + file: Option, +} + +impl FileReference { + pub fn new(path: Utf8PathBuf) -> FileReference { + FileReference { path, file: None } + } + + pub async fn get_or_open(&mut self) -> Result<&mut File, std::io::Error> { + let file = &mut self.file; + if let Some(file) = file { + Ok(file) + } else { + *file = Some(OpenOptions::new().read(true).append(true).open(&self.path).await?); + Ok(unsafe { file.as_mut().unwrap_unchecked() }) + } + } + + pub fn is_open(&self) -> bool { + self.file.is_some() + } + + pub fn close(&mut self) -> bool { + if let Some(_file) = self.file.take() { true } else { false } + } +} diff --git a/src/util/acquirable.rs b/src/util/acquirable.rs new file mode 100644 index 0000000..447c58f --- /dev/null +++ b/src/util/acquirable.rs @@ -0,0 +1,143 @@ +use replace_with::{replace_with_or_abort, replace_with_or_abort_and_return}; +use std::sync::Arc; +use tokio::sync::Mutex; +use tokio_util::sync::CancellationToken; + +#[derive(Debug)] +pub struct Acquirable { + state: Arc>>, +} + +#[derive(Debug)] +pub enum AcquirableState { + Available { + inner: T, + }, + Acquired { + release_request_token: CancellationToken, + data_return_channel_sender: tokio::sync::oneshot::Sender<(T, CancellationToken)>, + }, + Destroyed, +} + +#[must_use] +pub struct Acquisition { + inner: T, + acquirable_state: Arc>>, + release_request_token: CancellationToken, +} + +impl Acquirable { + pub fn new(inner: T) -> Acquirable { + Acquirable { + state: Arc::new(Mutex::new(AcquirableState::Available { inner })), + } + } + + pub async fn acquire(&self) -> Option> { + let mut state = self.state.lock().await; + + enum Outcome { + Acquired(Acquisition), + Waiting { + data_return_channel_receiver: tokio::sync::oneshot::Receiver<(T, CancellationToken)>, + }, + Destroyed, + } + + let outcome = replace_with_or_abort_and_return(&mut *state, |state| match state { + AcquirableState::Available { inner } => { + let release_request_token = CancellationToken::new(); + let (data_return_channel_sender, data_return_channel_receiver) = tokio::sync::oneshot::channel(); + drop(data_return_channel_receiver); + + ( + Outcome::Acquired(Acquisition { + inner, + acquirable_state: Arc::clone(&self.state), + release_request_token: release_request_token.clone(), + }), + AcquirableState::Acquired { + release_request_token, + data_return_channel_sender, + }, + ) + } + AcquirableState::Acquired { release_request_token, .. } => { + release_request_token.cancel(); + let (data_return_channel_sender, data_return_channel_receiver) = tokio::sync::oneshot::channel(); + + ( + Outcome::Waiting { data_return_channel_receiver }, + AcquirableState::Acquired { + release_request_token, + data_return_channel_sender, + }, + ) + } + AcquirableState::Destroyed => (Outcome::Destroyed, AcquirableState::Destroyed), + }); + + drop(state); + + match outcome { + Outcome::Acquired(acquisition) => Some(acquisition), + Outcome::Waiting { data_return_channel_receiver } => { + let data = data_return_channel_receiver.await; + + match data { + Ok((data, release_request_token)) => Some(Acquisition { + inner: data, + acquirable_state: Arc::clone(&self.state), + release_request_token, + }), + Err(_) => None, + } + } + Outcome::Destroyed => None, + } + } +} + +impl Acquisition { + pub fn inner(&mut self) -> &mut T { + &mut self.inner + } + + pub fn release_request_token(&self) -> CancellationToken { + self.release_request_token.clone() + } + + pub async fn release(self) { + let mut state = self.acquirable_state.lock().await; + + replace_with_or_abort(&mut *state, |state| match state { + AcquirableState::Acquired { + data_return_channel_sender, .. + } => { + let release_request_token = CancellationToken::new(); + match data_return_channel_sender.send((self.inner, release_request_token.clone())) { + Ok(_) => { + let (data_return_channel_sender, data_return_channel_receiver) = tokio::sync::oneshot::channel(); + drop(data_return_channel_receiver); + + AcquirableState::Acquired { + release_request_token, + data_return_channel_sender, + } + } + Err((data, _)) => AcquirableState::Available { inner: data }, + } + } + _ => unreachable!(), + }); + } + + /// Consume the acquisition without releasing it. The corresponding Acquirable will forever stay in the acquired state. + /// + /// All outstanding calls to Acquirable::acquire will return None. + pub async fn destroy(self) { + let mut state = self.acquirable_state.lock().await; + *state = AcquirableState::Destroyed; + } +} diff --git a/src/util/id.rs b/src/util/id.rs new file mode 100644 index 0000000..3e62921 --- /dev/null +++ b/src/util/id.rs @@ -0,0 +1,8 @@ +use fstr::FStr; +use rand::Rng; +use rand::distr::Alphanumeric; + +pub fn generate_id() -> FStr { + let bytes: [u8; N] = std::array::from_fn(|_| rand::rng().sample(&Alphanumeric)); + unsafe { FStr::from_inner_unchecked(bytes) } +} diff --git a/src/util/mod.rs b/src/util/mod.rs new file mode 100644 index 0000000..63dac0d --- /dev/null +++ b/src/util/mod.rs @@ -0,0 +1,2 @@ +pub mod acquirable; +pub mod id;