diff --git a/src/core/error.rs b/src/core/error.rs index c11221e..c5289a6 100644 --- a/src/core/error.rs +++ b/src/core/error.rs @@ -6,9 +6,11 @@ use actix_web::{ use serde::{ser::SerializeMap, Serialize, Serializer}; use std::{fmt, io}; -use crate::util::{crypto, validate}; +use crate::util::{crypto, validate, xsd}; pub type Result = std::result::Result; +pub use std::error::Error as StdError; +pub use std::result::Result as StdResult; #[derive(Debug)] pub enum Error { @@ -27,6 +29,7 @@ pub enum Error { MalformedHeader(header::ToStrError), NotFound, Reqwest(reqwest::Error), + Xsd(xsd::FromStrError), } impl ResponseError for Error { @@ -43,6 +46,7 @@ impl ResponseError for Error { Error::MalformedApub(_) => StatusCode::UNPROCESSABLE_ENTITY, Error::MalformedHeader(_) => StatusCode::BAD_REQUEST, Error::NotFound => StatusCode::NOT_FOUND, + Error::Xsd(_) => StatusCode::UNPROCESSABLE_ENTITY, _ => StatusCode::INTERNAL_SERVER_ERROR, } } @@ -52,8 +56,8 @@ impl ResponseError for Error { } } -impl std::error::Error for Error { - fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { +impl StdError for Error { + fn source(&self) -> Option<&(dyn StdError + 'static)> { match self { Error::BadToken(e) => Some(e), Error::Crypto(e) => Some(e), @@ -85,6 +89,7 @@ impl fmt::Display for Error { Error::MalformedApub(msg) => write!(f, "Malformed ActivityPub: {msg}"), Error::MalformedHeader(to_str_error) => to_str_error.fmt(f), Error::Reqwest(reqwest_error) => reqwest_error.fmt(f), + Error::Xsd(xsd_error) => xsd_error.fmt(f), } } } @@ -134,6 +139,12 @@ impl From for Error { } } +impl From for Error { + fn from(e: xsd::FromStrError) -> Error { + Error::Xsd(e) + } +} + impl Serialize for Error { fn serialize(&self, serializer: S) -> std::result::Result where diff --git a/src/util/mod.rs b/src/util/mod.rs index 59914b1..c188ee5 100644 --- a/src/util/mod.rs +++ b/src/util/mod.rs @@ -14,3 +14,5 @@ pub mod token; pub mod transcode; /// Validation framework. pub mod validate; +/// Some XSD types and parsers. +pub mod xsd; diff --git a/src/util/xsd/duration.rs b/src/util/xsd/duration.rs new file mode 100644 index 0000000..e96d95b --- /dev/null +++ b/src/util/xsd/duration.rs @@ -0,0 +1,492 @@ +use std::fmt; +use std::fmt::Write; +use std::str::FromStr; + +use crate::core::*; +use crate::util::slice::SliceCursor; +use crate::util::xsd::{from_str_error, FromStrError}; + +#[derive(Clone, Eq, PartialEq, Default, Debug)] +pub struct Duration { + negative: bool, + years: u16, + months: u32, + days: u32, + hours: u32, + mins: u32, + millis: u64, +} + +enum DateUnit { + Years(u16), + Months(u16), + Days(u32), +} + +enum TimeUnit { + Hours(u32), + Minutes(u32), + Millis(u64), +} + +trait UnitDomain: Sized { + fn index(&self) -> usize; + fn val(&self) -> u64; + fn parse(cursor: &mut SliceCursor) -> StdResult; +} + +impl Duration { + pub fn is_negative(&self) -> bool { + self.negative + } + + pub fn set_negative(&mut self, negative: bool) { + self.negative = negative; + } + + pub fn years(&self) -> u16 { + self.years + } + + pub fn set_years(&mut self, years: u16) { + self.years = years; + } + + pub fn from_years(years: u16) -> Self { + Self { + years, + ..Default::default() + } + } + + pub fn months(&self) -> u32 { + self.months + } + + pub fn set_months(&mut self, months: u32) { + self.months = months; + } + + pub fn from_months(months: u32) -> Self { + Self { + months, + ..Default::default() + } + } + + pub fn days(&self) -> u32 { + self.days + } + + pub fn set_days(&mut self, days: u32) { + self.days = days; + } + + pub fn from_days(days: u32) -> Self { + Self { + days, + ..Default::default() + } + } + + pub fn hours(&self) -> u32 { + self.hours + } + + pub fn set_hours(&mut self, hours: u32) { + self.hours = hours; + } + + pub fn from_hours(hours: u32) -> Self { + Self { + hours, + ..Default::default() + } + } + + pub fn mins(&self) -> u32 { + self.mins + } + + pub fn set_mins(&mut self, mins: u32) { + self.mins = mins; + } + + pub fn from_mins(mins: u32) -> Self { + Self { + mins, + ..Default::default() + } + } + + pub fn secs(&self) -> u64 { + self.millis / 1000 + } + + pub fn set_secs(&mut self, secs: u64) { + self.millis = secs * 1000 + self.millis % 1000; + } + + pub fn from_secs(secs: u64) -> Self { + Self { + millis: secs * 1000, + ..Default::default() + } + } + + pub fn millis(&self) -> u16 { + (self.millis % 1000) as u16 + } + + pub fn set_millis(&mut self, millis: u16) { + debug_assert!(millis < 1000); + self.millis = self.secs() * 1000 + millis as u64; + } + + pub fn from_millis(millis: u64) -> Self { + Self { + millis, + ..Default::default() + } + } +} + +impl FromStr for Duration { + type Err = FromStrError; + + fn from_str(s: &str) -> StdResult { + let mut cursor = SliceCursor::new(s.as_bytes()); + + let negative = cursor.next_if(|&b| b == b'-').is_some(); + cursor + .next_if(|&b| b == b'P') + .ok_or_else(|| from_str_error("Expected a 'P'"))?; + let mut vals: [Option; 6] = [None; 6]; + let mut prev_index = None; + + parse_unit_domain::(&mut vals, &mut prev_index, &mut cursor)?; + match cursor.peek().copied() { + Some(b'T') => { + cursor.next(); + let prev_index_before = prev_index; + parse_unit_domain::(&mut vals, &mut prev_index, &mut cursor)?; + if prev_index_before == prev_index { + return Err(from_str_error("Time separator must be followed by a value")); + } + } + Some(_) => return Err(from_str_error("Syntax error")), + None => {} + }; + if prev_index.is_none() { + return Err(from_str_error("No values specified")); + } + if cursor.next().is_some() { + return Err(from_str_error("Unexpected trailing character")); + } + + Ok(Self { + negative, + years: convert_to(vals[0], "Year value exceeds range")?, + months: convert_to(vals[1], "Months exceed range")?, + days: convert_to(vals[2], "Days exceed range")?, + hours: convert_to(vals[3], "Hours exceed range")?, + mins: convert_to(vals[4], "Minutes exceed range")?, + millis: vals[5].unwrap_or_default(), + }) + } +} + +fn parse_unit_domain( + vals: &mut [Option; 6], + prev_index: &mut Option, + cursor: &mut SliceCursor, +) -> StdResult<(), FromStrError> { + for _ in 0..3 { + match cursor.peek().copied() { + Some(b'T') => break, + None => break, + Some(_) => { + let val = T::parse(cursor)?; + let index = val.index(); + if Some(index) < *prev_index { + return Err(from_str_error("Wrong unit order")); + } + *prev_index = Some(index); + + if vals[val.index()].replace(val.val()).is_some() { + return Err(from_str_error("Duplicate unit definition")); + } + } + } + } + + Ok(()) +} + +fn convert_to(t: Option, msg: impl Into) -> StdResult +where + T: TryInto + Default, +{ + t.unwrap_or_default() + .try_into() + .map_err(|_| from_str_error(msg)) +} + +impl fmt::Display for Duration { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + if self.negative { + f.write_char('-')?; + } + f.write_char('P')?; + if self.years == 0 + && self.months == 0 + && self.days == 0 + && self.hours == 0 + && self.mins == 0 + && self.millis == 0 + { + return f.write_str("T0S"); + } + + if self.years > 0 { + f.write_fmt(format_args!("{}Y", self.years))?; + } + if self.months > 0 { + f.write_fmt(format_args!("{}M", self.months))?; + } + if self.days > 0 { + f.write_fmt(format_args!("{}D", self.days))?; + } + if self.hours as u64 + self.mins as u64 + self.millis > 0 { + f.write_char('T')?; + if self.hours > 0 { + f.write_fmt(format_args!("{}H", self.hours))?; + } + if self.mins > 0 { + f.write_fmt(format_args!("{}M", self.mins))?; + } + if self.millis > 0 { + f.write_fmt(format_args!("{}", self.millis / 1000))?; + if self.millis % 1000 > 0 { + f.write_fmt(format_args!(".{:03}", self.millis % 1000))?; + } + f.write_fmt(format_args!("S"))?; + } + } + Ok(()) + } +} + +impl UnitDomain for DateUnit { + fn index(&self) -> usize { + match self { + DateUnit::Years(_) => 0, + DateUnit::Months(_) => 1, + DateUnit::Days(_) => 2, + } + } + + fn val(&self) -> u64 { + match self { + DateUnit::Years(n) => *n as u64, + DateUnit::Months(n) => *n as u64, + DateUnit::Days(n) => *n as u64, + } + } + + fn parse(cursor: &mut SliceCursor) -> StdResult { + let n: u32 = parse_int(cursor)?; + match cursor.next().copied() { + Some(b'Y') => { + Ok(DateUnit::Years(n.try_into().map_err(|e| { + from_str_error(format!("Invalid years value: {e}")) + })?)) + } + Some(b'M') => { + Ok(DateUnit::Months(n.try_into().map_err(|e| { + from_str_error(format!("Invalid months value: {e}")) + })?)) + } + Some(b'D') => Ok(DateUnit::Days(n)), + Some(_) => Err(from_str_error("Invalid unit specifier")), + None => Err(from_str_error("Missing unit specifier")), + } + } +} + +impl UnitDomain for TimeUnit { + fn index(&self) -> usize { + match self { + TimeUnit::Hours(_) => 3, + TimeUnit::Minutes(_) => 4, + TimeUnit::Millis(_) => 5, + } + } + + fn val(&self) -> u64 { + match self { + TimeUnit::Hours(n) => *n as u64, + TimeUnit::Minutes(n) => *n as u64, + TimeUnit::Millis(n) => *n, + } + } + + fn parse(cursor: &mut SliceCursor) -> StdResult { + let int: u64 = parse_int(cursor)?; + let decimals = cursor + .next_if(|&c| c == b'.') + .and_then(|_| { + let fract_bytes = cursor.next_while(|c| c.is_ascii_digit()); + // only support up to 3 decimal digits + let fract_bytes = &fract_bytes[..fract_bytes.len().min(4)]; + std::str::from_utf8(fract_bytes) + .ok()? + .parse::() // this fails if the string is empty + .ok() + .map(|f| f * 10u32.pow(3 - fract_bytes.len() as u32)) + }) + .map(|n| n as u64); + + match cursor.next().copied() { + Some(b'H') => { + if decimals.is_none() { + Ok(Self::Hours(int.try_into().map_err(|e| { + from_str_error(format!("Invalid hours value: {e}")) + })?)) + } else { + Err(from_str_error("Hours cannot contain a fractional part")) + } + } + Some(b'M') => { + if decimals.is_none() { + Ok(Self::Minutes(int.try_into().map_err(|e| { + from_str_error(format!("Invalid minutes value: {e}")) + })?)) + } else { + Err(from_str_error("Minutes cannot contain a fractional part")) + } + } + Some(b'S') => int + .checked_mul(1000) + .and_then(|ms| ms.checked_add(decimals.unwrap_or(0))) + .map(Self::Millis) + .ok_or_else(|| from_str_error("Integer overflow while parsing seconds")), + Some(_) => Err(from_str_error("Invalid unit specified")), + None => Err(from_str_error("Missing unit specifier")), + } + } +} + +fn parse_int(cursor: &mut SliceCursor) -> StdResult +where + F: FromStr, + F::Err: fmt::Display, +{ + let num = std::str::from_utf8(cursor.next_while(|&c| c.is_ascii_digit())) + .expect("ASCII digits are always valid UTF-8"); + num.parse::() + .map_err(|e| from_str_error(format!("Cannot parse number: {e}"))) +} + +impl fmt::Display for FromStrError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(&self.msg) + } +} + +impl StdError for FromStrError { + fn description(&self) -> &str { + &self.msg + } +} + +#[cfg(test)] +mod tests { + use super::Duration; + use std::str::FromStr; + + #[test] + fn set_vals() { + let mut duration = Duration::default(); + + duration.set_years(1); + assert_eq!(duration.years(), 1); + duration.set_months(2); + assert_eq!(duration.months(), 2); + duration.set_days(3); + assert_eq!(duration.days(), 3); + duration.set_hours(4); + assert_eq!(duration.hours(), 4); + duration.set_mins(5); + assert_eq!(duration.mins(), 5); + duration.set_secs(6); + assert_eq!(duration.secs(), 6); + duration.set_millis(7); + assert_eq!(duration.millis(), 7); + assert_eq!(duration.secs(), 6); + } + + #[test] + fn from_str() { + let test_values = [ + ( + "-P1Y2M3DT4H5M6.7S", + (true, 1u16, 2u32, 3u32, 4u32, 5u32, 6u64, 700u16), + "-P1Y2M3DT4H5M6.700S", + ), + ( + "P7Y6M5DT4H3M2S", + (false, 7, 6, 5, 4, 3, 2, 0), + "P7Y6M5DT4H3M2S", + ), + ("PT0S", (false, 0, 0, 0, 0, 0, 0, 0), "PT0S"), + ("P1Y", (false, 1, 0, 0, 0, 0, 0, 0), "P1Y"), + ("P1M", (false, 0, 1, 0, 0, 0, 0, 0), "P1M"), + ("P1D", (false, 0, 0, 1, 0, 0, 0, 0), "P1D"), + ("PT1H", (false, 0, 0, 0, 1, 0, 0, 0), "PT1H"), + ("PT1M", (false, 0, 0, 0, 0, 1, 0, 0), "PT1M"), + ("PT1S", (false, 0, 0, 0, 0, 0, 1, 0), "PT1S"), + ("PT0.001S", (false, 0, 0, 0, 0, 0, 0, 1), "PT0.001S"), + ]; + + for (input, (negative, years, months, days, hours, mins, secs, millis), output) in + test_values + { + let duration = Duration::from_str(input).unwrap(); + assert_eq!(duration.is_negative(), negative); + assert_eq!(duration.years(), years); + assert_eq!(duration.months(), months); + assert_eq!(duration.days(), days); + assert_eq!(duration.hours(), hours); + assert_eq!(duration.mins(), mins); + assert_eq!(duration.secs(), secs); + assert_eq!(duration.millis(), millis); + + assert_eq!(duration.to_string(), output); + } + + let bad_values = [ + "P", + "-P", + "P1H", + "P1HT", + "P1S", + "P1ST", + "PT1Y", + "PT1D", + "PT", + "P1YT", + "P65536Y", + "P4294967296M", + "P4294967296D", + "PT4294967296H", + "PT4294967296M", + "P1DT18446744073709551.616S", + ]; + for input in bad_values { + let result = Duration::from_str(input); + eprintln!("{result:?}"); + assert!(result.is_err()); + } + } +} diff --git a/src/util/xsd/mod.rs b/src/util/xsd/mod.rs new file mode 100644 index 0000000..80b474b --- /dev/null +++ b/src/util/xsd/mod.rs @@ -0,0 +1,19 @@ +mod duration; +pub use duration::Duration; + +#[derive(Debug)] +pub struct FromStrError { + msg: String, +} + +fn from_str_error(msg: O) -> FromStrError +where + T: Into, + O: Into>, +{ + let msg = msg + .into() + .map(|t| t.into()) + .unwrap_or_else(|| String::from("Invalid value")); + FromStrError { msg } +}