util: add parser for xsd:duration

i am in immeasurable pain
This commit is contained in:
anna 2023-08-06 06:18:47 +02:00
parent 36becee078
commit dad5d07dd6
Signed by: fef
GPG key ID: 2585C2DC6D79B485
4 changed files with 527 additions and 3 deletions

View file

@ -6,9 +6,11 @@ use actix_web::{
use serde::{ser::SerializeMap, Serialize, Serializer}; use serde::{ser::SerializeMap, Serialize, Serializer};
use std::{fmt, io}; use std::{fmt, io};
use crate::util::{crypto, validate}; use crate::util::{crypto, validate, xsd};
pub type Result<T> = std::result::Result<T, Error>; pub type Result<T> = std::result::Result<T, Error>;
pub use std::error::Error as StdError;
pub use std::result::Result as StdResult;
#[derive(Debug)] #[derive(Debug)]
pub enum Error { pub enum Error {
@ -27,6 +29,7 @@ pub enum Error {
MalformedHeader(header::ToStrError), MalformedHeader(header::ToStrError),
NotFound, NotFound,
Reqwest(reqwest::Error), Reqwest(reqwest::Error),
Xsd(xsd::FromStrError),
} }
impl ResponseError for Error { impl ResponseError for Error {
@ -43,6 +46,7 @@ impl ResponseError for Error {
Error::MalformedApub(_) => StatusCode::UNPROCESSABLE_ENTITY, Error::MalformedApub(_) => StatusCode::UNPROCESSABLE_ENTITY,
Error::MalformedHeader(_) => StatusCode::BAD_REQUEST, Error::MalformedHeader(_) => StatusCode::BAD_REQUEST,
Error::NotFound => StatusCode::NOT_FOUND, Error::NotFound => StatusCode::NOT_FOUND,
Error::Xsd(_) => StatusCode::UNPROCESSABLE_ENTITY,
_ => StatusCode::INTERNAL_SERVER_ERROR, _ => StatusCode::INTERNAL_SERVER_ERROR,
} }
} }
@ -52,8 +56,8 @@ impl ResponseError for Error {
} }
} }
impl std::error::Error for Error { impl StdError for Error {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { fn source(&self) -> Option<&(dyn StdError + 'static)> {
match self { match self {
Error::BadToken(e) => Some(e), Error::BadToken(e) => Some(e),
Error::Crypto(e) => Some(e), Error::Crypto(e) => Some(e),
@ -85,6 +89,7 @@ impl fmt::Display for Error {
Error::MalformedApub(msg) => write!(f, "Malformed ActivityPub: {msg}"), Error::MalformedApub(msg) => write!(f, "Malformed ActivityPub: {msg}"),
Error::MalformedHeader(to_str_error) => to_str_error.fmt(f), Error::MalformedHeader(to_str_error) => to_str_error.fmt(f),
Error::Reqwest(reqwest_error) => reqwest_error.fmt(f), Error::Reqwest(reqwest_error) => reqwest_error.fmt(f),
Error::Xsd(xsd_error) => xsd_error.fmt(f),
} }
} }
} }
@ -134,6 +139,12 @@ impl From<crypto::Error> for Error {
} }
} }
impl From<xsd::FromStrError> for Error {
fn from(e: xsd::FromStrError) -> Error {
Error::Xsd(e)
}
}
impl Serialize for Error { impl Serialize for Error {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error> fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where where

View file

@ -14,3 +14,5 @@ pub mod token;
pub mod transcode; pub mod transcode;
/// Validation framework. /// Validation framework.
pub mod validate; pub mod validate;
/// Some XSD types and parsers.
pub mod xsd;

492
src/util/xsd/duration.rs Normal file
View file

@ -0,0 +1,492 @@
use std::fmt;
use std::fmt::Write;
use std::str::FromStr;
use crate::core::*;
use crate::util::slice::SliceCursor;
use crate::util::xsd::{from_str_error, FromStrError};
#[derive(Clone, Eq, PartialEq, Default, Debug)]
pub struct Duration {
negative: bool,
years: u16,
months: u32,
days: u32,
hours: u32,
mins: u32,
millis: u64,
}
enum DateUnit {
Years(u16),
Months(u16),
Days(u32),
}
enum TimeUnit {
Hours(u32),
Minutes(u32),
Millis(u64),
}
trait UnitDomain: Sized {
fn index(&self) -> usize;
fn val(&self) -> u64;
fn parse(cursor: &mut SliceCursor<u8>) -> StdResult<Self, FromStrError>;
}
impl Duration {
pub fn is_negative(&self) -> bool {
self.negative
}
pub fn set_negative(&mut self, negative: bool) {
self.negative = negative;
}
pub fn years(&self) -> u16 {
self.years
}
pub fn set_years(&mut self, years: u16) {
self.years = years;
}
pub fn from_years(years: u16) -> Self {
Self {
years,
..Default::default()
}
}
pub fn months(&self) -> u32 {
self.months
}
pub fn set_months(&mut self, months: u32) {
self.months = months;
}
pub fn from_months(months: u32) -> Self {
Self {
months,
..Default::default()
}
}
pub fn days(&self) -> u32 {
self.days
}
pub fn set_days(&mut self, days: u32) {
self.days = days;
}
pub fn from_days(days: u32) -> Self {
Self {
days,
..Default::default()
}
}
pub fn hours(&self) -> u32 {
self.hours
}
pub fn set_hours(&mut self, hours: u32) {
self.hours = hours;
}
pub fn from_hours(hours: u32) -> Self {
Self {
hours,
..Default::default()
}
}
pub fn mins(&self) -> u32 {
self.mins
}
pub fn set_mins(&mut self, mins: u32) {
self.mins = mins;
}
pub fn from_mins(mins: u32) -> Self {
Self {
mins,
..Default::default()
}
}
pub fn secs(&self) -> u64 {
self.millis / 1000
}
pub fn set_secs(&mut self, secs: u64) {
self.millis = secs * 1000 + self.millis % 1000;
}
pub fn from_secs(secs: u64) -> Self {
Self {
millis: secs * 1000,
..Default::default()
}
}
pub fn millis(&self) -> u16 {
(self.millis % 1000) as u16
}
pub fn set_millis(&mut self, millis: u16) {
debug_assert!(millis < 1000);
self.millis = self.secs() * 1000 + millis as u64;
}
pub fn from_millis(millis: u64) -> Self {
Self {
millis,
..Default::default()
}
}
}
impl FromStr for Duration {
type Err = FromStrError;
fn from_str(s: &str) -> StdResult<Self, Self::Err> {
let mut cursor = SliceCursor::new(s.as_bytes());
let negative = cursor.next_if(|&b| b == b'-').is_some();
cursor
.next_if(|&b| b == b'P')
.ok_or_else(|| from_str_error("Expected a 'P'"))?;
let mut vals: [Option<u64>; 6] = [None; 6];
let mut prev_index = None;
parse_unit_domain::<DateUnit>(&mut vals, &mut prev_index, &mut cursor)?;
match cursor.peek().copied() {
Some(b'T') => {
cursor.next();
let prev_index_before = prev_index;
parse_unit_domain::<TimeUnit>(&mut vals, &mut prev_index, &mut cursor)?;
if prev_index_before == prev_index {
return Err(from_str_error("Time separator must be followed by a value"));
}
}
Some(_) => return Err(from_str_error("Syntax error")),
None => {}
};
if prev_index.is_none() {
return Err(from_str_error("No values specified"));
}
if cursor.next().is_some() {
return Err(from_str_error("Unexpected trailing character"));
}
Ok(Self {
negative,
years: convert_to(vals[0], "Year value exceeds range")?,
months: convert_to(vals[1], "Months exceed range")?,
days: convert_to(vals[2], "Days exceed range")?,
hours: convert_to(vals[3], "Hours exceed range")?,
mins: convert_to(vals[4], "Minutes exceed range")?,
millis: vals[5].unwrap_or_default(),
})
}
}
fn parse_unit_domain<T: UnitDomain>(
vals: &mut [Option<u64>; 6],
prev_index: &mut Option<usize>,
cursor: &mut SliceCursor<u8>,
) -> StdResult<(), FromStrError> {
for _ in 0..3 {
match cursor.peek().copied() {
Some(b'T') => break,
None => break,
Some(_) => {
let val = T::parse(cursor)?;
let index = val.index();
if Some(index) < *prev_index {
return Err(from_str_error("Wrong unit order"));
}
*prev_index = Some(index);
if vals[val.index()].replace(val.val()).is_some() {
return Err(from_str_error("Duplicate unit definition"));
}
}
}
}
Ok(())
}
fn convert_to<T, U>(t: Option<T>, msg: impl Into<String>) -> StdResult<U, FromStrError>
where
T: TryInto<U> + Default,
{
t.unwrap_or_default()
.try_into()
.map_err(|_| from_str_error(msg))
}
impl fmt::Display for Duration {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
if self.negative {
f.write_char('-')?;
}
f.write_char('P')?;
if self.years == 0
&& self.months == 0
&& self.days == 0
&& self.hours == 0
&& self.mins == 0
&& self.millis == 0
{
return f.write_str("T0S");
}
if self.years > 0 {
f.write_fmt(format_args!("{}Y", self.years))?;
}
if self.months > 0 {
f.write_fmt(format_args!("{}M", self.months))?;
}
if self.days > 0 {
f.write_fmt(format_args!("{}D", self.days))?;
}
if self.hours as u64 + self.mins as u64 + self.millis > 0 {
f.write_char('T')?;
if self.hours > 0 {
f.write_fmt(format_args!("{}H", self.hours))?;
}
if self.mins > 0 {
f.write_fmt(format_args!("{}M", self.mins))?;
}
if self.millis > 0 {
f.write_fmt(format_args!("{}", self.millis / 1000))?;
if self.millis % 1000 > 0 {
f.write_fmt(format_args!(".{:03}", self.millis % 1000))?;
}
f.write_fmt(format_args!("S"))?;
}
}
Ok(())
}
}
impl UnitDomain for DateUnit {
fn index(&self) -> usize {
match self {
DateUnit::Years(_) => 0,
DateUnit::Months(_) => 1,
DateUnit::Days(_) => 2,
}
}
fn val(&self) -> u64 {
match self {
DateUnit::Years(n) => *n as u64,
DateUnit::Months(n) => *n as u64,
DateUnit::Days(n) => *n as u64,
}
}
fn parse(cursor: &mut SliceCursor<u8>) -> StdResult<Self, FromStrError> {
let n: u32 = parse_int(cursor)?;
match cursor.next().copied() {
Some(b'Y') => {
Ok(DateUnit::Years(n.try_into().map_err(|e| {
from_str_error(format!("Invalid years value: {e}"))
})?))
}
Some(b'M') => {
Ok(DateUnit::Months(n.try_into().map_err(|e| {
from_str_error(format!("Invalid months value: {e}"))
})?))
}
Some(b'D') => Ok(DateUnit::Days(n)),
Some(_) => Err(from_str_error("Invalid unit specifier")),
None => Err(from_str_error("Missing unit specifier")),
}
}
}
impl UnitDomain for TimeUnit {
fn index(&self) -> usize {
match self {
TimeUnit::Hours(_) => 3,
TimeUnit::Minutes(_) => 4,
TimeUnit::Millis(_) => 5,
}
}
fn val(&self) -> u64 {
match self {
TimeUnit::Hours(n) => *n as u64,
TimeUnit::Minutes(n) => *n as u64,
TimeUnit::Millis(n) => *n,
}
}
fn parse(cursor: &mut SliceCursor<u8>) -> StdResult<Self, FromStrError> {
let int: u64 = parse_int(cursor)?;
let decimals = cursor
.next_if(|&c| c == b'.')
.and_then(|_| {
let fract_bytes = cursor.next_while(|c| c.is_ascii_digit());
// only support up to 3 decimal digits
let fract_bytes = &fract_bytes[..fract_bytes.len().min(4)];
std::str::from_utf8(fract_bytes)
.ok()?
.parse::<u32>() // this fails if the string is empty
.ok()
.map(|f| f * 10u32.pow(3 - fract_bytes.len() as u32))
})
.map(|n| n as u64);
match cursor.next().copied() {
Some(b'H') => {
if decimals.is_none() {
Ok(Self::Hours(int.try_into().map_err(|e| {
from_str_error(format!("Invalid hours value: {e}"))
})?))
} else {
Err(from_str_error("Hours cannot contain a fractional part"))
}
}
Some(b'M') => {
if decimals.is_none() {
Ok(Self::Minutes(int.try_into().map_err(|e| {
from_str_error(format!("Invalid minutes value: {e}"))
})?))
} else {
Err(from_str_error("Minutes cannot contain a fractional part"))
}
}
Some(b'S') => int
.checked_mul(1000)
.and_then(|ms| ms.checked_add(decimals.unwrap_or(0)))
.map(Self::Millis)
.ok_or_else(|| from_str_error("Integer overflow while parsing seconds")),
Some(_) => Err(from_str_error("Invalid unit specified")),
None => Err(from_str_error("Missing unit specifier")),
}
}
}
fn parse_int<F>(cursor: &mut SliceCursor<u8>) -> StdResult<F, FromStrError>
where
F: FromStr,
F::Err: fmt::Display,
{
let num = std::str::from_utf8(cursor.next_while(|&c| c.is_ascii_digit()))
.expect("ASCII digits are always valid UTF-8");
num.parse::<F>()
.map_err(|e| from_str_error(format!("Cannot parse number: {e}")))
}
impl fmt::Display for FromStrError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(&self.msg)
}
}
impl StdError for FromStrError {
fn description(&self) -> &str {
&self.msg
}
}
#[cfg(test)]
mod tests {
use super::Duration;
use std::str::FromStr;
#[test]
fn set_vals() {
let mut duration = Duration::default();
duration.set_years(1);
assert_eq!(duration.years(), 1);
duration.set_months(2);
assert_eq!(duration.months(), 2);
duration.set_days(3);
assert_eq!(duration.days(), 3);
duration.set_hours(4);
assert_eq!(duration.hours(), 4);
duration.set_mins(5);
assert_eq!(duration.mins(), 5);
duration.set_secs(6);
assert_eq!(duration.secs(), 6);
duration.set_millis(7);
assert_eq!(duration.millis(), 7);
assert_eq!(duration.secs(), 6);
}
#[test]
fn from_str() {
let test_values = [
(
"-P1Y2M3DT4H5M6.7S",
(true, 1u16, 2u32, 3u32, 4u32, 5u32, 6u64, 700u16),
"-P1Y2M3DT4H5M6.700S",
),
(
"P7Y6M5DT4H3M2S",
(false, 7, 6, 5, 4, 3, 2, 0),
"P7Y6M5DT4H3M2S",
),
("PT0S", (false, 0, 0, 0, 0, 0, 0, 0), "PT0S"),
("P1Y", (false, 1, 0, 0, 0, 0, 0, 0), "P1Y"),
("P1M", (false, 0, 1, 0, 0, 0, 0, 0), "P1M"),
("P1D", (false, 0, 0, 1, 0, 0, 0, 0), "P1D"),
("PT1H", (false, 0, 0, 0, 1, 0, 0, 0), "PT1H"),
("PT1M", (false, 0, 0, 0, 0, 1, 0, 0), "PT1M"),
("PT1S", (false, 0, 0, 0, 0, 0, 1, 0), "PT1S"),
("PT0.001S", (false, 0, 0, 0, 0, 0, 0, 1), "PT0.001S"),
];
for (input, (negative, years, months, days, hours, mins, secs, millis), output) in
test_values
{
let duration = Duration::from_str(input).unwrap();
assert_eq!(duration.is_negative(), negative);
assert_eq!(duration.years(), years);
assert_eq!(duration.months(), months);
assert_eq!(duration.days(), days);
assert_eq!(duration.hours(), hours);
assert_eq!(duration.mins(), mins);
assert_eq!(duration.secs(), secs);
assert_eq!(duration.millis(), millis);
assert_eq!(duration.to_string(), output);
}
let bad_values = [
"P",
"-P",
"P1H",
"P1HT",
"P1S",
"P1ST",
"PT1Y",
"PT1D",
"PT",
"P1YT",
"P65536Y",
"P4294967296M",
"P4294967296D",
"PT4294967296H",
"PT4294967296M",
"P1DT18446744073709551.616S",
];
for input in bad_values {
let result = Duration::from_str(input);
eprintln!("{result:?}");
assert!(result.is_err());
}
}
}

19
src/util/xsd/mod.rs Normal file
View file

@ -0,0 +1,19 @@
mod duration;
pub use duration::Duration;
#[derive(Debug)]
pub struct FromStrError {
msg: String,
}
fn from_str_error<T, O>(msg: O) -> FromStrError
where
T: Into<String>,
O: Into<Option<T>>,
{
let msg = msg
.into()
.map(|t| t.into())
.unwrap_or_else(|| String::from("Invalid value"));
FromStrError { msg }
}