From 88b7268dcc965b94932988b2fd6be66161a81b47 Mon Sep 17 00:00:00 2001 From: fef Date: Thu, 26 Jan 2023 17:47:31 +0100 Subject: [PATCH] util: support parsing non-RFC8941 headers --- src/ap/loader/content_type.rs | 12 +- src/ap/loader/link.rs | 54 ++-- src/util/header.rs | 491 ++++++++++++++++++++++++++++------ src/util/slice.rs | 88 ++++++ 4 files changed, 525 insertions(+), 120 deletions(-) diff --git a/src/ap/loader/content_type.rs b/src/ap/loader/content_type.rs index 8a79ddf..66273bd 100644 --- a/src/ap/loader/content_type.rs +++ b/src/ap/loader/content_type.rs @@ -12,7 +12,7 @@ use mime::Mime; use reqwest::header::HeaderValue; use std::str::FromStr; -use crate::util::header::{Item, ParseHeader}; +use crate::util::header::{Item, ParseHeader, ParseOptions}; /// Helper structure for parsing the `Content-Type` header for JSON-LD. pub struct ContentType { @@ -22,13 +22,11 @@ pub struct ContentType { impl ContentType { pub fn from_header(value: &HeaderValue) -> Option { - let item = Item::parse_from_header(value, true).ok()?; + let item = Item::parse_from_header(value, ParseOptions::rfc8941()).ok()?; let mime = Mime::from_str(item.as_token()?).ok()?; - let profile = item.param("profile").and_then(|profile| { - profile - .as_string() - .or_else(|| profile.as_token().map(String::from)) - }); + let profile = item + .param("profile") + .and_then(|profile| profile.as_string_or_token()); Some(ContentType { mime, profile }) } diff --git a/src/ap/loader/link.rs b/src/ap/loader/link.rs index 1a6ceed..843e2ad 100644 --- a/src/ap/loader/link.rs +++ b/src/ap/loader/link.rs @@ -12,7 +12,7 @@ use iref::{IriRef, IriRefBuf}; use reqwest::header::HeaderValue; use std::str::FromStr; -use crate::util::header::{Item, ParseHeader}; +use crate::util::header::{Item, ParseHeader, ParseOptions}; pub struct Link { href: IriRefBuf, @@ -22,7 +22,7 @@ pub struct Link { impl Link { pub fn from_header(value: &HeaderValue) -> Option { - let item = Item::parse_from_header(value, false).unwrap(); + let item = Item::parse_from_header(value, ParseOptions::link_header()).ok()?; let href = IriRefBuf::from_str(item.as_url()?).ok()?; let rel = item.param("rel").and_then(|rel| rel.as_string()); let typ = item.param("type").and_then(|typ| typ.as_string()); @@ -43,8 +43,8 @@ impl Link { pub fn is_proper_json_ld(&self) -> bool { self.typ() - .map(|typ| ["application/activity+json", "application/ld+json"].contains(&typ)) - .unwrap_or(false) + .filter(|t| ["application/activity+json", "application/ld+json"].contains(t)) + .is_some() } } @@ -52,13 +52,14 @@ impl Link { mod tests { use super::*; + fn mklink(header_value: &'static str) -> Option { + Link::from_header(&HeaderValue::from_str(header_value).unwrap()) + } + #[test] fn parse_link_1() { - let link = Link::from_header( - &HeaderValue::from_str( - "; rel=\"context\"; type=\"application/ld+json\"", - ) - .unwrap(), + let link = mklink( + "; rel=\"context\"; type=\"application/ld+json\"", ) .unwrap(); assert_eq!(link.href(), "http://www.example.org/context"); @@ -68,7 +69,10 @@ mod tests { #[test] fn parse_link_2() { - let link = Link::from_header(&HeaderValue::from_str("; rel=\"context\"; type=\"application/ld+json\"; foo=\"bar\"").unwrap()).unwrap(); + let link = mklink( + "; rel = \"context\" ; type=\"application/ld+json\" ; foo=\"bar\"", + ) + .unwrap(); assert_eq!(link.href(), "http://www.example.org/context"); assert_eq!(link.rel(), Some("context")); assert_eq!(link.typ(), Some("application/ld+json")) @@ -76,39 +80,27 @@ mod tests { #[test] fn parse_link_3() { - let link = - Link::from_header(&HeaderValue::from_str("").unwrap()) - .unwrap(); + let link = mklink("").unwrap(); assert_eq!(link.href(), "http://www.example.org/context") } #[test] fn is_proper_json_ld() { - let link = Link::from_header( - &HeaderValue::from_str( - "; rel=\"context\"; type=\"application/ld+json\"", - ) - .unwrap(), + let link = mklink( + "; rel=\"context\"; type=\"application/ld+json\"", ) .unwrap(); assert!(link.is_proper_json_ld()); - let link = Link::from_header( - &HeaderValue::from_str( - "; rel=\"context\"; type=\"application/activity+json\"", - ) - .unwrap(), + let link = mklink( + "; rel=\"context\"; type=\"application/activity+json\"", ) - .unwrap(); + .unwrap(); assert!(link.is_proper_json_ld()); - let link = Link::from_header( - &HeaderValue::from_str( - "; rel=\"context\"; type=\"application/json\"", - ) - .unwrap(), - ) - .unwrap(); + let link = + mklink("; rel=\"context\"; type=\"application/json\"") + .unwrap(); assert!(!link.is_proper_json_ld()); } } diff --git a/src/util/header.rs b/src/util/header.rs index e139622..4cdc9ca 100644 --- a/src/util/header.rs +++ b/src/util/header.rs @@ -8,40 +8,206 @@ use crate::util::transcode; /// Parse an HTTP Structured Field Value according to /// [RFC 8941](https://www.rfc-editor.org/info/rfc8941). +/// +/// Note: This parser is only compliant with RFC 8941 in strict mode; +/// see [`ParseOptions::strict`] for details. +/// /// Note: This only parses one "line" although the RFC says conforming /// software MUST support values split over several headers. +/// If you wish to comply with the RFC, you MUST call this for every header +/// line individually. pub trait ParseHeader<'a>: Sized { - fn parse_from_ascii(header: &'a [u8], strict: bool) -> Result; + fn parse_from_ascii(header: &'a [u8], options: ParseOptions) -> Result; - fn parse_from_header(header: &'a HeaderValue, strict: bool) -> Result { - Self::parse_from_ascii(header.as_bytes(), strict) + fn parse_from_header(header: &'a HeaderValue, options: ParseOptions) -> Result { + Self::parse_from_ascii(header.as_bytes(), options) } } +/// Options for the header parser. The default is strict mode, i.e. conforming to +/// [RFC 8941](https://www.rfc-editor.org/info/rfc8941) except for multiline headers. +pub struct ParseOptions { + strict: bool, + allow_utf8: bool, + allow_url: bool, + allow_param_bws: bool, + max_dict_members: usize, + max_list_members: usize, + max_inner_list_members: usize, + max_params: usize, +} + +impl ParseOptions { + /// Return the default options; see [`ParseOptions::default()`]. + pub fn new() -> Self { + Self::default() + } + + /// Return options for parsing an HTTP header defined on top of RFC 8941. + /// This is currently the default. + pub fn rfc8941() -> Self { + ParseOptions { + strict: true, + allow_utf8: false, + allow_url: false, + allow_param_bws: false, + max_dict_members: 1024, + max_list_members: 1024, + max_inner_list_members: 256, + max_params: 256, + } + } + + /// Return parser options suitable for parsing the HTTP `Link` header as + /// defined in section 3 of [RFC 8288](https://www.rfc-editor.org/info/rfc8288). + pub fn link_header() -> Self { + Self::default() + .strict(false) + .allow_utf8(true) + .allow_url(true) + .allow_param_bws(true) + } + + /// Enable strict mode, i.e. fully comply with RFC 8941 (except for the + /// multiline header thing; consumers of this utility MUST call the parser + /// on every header value with the same name manually). + /// + /// This option exists because the parser is also useful for headers that + /// *almost* conform to the RFC with only some minor deviations (e.g. the + /// `Link` header, which allows URLs enclosed in angle brackets). + /// When parsing a header that is defined based on RFC 8941, this option + /// MUST be set to `true`. + /// + /// This option takes precedence over all other ones and defaults to `true`. + pub fn strict(mut self, strict: bool) -> Self { + self.strict = strict; + self + } + + /// Accept the entire UTF-8 alphabet instead of just ASCII. + /// Strict mode implies this is `false`. + pub fn allow_utf8(mut self, allow_utf8: bool) -> Self { + self.allow_utf8 = allow_utf8; + self + } + + /// Enable the non-standard URL Item type for values enclosed in angle + /// brackets (`<>`). Strict mode implies this is `false`. + pub fn allow_url(mut self, allow_url: bool) -> Self { + self.allow_url = allow_url; + self + } + + /// Allow "bad" whitespace (as per the BWS rule in section 3.2.3 of + /// [RFC 7230](https://www.rfc-editor.org/info/rfc7230)) before and after + /// the `=` token in parameters, as well as before the semicolon. + /// Strict mode implies this is `false`. + pub fn allow_param_bws(mut self, allow_param_bws: bool) -> Self { + self.allow_param_bws = allow_param_bws; + self + } + + /// Maximum number of members to allow in a Dictionary (minimum 1). + /// Strict mode implies this is no less than 1024. + pub fn max_dict_members(mut self, max_dict_members: usize) -> Self { + self.max_dict_members = max_dict_members.max(1); + self + } + + /// Maximum number of members to allow in a List (minimum 1). + /// Strict mode implies this is no less than 1024. + pub fn max_list_members(mut self, max_list_members: usize) -> Self { + self.max_list_members = max_list_members.max(1); + self + } + + /// Maximum number of members to allow in an Inner List (minimum 1). + /// Strict mode implies this is no less than 256. + pub fn max_inner_list_members(mut self, max_inner_list_members: usize) -> Self { + self.max_inner_list_members = max_inner_list_members.max(1); + self + } + + /// Maximum number of parameters to allow on Items (minimum 1). + /// Strict mode implies this is no less than 256. + pub fn max_params(mut self, max_params: usize) -> Self { + self.max_params = max_params.max(1); + self + } + + /// In strict mode, override all options to comply with RFC 8941. + fn normalize(mut self) -> Self { + if self.strict { + if self.allow_utf8 { + debug!("Strict mode enabled, overriding allow_utf8 to false"); + self.allow_utf8 = false; + } + if self.allow_url { + debug!("Strict mode enabled, overriding allow_url to false"); + self.allow_url = false; + } + if self.allow_param_bws { + debug!("Strict mode enabled, overriding allow_param_bws to false"); + self.allow_param_bws = false; + } + if self.max_dict_members < 1024 { + debug!("Strict mode enabled, overriding max_dict_members to 1024"); + self.max_dict_members = 1024; + } + if self.max_list_members < 1024 { + debug!("Strict mode enabled, overriding max_list_members to 1024"); + self.max_list_members = 1024; + } + if self.max_inner_list_members < 256 { + debug!("Strict mode enabled, overriding max_inner_list_members to 256"); + self.max_inner_list_members = 256; + } + if self.max_params < 256 { + debug!("Strict mode enabled, overriding max_params to 256"); + self.max_params = 256; + } + } + + self + } +} + +impl Default for ParseOptions { + fn default() -> Self { + Self::rfc8941() + } +} + +/// A Dictionary (section 3.2). #[derive(Debug, PartialEq)] pub struct Dictionary<'a>(Vec<(&'a str, Member<'a>)>); +/// A List (section 3.1). #[derive(Debug, PartialEq)] pub struct List<'a>(Vec>); +/// A Member of a List or Dictionary. #[derive(Debug, PartialEq)] pub enum Member<'a> { Item(Item<'a>), InnerList(InnerList<'a>), } +/// An Inner List (section 3.1.1). #[derive(Debug, PartialEq)] pub struct InnerList<'a> { items: Vec>, params: Vec<(&'a str, BareItem<'a>)>, } +/// An Item (section 3.3). #[derive(Debug, PartialEq)] pub struct Item<'a> { bare_item: BareItem<'a>, params: Vec<(&'a str, BareItem<'a>)>, } +/// An Item without Parameters. #[derive(Debug, PartialEq)] pub enum BareItem<'a> { Integer(i64), @@ -63,13 +229,17 @@ pub struct UrlItem<'a>(&'a str); pub struct ByteSequenceItem<'a>(&'a str); impl<'a> ParseHeader<'a> for Dictionary<'a> { - fn parse_from_ascii(header: &'a [u8], strict: bool) -> Result { - Parser::new(header, strict)?.parse_dictionary() + fn parse_from_ascii(header: &'a [u8], options: ParseOptions) -> Result { + Parser::new(header, options)?.parse_dictionary() } } impl<'a> Dictionary<'a> { - pub fn get(&self, key: &'a str) -> Option<&Member<'a>> { + pub fn get<'k, K>(&self, key: K) -> Option<&Member<'a>> + where + K: Into<&'k str>, + { + let key = key.into(); self.0.iter().find_map(|(k, v)| key.eq(*k).then_some(v)) } @@ -79,8 +249,8 @@ impl<'a> Dictionary<'a> { } impl<'a> ParseHeader<'a> for List<'a> { - fn parse_from_ascii(header: &'a [u8], strict: bool) -> Result { - Parser::new(header, strict)?.parse_list() + fn parse_from_ascii(header: &'a [u8], options: ParseOptions) -> Result { + Parser::new(header, options)?.parse_list() } } @@ -99,8 +269,8 @@ impl<'a> List<'a> { } impl<'a> ParseHeader<'a> for Item<'a> { - fn parse_from_ascii(header: &'a [u8], strict: bool) -> Result { - Parser::new(header, strict)?.parse_item(!strict) + fn parse_from_ascii(header: &'a [u8], options: ParseOptions) -> Result { + Parser::new(header, options)?.parse_item() } } @@ -109,9 +279,9 @@ impl<'a> Item<'a> { self.params.as_slice() } - pub fn param(&self, key: K) -> Option<&BareItem<'a>> + pub fn param<'k, K>(&self, key: K) -> Option<&BareItem<'a>> where - K: Into<&'a str>, + K: Into<&'k str>, { let key = key.into(); self.params @@ -119,9 +289,19 @@ impl<'a> Item<'a> { .find_map(|(k, v)| key.eq(*k).then_some(v)) } - pub fn has_param(&self, key: K) -> bool + pub fn param_nocase<'k, K>(&self, key: K) -> Option<&BareItem<'a>> where - K: Into<&'a str>, + K: Into<&'k str>, + { + let key = key.into(); + self.params + .iter() + .find_map(|(k, v)| key.eq_ignore_ascii_case(k).then_some(v)) + } + + pub fn has_param<'k, K>(&self, key: K) -> bool + where + K: Into<&'k str>, { let key = key.into(); self.params.iter().any(|(k, _)| key.eq(*k)) @@ -143,6 +323,10 @@ impl<'a> Item<'a> { self.bare_item.as_token() } + pub fn as_string_or_token(&self) -> Option { + self.bare_item.as_string_or_token() + } + pub fn as_url(&self) -> Option<&'a str> { self.bare_item.as_url() } @@ -185,6 +369,14 @@ impl<'a> BareItem<'a> { } } + pub fn as_string_or_token(&self) -> Option { + match self { + BareItem::String(s) => Some(remove_escapes_stupid(s.0)), + BareItem::Token(t) => Some(String::from(t)), + _ => None, + } + } + pub fn as_url(&self) -> Option<&'a str> { match self { BareItem::Url(u) => Some(u.0), @@ -321,23 +513,28 @@ impl<'a> InnerList<'a> { } } +/// Internal implementation of Structured Field Values. +/// Parsing methods have their respective production rules in the doc comment, +/// which was extracted from the RFC. See section 1.2 for details. struct Parser<'a> { cursor: SliceCursor<'a, u8>, - strict: bool, + options: ParseOptions, } impl<'a> Parser<'a> { - fn new(data: &'a [u8], strict: bool) -> Result { - if data.is_ascii() || (std::str::from_utf8(data).is_ok() && !strict) { - Ok(Parser { - cursor: SliceCursor::new(data), - strict, - }) - } else { - Err(Error::BadHeader(String::from( - "RFC 8941 prohibits non-ASCII characters", - ))) + fn new(data: &'a [u8], options: ParseOptions) -> Result { + let options = options.normalize(); + + if options.allow_utf8 { + std::str::from_utf8(data).map_err(|e| Error::BadHeader(e.to_string()))?; + } else if !data.is_ascii() { + return Err(Error::BadHeader(String::from("Not an ASCII string"))); } + + Ok(Parser { + cursor: SliceCursor::new(data), + options, + }) } /// Parse a full List (section 3.1). @@ -348,15 +545,17 @@ impl<'a> Parser<'a> { fn parse_list(&mut self) -> Result> { let mut members = Vec::with_capacity(1); members.push(self.parse_list_member()?); - self.skip_whitespace(); + self.skip_ows(); while self.skip_if(|c| c == b',') { - self.skip_whitespace(); + self.skip_ows(); members.push(self.parse_list_member()?); - // > Parsers MUST support Lists containing at least 1024 members. - if members.len() == 1024 { - break; + if members.len() > self.options.max_list_members { + return Err(self.make_error(format!( + "List exceeds configured member limit of {}", + self.options.max_list_members + ))); } - self.skip_whitespace(); + self.skip_ows(); } Ok(List(members)) } @@ -370,7 +569,7 @@ impl<'a> Parser<'a> { if self.cursor.peek().copied() == Some(b'(') { self.parse_inner_list().map(Member::InnerList) } else { - self.parse_item(false).map(Member::Item) + self.parse_item().map(Member::Item) } } @@ -382,16 +581,17 @@ impl<'a> Parser<'a> { fn parse_dictionary(&mut self) -> Result> { let mut members = Vec::with_capacity(1); members.push(self.parse_dict_member()?); - self.skip_whitespace(); + self.skip_ows(); while self.skip_if(|c| c == b',') { - self.skip_whitespace(); + self.skip_ows(); members.push(self.parse_dict_member()?); - // > Parsers MUST support Dictionaries containing at least - // > 1024 key/value pairs and keys with at least 64 characters. - if members.len() == 1024 { - break; + if members.len() > self.options.max_dict_members { + return Err(self.make_error(format!( + "Dictionary exceeds configured member limit of {}", + self.options.max_dict_members + ))); } - self.skip_whitespace(); + self.skip_ows(); } Ok(Dictionary(members)) } @@ -407,12 +607,20 @@ impl<'a> Parser<'a> { // member-key let key = self.parse_key()?; + if self.options.allow_param_bws { + self.skip_bws_if_next_matches(|c| c == b'='); + } + let val = if self.skip_if(|c| c == b'=') { + if self.options.allow_param_bws { + self.skip_bws(); + } + // member-value if self.cursor.peek().copied() == Some(b'(') { Member::InnerList(self.parse_inner_list()?) } else { - Member::Item(self.parse_item(false)?) + Member::Item(self.parse_item()?) } } else { // parameters @@ -438,7 +646,7 @@ impl<'a> Parser<'a> { if self.skip_if(|c| c == b')') { break; } - items.push(self.parse_item(false)?); + items.push(self.parse_item()?); // > Parsers MUST support Inner Lists containing at least 256 members. if items.len() == 256 { break; @@ -459,8 +667,8 @@ impl<'a> Parser<'a> { /// ```notrust /// sf-item = bare-item parameters /// ``` - fn parse_item(&mut self, allow_url: bool) -> Result> { - let bare_item = self.parse_bare_item(allow_url)?; + fn parse_item(&mut self) -> Result> { + let bare_item = self.parse_bare_item()?; let params = self.parse_parameters()?; Ok(Item { bare_item, params }) } @@ -469,14 +677,34 @@ impl<'a> Parser<'a> { /// /// ```notrust /// parameters = *( ";" *SP parameter ) + /// + /// ; deviations in non-strict mode: + /// parameters = *( ";" OWS parameter ) + /// + /// ; deviations if allow_param_bws: + /// parameters = *( BWS ";" OWS parameter ) /// ``` fn parse_parameters(&mut self) -> Result)>> { let mut params = Vec::new(); + if self.options.allow_param_bws { + self.skip_bws_if_next_matches(|c| c == b';'); + } while self.skip_if(|c| c == b';') { - self.skip_sp(); + if self.options.strict { + self.skip_sp(); + } else { + self.skip_ows(); + } + params.push(self.parse_parameter()?); - if params.len() == 256 { - break; + if params.len() > self.options.max_params { + return Err(self.make_error(format!( + "Parameter count exceeds configured limit of {}", + self.options.max_params + ))); + } + if self.options.allow_param_bws { + self.skip_bws_if_next_matches(|c| c == b';'); } } Ok(params) @@ -488,11 +716,23 @@ impl<'a> Parser<'a> { /// parameter = param-key [ "=" param-value ] /// param-key = key /// param-value = bare-item + /// + /// ; deviations if allow_param_bws: + /// parameter = token [ BWS "=" BWS bare-item ] /// ``` fn parse_parameter(&mut self) -> Result<(&'a str, BareItem<'a>)> { let key = self.parse_key()?; + + if self.options.allow_param_bws { + self.skip_bws_if_next_matches(|c| c == b'='); + } + let value = if self.skip_if(|c| c == b'=') { - self.parse_bare_item(false)? + if self.options.allow_param_bws { + self.skip_bws(); + } + + self.parse_bare_item()? } else { BareItem::Boolean(true) }; @@ -519,7 +759,7 @@ impl<'a> Parser<'a> { /// bare-item = sf-integer / sf-decimal / sf-string /// / sf-token / sf-binary / sf-boolean /// ``` - fn parse_bare_item(&mut self, allow_url: bool) -> Result> { + fn parse_bare_item(&mut self) -> Result> { match self .cursor .peek() @@ -528,7 +768,7 @@ impl<'a> Parser<'a> { { c if is_numeric_start(c) => self.parse_numeric(), b'"' => self.parse_string(), - b'<' if allow_url => self.parse_url(), + b'<' => self.parse_url(), c if is_token_start(c) => self.parse_token(), b':' => self.parse_byte_sequence(), b'?' => self.parse_boolean(), @@ -566,9 +806,17 @@ impl<'a> Parser<'a> { /// chr = unescaped / escaped /// unescaped = %x20-21 / %x23-5B / %x5D-7E /// escaped = "\" ( DQUOTE / "\" ) + /// + /// ; deviations if allow_utf8: + /// unescaped = %x20-21 / %x23-5B / %x5D-7E / %x80-FF /// ``` fn parse_string(&mut self) -> Result> { self.assert_next(|c| c == b'"')?; + let is_allowed_char = if self.options.allow_utf8 { + is_string_part_utf8 + } else { + is_string_part + }; self.chop(); loop { @@ -577,8 +825,8 @@ impl<'a> Parser<'a> { b'\\' => { self.assert_next(|c| c == b'\\' || c == b'"')?; } - c if is_string_part(c) => continue, - _ => return Err(self.make_error("Unexpected character in string")), + c if is_allowed_char(c) => continue, + c => return Err(self.make_error(format!("Unexpected character {:?} in string", c))), } } let slice = self.chop(); @@ -586,19 +834,26 @@ impl<'a> Parser<'a> { Ok(BareItem::String(StringItem(slice))) } + /// Parse a non-standard URL item if `allow_url` is enabled in the options. fn parse_url(&mut self) -> Result> { - if self.strict { - return Err( - self.make_error("URLs enclosed in are forbidden in strict mode") - ); - } + if self.options.allow_url { + self.assert_next(|c| c == b'<')?; - self.assert_next(|c| c == b'<')?; - self.chop(); - self.skip_while(|c| c != b'>'); - let slice = self.chop(); - self.assert_next(|c| c == b'>')?; - Ok(BareItem::Url(UrlItem(slice))) + self.chop(); + if self.options.allow_utf8 { + self.skip_while(is_url_part_utf8); + } else { + self.skip_while(is_url_part); + } + let slice = self.chop(); + + self.assert_next(|c| c == b'>')?; + Ok(BareItem::Url(UrlItem(slice))) + } else { + Err(self.make_error( + "allow_url is disabled, refusing to parse URL enclosed in ", + )) + } } /// Parse a Token item (section 3.3.4). @@ -659,10 +914,41 @@ impl<'a> Parser<'a> { self.cursor.next_while(|&c| c == b' ').len() } - fn skip_whitespace(&mut self) -> usize { + /// Skip optional whitespace as per section 3.2.3 of + /// [RFC 7230](https://www.rfc-editor.org/info/rfc7230). + /// + /// ```notrust + /// OWS = *( SP / HTAB ) + /// ; optional whitespace + /// ``` + fn skip_ows(&mut self) -> usize { self.cursor.next_while(|&c| c == b' ' || c == b'\t').len() } + /// Skip "bad" whitespace as per section 3.2.3 of + /// [RFC 7230](https://www.rfc-editor.org/info/rfc7230). + /// + /// ```notrust + /// BWS = OWS + /// ; "bad" whitespace + /// ``` + fn skip_bws(&mut self) -> usize { + self.skip_ows() + } + + /// Skip "bad" whitespace (see [`Self::skip_bws`]) if the first character + /// after the whitespace matches `predicate`. + /// The cursor will point to the last whitespace character. + fn skip_bws_if_next_matches(&mut self, predicate: F) -> Option + where + F: FnOnce(u8) -> bool, + { + self.cursor.attempt(|cursor| { + let bws_count = cursor.next_while(|&c| c == b' ' || c == b'\t').len(); + cursor.peek().filter(|&&c| predicate(c)).map(|_| bws_count) + }) + } + fn assert_next(&mut self, predicate: F) -> Result where F: FnOnce(u8) -> bool, @@ -710,42 +996,65 @@ impl<'a> Parser<'a> { } } -fn is_numeric_start(c: u8) -> bool { +const fn is_numeric_start(c: u8) -> bool { c.is_ascii_digit() || c == b'-' } -fn is_string_start(c: u8) -> bool { - c == b'"' +const fn is_string_part(c: u8) -> bool { + matches!(c, b'\x20'..=b'\x21' | b'\x23'..=b'\x5b' | b'\x5d'..=b'\x7e') +} + +const fn is_string_part_utf8(c: u8) -> bool { + !c.is_ascii() || is_string_part(c) +} + +const fn is_url_part(c: u8) -> bool { + c != b'>' && is_string_part(c) } -fn is_string_part(c: u8) -> bool { - (b'\x20'..=b'\x21').contains(&c) - || (b'\x23'..=b'\x5b').contains(&c) - || (b'\x5d'..=b'\x7e').contains(&c) +const fn is_url_part_utf8(c: u8) -> bool { + c != b'>' && is_string_part_utf8(c) } -fn is_token_start(c: u8) -> bool { +const fn is_token_start(c: u8) -> bool { c.is_ascii_alphabetic() || c == b'*' } -fn is_tchar(c: u8) -> bool { - c.is_ascii_alphanumeric() || b"!#$%&'*+-.^_`|~".contains(&c) +const fn is_tchar(c: u8) -> bool { + c.is_ascii_alphanumeric() + || matches!( + c, + b'!' | b'#' + | b'$' + | b'%' + | b'&' + | b'\'' + | b'*' + | b'+' + | b'-' + | b'.' + | b'^' + | b'_' + | b'`' + | b'|' + | b'~' + ) } -fn is_byte_sequence_start(c: u8) -> bool { +const fn is_byte_sequence_start(c: u8) -> bool { c == b':' } -fn is_base64(c: u8) -> bool { - c.is_ascii_alphanumeric() || c == b'+' || c == b'/' || c == b'=' +const fn is_base64(c: u8) -> bool { + c.is_ascii_alphanumeric() || matches!(c, b'+' | b'/' | b'=') } -fn is_key_start(c: u8) -> bool { +const fn is_key_start(c: u8) -> bool { c.is_ascii_lowercase() || c == b'*' } -fn is_key_part(c: u8) -> bool { - c.is_ascii_lowercase() || c.is_ascii_digit() || b"_-.*".contains(&c) +const fn is_key_part(c: u8) -> bool { + c.is_ascii_lowercase() || c.is_ascii_digit() || matches!(c, b'_' | b'-' | b'.' | b'*') } fn remove_escapes_stupid(s: &str) -> String { @@ -787,15 +1096,15 @@ mod tests { use crate::util::transcode::base64_decode; fn mklist(header: &'static str) -> Result> { - List::parse_from_ascii(header.as_bytes(), true) + List::parse_from_ascii(header.as_bytes(), Default::default()) } fn mkdict(header: &'static str) -> Result> { - Dictionary::parse_from_ascii(header.as_bytes(), true) + Dictionary::parse_from_ascii(header.as_bytes(), Default::default()) } fn mkitem(header: &'static str) -> Result> { - Item::parse_from_ascii(header.as_bytes(), true) + Item::parse_from_ascii(header.as_bytes(), Default::default()) } #[test] @@ -928,6 +1237,24 @@ mod tests { ); } + #[test] + fn parse_item_url() { + let header = r#"; type="text/html""#; + + assert!(mkitem(header).is_err()); + + let item = Item::parse_from_ascii( + header.as_bytes(), + ParseOptions::default().strict(false).allow_url(true), + ) + .unwrap(); + assert_eq!(item.as_url(), Some("https://example.com/a")); + assert_eq!( + item.param("type").unwrap().as_string(), + Some("text/html".into()) + ); + } + #[test] fn parse_item_byte_sequence() { let base64_str = diff --git a/src/util/slice.rs b/src/util/slice.rs index b80b2c2..fb9aec2 100644 --- a/src/util/slice.rs +++ b/src/util/slice.rs @@ -8,6 +8,7 @@ pub struct SliceCursor<'a, T> { } /// Helper for the [`SliceCursor`] helper. +#[derive(Copy, Clone)] struct Position { /// Always within -1 and `end` (both inclusive). pos: isize, @@ -15,6 +16,16 @@ struct Position { end: usize, } +impl<'a, T> Clone for SliceCursor<'a, T> { + fn clone(&self) -> Self { + Self { + data: self.data, + pos: self.pos, + chop: self.chop, + } + } +} + impl<'a, T> SliceCursor<'a, T> { pub fn new(data: &'a [T]) -> Self { assert!(data.len() <= isize::MAX as usize); @@ -60,6 +71,23 @@ impl<'a, T> SliceCursor<'a, T> { } } + /// Advance to the last item for which `predicate` is true and return a + /// slice from the current position up to and including that last item. + /// + /// Besides the fact that it will not modify the chop position, + /// this operation is functionally equivalent to: + /// + /// ``` + /// cursor.chop(); + /// while let Some(c) = cursor.peek() { + /// if predicate(c) { + /// cursor.next(); + /// } else { + /// break; + /// } + /// } + /// let result = cursor.chop(); + /// ``` pub fn next_while(&mut self, mut predicate: F) -> &'a [T] where F: FnMut(&'a T) -> bool, @@ -75,6 +103,24 @@ impl<'a, T> SliceCursor<'a, T> { &self.data[start..end] } + /// Save the cursor's state and perform an arbitrary operation on it. + /// If the operation failed (i.e. yielded `None`), restore the cursor's + /// state. Passes on the return value of `op`. + /// + /// `op` SHOULD NOT redefine the cursor unless you want buggy code. + /// This is because the original slice is only restored if `op` succeeded. + pub fn attempt(&mut self, op: F) -> Option + where + F: FnOnce(&mut Self) -> Option, + { + let backup = self.save(); + let result = op(self); + if result.is_none() { + self.restore(backup); + } + result + } + /// Return a slice over all elements since the last time this method was called. /// If the cursor went backwards, the slice is empty. pub fn chop(&mut self) -> &'a [T] { @@ -97,6 +143,16 @@ impl<'a, T> SliceCursor<'a, T> { pub fn remaining(&self) -> usize { self.data.len() - self.pos.next_index_or_end() } + + fn save(&self) -> Self { + self.clone() + } + + fn restore(&mut self, backup: Self) { + self.data = backup.data; + self.chop = backup.chop; + self.pos = backup.pos; + } } impl Position { @@ -254,4 +310,36 @@ mod tests { assert_eq!(cursor.current(), Some(&4)); assert_eq!(cursor.chop(), &data[0..5]); } + + #[test] + fn attempt() { + let data: Vec = (0..10).collect(); + let mut cursor = SliceCursor::new(&data); + + let result = cursor.attempt(|cursor| cursor.next().copied().filter(|c| *c == 0)); + assert_eq!(result, Some(0)); + assert_eq!(cursor.remaining(), 9); + assert_eq!(cursor.current(), Some(&0)); + + let result = cursor.attempt(|cursor| cursor.next().copied().filter(|c| *c == 0)); + assert_eq!(result, None); + assert_eq!(cursor.remaining(), 9); + assert_eq!(cursor.current(), Some(&0)); + + let data2: Vec = (10..20).collect(); + + let _: Option<()> = cursor.attempt(|cursor| { + *cursor = SliceCursor::new(&data2); + cursor.next(); + None + }); + assert_eq!(cursor.current(), Some(&0)); + + cursor.attempt(|cursor| { + *cursor = SliceCursor::new(&data2); + cursor.next(); + Some(()) + }); + assert_eq!(cursor.current(), Some(&10)); + } }