diff --git a/Cargo.lock b/Cargo.lock index 6df3e14..7300ec0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -643,6 +643,21 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "futures" +version = "0.3.25" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38390104763dc37a5145a53c29c63c1290b5d316d6086ec32c293f6736051bb0" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + [[package]] name = "futures-channel" version = "0.3.25" @@ -659,6 +674,17 @@ version = "0.3.25" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "04909a7a7e4633ae6c4a9ab280aeb86da1236243a77b694a49eacd659a4bd3ac" +[[package]] +name = "futures-executor" +version = "0.3.25" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7acc85df6714c176ab5edf386123fafe217be88c0840ec11f199441134a074e2" +dependencies = [ + "futures-core", + "futures-task", + "futures-util", +] + [[package]] name = "futures-intrusive" version = "0.4.2" @@ -670,6 +696,23 @@ dependencies = [ "parking_lot 0.11.2", ] +[[package]] +name = "futures-io" +version = "0.3.25" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "00f5fb52a06bdcadeb54e8d3671f8888a39697dcb0b81b23b55174030427f4eb" + +[[package]] +name = "futures-macro" +version = "0.3.25" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bdfb8ce053d86b91919aad980c220b1fb8401a9394410e1c289ed7e66b61835d" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "futures-sink" version = "0.3.25" @@ -688,11 +731,16 @@ version = "0.3.25" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "197676987abd2f9cadff84926f410af1c183608d36641465df73ae8211dc65d6" dependencies = [ + "futures-channel", "futures-core", + "futures-io", + "futures-macro", "futures-sink", "futures-task", + "memchr", "pin-project-lite", "pin-utils", + "slab", ] [[package]] @@ -1080,6 +1128,7 @@ dependencies = [ "async-trait", "chrono", "dotenvy", + "futures", "jsonwebtoken", "log", "pretty_env_logger", diff --git a/Cargo.toml b/Cargo.toml index a03a542..ba2ee36 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,6 +10,7 @@ argon2 = "0.4" async-trait = "0.1.59" chrono = { version = "0.4", features = [ "alloc", "clock", "serde" ] } dotenvy = "0.15.6" +futures = "0.3" jsonwebtoken = { version = "8", default-features = false } log = "0.4" pretty_env_logger = "0.4" diff --git a/src/core/mod.rs b/src/core/mod.rs index 96a795d..a3a1c93 100644 --- a/src/core/mod.rs +++ b/src/core/mod.rs @@ -1,8 +1,9 @@ use actix_web::body::BoxBody; -use actix_web::http::StatusCode; +use actix_web::http::{header, StatusCode}; use actix_web::{HttpResponse, ResponseError}; use chrono::prelude::*; use serde::{Serialize, Serializer}; +use std::time::{SystemTime, UNIX_EPOCH}; use std::{fmt, io}; use crate::util::validate; @@ -29,6 +30,15 @@ pub fn utc_now() -> NaiveDateTime { Utc::now().naive_utc() } +pub fn unix_now() -> i64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect( + "You've either broken spacetime, or your system clock is a bit off.", + ) + .as_secs() as i64 +} + impl ResponseError for Error { fn status_code(&self) -> StatusCode { match self { diff --git a/src/main.rs b/src/main.rs index b2cce61..c41f14c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -18,6 +18,9 @@ mod data; /// Asynchronous background workers. mod job; +/// Middleware for request handlers. +mod middle; + /// Database models and validation. mod model; @@ -65,6 +68,7 @@ async fn main() -> std::io::Result<()> { HttpServer::new(move || { App::new() .wrap(actix_web::middleware::DefaultHeaders::new().add(("Server", "nyano"))) + .wrap(middle::auth::Auth::new(state.clone())) .app_data(state.clone()) .configure(route::configure) }) diff --git a/src/middle/auth.rs b/src/middle/auth.rs new file mode 100644 index 0000000..49b237c --- /dev/null +++ b/src/middle/auth.rs @@ -0,0 +1,113 @@ +use actix_web::dev::{Payload, Service, ServiceRequest, ServiceResponse, Transform}; +use actix_web::{FromRequest, HttpMessage, HttpRequest}; +use futures::future::LocalBoxFuture; +use futures::FutureExt; +use std::future::{ready, Ready}; +use std::rc::Rc; +use std::task::{Context, Poll}; + +use crate::core::*; +use crate::model::Account; +use crate::state::AppState; +use crate::util::token; + +/// Factory for [`AuthMiddleware`]. +pub struct Auth { + state: AppState, +} + +impl Auth { + pub fn new(state: AppState) -> Auth { + Auth { state } + } +} + +impl Transform for Auth +where + S: Service, Error = actix_web::Error> + 'static, +{ + type Response = ServiceResponse; + type Error = actix_web::Error; + type Transform = AuthMiddleware; + type InitError = (); + type Future = Ready>; + + fn new_transform(&self, service: S) -> Self::Future { + ready(Ok(AuthMiddleware { + service: Rc::new(service), + state: self.state.clone(), + })) + } +} + +pub struct AuthMiddleware { + service: Rc, + state: AppState, +} + +impl Service for AuthMiddleware +where + S: Service, Error = actix_web::Error> + 'static, +{ + type Response = ServiceResponse; + type Error = actix_web::Error; + type Future = LocalBoxFuture<'static, std::result::Result>; + + fn poll_ready(&self, ctx: &mut Context<'_>) -> Poll> { + self.service.poll_ready(ctx) + } + + fn call(&self, req: ServiceRequest) -> Self::Future { + let service = self.service.clone(); + let state = self.state.clone(); + + async move { + let account = if let Some(token) = req.headers().get("Authorization") { + let token = extract_token(token.to_str().unwrap())?; + debug!("token = \"{}\"", token); + let account = token::validate(&state, token).await?; + Some(account) + } else { + None + }; + req.extensions_mut().insert(account); + service.call(req).await + } + .boxed_local() + } +} + +pub struct AuthData(Option); + +impl AuthData { + pub fn maybe(&self) -> Option<&Account> { + self.0.as_ref() + } + + pub fn require(&self) -> Result<&Account> { + self.maybe().ok_or(Error::BadCredentials) + } +} + +impl FromRequest for AuthData { + type Error = Error; + type Future = Ready>; + + fn from_request(req: &HttpRequest, _payload: &mut Payload) -> Self::Future { + let val: Option> = req.extensions_mut().remove(); + ready(match val { + Some(a) => Ok(AuthData(a)), + None => Err(Error::BadCredentials), + }) + } +} + +fn extract_token(header: &str) -> Result<&str> { + const PREFIX: &'static str = "Bearer "; + + if header.starts_with(PREFIX) { + Ok(&header[PREFIX.len()..]) + } else { + Err(Error::BadCredentials) + } +} diff --git a/src/middle/mod.rs b/src/middle/mod.rs new file mode 100644 index 0000000..da75fb1 --- /dev/null +++ b/src/middle/mod.rs @@ -0,0 +1,3 @@ +pub mod auth; + +pub use auth::AuthData; diff --git a/src/route/api/v1/accounts.rs b/src/route/api/v1/accounts.rs index ba64fee..2a62bfd 100644 --- a/src/route/api/v1/accounts.rs +++ b/src/route/api/v1/accounts.rs @@ -2,11 +2,18 @@ use actix_web::{get, post, web, HttpResponse}; use serde::Deserialize; use crate::core::*; +use crate::middle::AuthData; use crate::model::{NewAccount, NewUser}; use crate::state::AppState; use crate::util::password; use crate::util::validate::{ResultBuilder, Validate}; +#[get("/self")] +async fn get_self(account: AuthData) -> Result { + let account = account.require()?; + Ok(HttpResponse::Ok().json(account)) +} + #[derive(Deserialize)] struct SignupData { username: String, @@ -75,5 +82,8 @@ async fn get_notes(path: web::Path, state: AppState) -> Result } pub fn configure(cfg: &mut web::ServiceConfig) { - cfg.service(get_by_id).service(get_notes).service(signup); + cfg.service(get_self) + .service(get_by_id) + .service(get_notes) + .service(signup); } diff --git a/src/util/token.rs b/src/util/token.rs index 2671da5..8758c4d 100644 --- a/src/util/token.rs +++ b/src/util/token.rs @@ -1,4 +1,3 @@ -use chrono::{DateTime, Duration, Utc}; use jsonwebtoken::{decode, encode, Algorithm, DecodingKey, EncodingKey, Header, Validation}; use serde::{Deserialize, Serialize}; @@ -6,18 +5,20 @@ use crate::core::*; use crate::model::Account; use crate::state::AppState; -const VALID_DAYS: i64 = 30; +const VALID_SECS: i64 = 30 * 86400; // 30 days #[derive(Debug, Serialize, Deserialize)] struct Claims { - acct: String, // store as a string because JSON can't be trusted - exp: DateTime, + /// Account ID, stored as a string because JSON can't be trusted + sub: String, + /// Expiry date in UNIX time (seconds) + exp: i64, } pub fn issue(state: &AppState, account: &Account) -> Result { let claims = Claims { - acct: format!("{}", account.id), - exp: Utc::now() + Duration::days(VALID_DAYS), + sub: format!("{}", account.id), + exp: unix_now() + 30 * 86400, }; let encoding_key = EncodingKey::from_secret(state.config.jwt_secret.as_slice()); let header = Header { @@ -29,14 +30,14 @@ pub fn issue(state: &AppState, account: &Account) -> Result { Ok(encode(&header, &claims, &encoding_key)?) } -async fn validate(state: &AppState, token: &str) -> Result { +pub async fn validate(state: &AppState, token: &str) -> Result { let decoding_key = DecodingKey::from_secret(state.config.jwt_secret.as_slice()); let validation = Validation::new(Algorithm::HS256); let claims: Claims = decode(token, &decoding_key, &validation)?.claims; - if claims.exp - Utc::now() > Duration::days(VALID_DAYS) { - // tokens that are valid for longer than VALID_DAYS are sus - todo!() + if unix_now() - claims.exp > VALID_SECS { + // tokens that are valid for longer than VALID_SECS are sus + return Err(Error::BadCredentials); } - let account_id: Id = claims.acct.parse().expect("We issued an invalid token??"); + let account_id: Id = claims.sub.parse().expect("We issued an invalid token??"); state.repo.accounts.by_id(account_id).await }