use rsa::{ pkcs1, pkcs1v15::{Signature, SigningKey, VerifyingKey}, pkcs8::{ self, DecodePrivateKey, DecodePublicKey, EncodePrivateKey, EncodePublicKey, LineEnding, }, rand_core::OsRng, sha2::Sha256, signature::{RandomizedSigner, Verifier}, RsaPrivateKey, RsaPublicKey, }; use serde::{Deserialize, Deserializer, Serialize, Serializer}; use sqlx::{ database::{HasArguments, HasValueRef}, encode::IsNull, error::BoxDynError, Database, }; use std::{fmt, future}; use tokio::sync::OnceCell; use crate::core::*; pub const DEFAULT_KEY_SIZE: usize = 2048; /// Our abstract wrapper around "any" type of public key. /// We currently assume all keys are RSA. #[derive(Clone)] pub struct PubKey { pkey: OnceCell, der: Vec, } /// Our abstract wrapper around "any" type of private key. /// We currently assume all keys are RSA. #[derive(Clone)] pub struct PrivKey { pkey: OnceCell, der: Vec, } pub struct Error(rsa::errors::Error); impl PubKey { pub fn from_der_unchecked(der: Vec) -> PubKey { PubKey { pkey: OnceCell::new(), der, } } pub fn from_pem(pem: &str) -> Result { if pem.starts_with("-----BEGIN PUBLIC KEY-----") { PubKey::from_pkcs8_pem(pem) } else { PubKey::from_pkcs1_pem(pem) } } pub fn from_pkcs8_pem(pem: &str) -> Result { let pkey = RsaPublicKey::from_public_key_pem(pem).map_err(Error::from)?; let der = pkey.to_public_key_der().map_err(Error::from)?.into_vec(); Ok(PubKey { pkey: OnceCell::from(pkey), der, }) } pub fn from_pkcs1_pem(pem: &str) -> Result { let pkey = ::from_pkcs1_pem(pem) .map_err(Error::from)?; let der = pkey.to_public_key_der().map_err(Error::from)?.into_vec(); Ok(PubKey { pkey: OnceCell::from(pkey), der, }) } pub async fn to_pem(&self) -> Result { let pkey = self.get_pkey().await?; let pem = pkey .to_public_key_pem(LineEnding::LF) .map_err(Error::from)?; Ok(pem) } pub async fn verify(&self, data: &[u8], signature: &[u8]) -> Result<()> { let pkey = self.get_pkey().await?; let signature = Signature::from(Box::from(signature)); let verifying_key: VerifyingKey = VerifyingKey::new_with_prefix(pkey.clone()); verifying_key .verify(data, &signature) .map_err(|_| crate::core::Error::BadSignature) } async fn get_pkey(&self) -> Result<&RsaPublicKey> { self.pkey .get_or_try_init(|| { future::ready( RsaPublicKey::from_public_key_der(self.der.as_slice()) .map_err(Error::from) .map_err(crate::core::Error::from), ) }) .await } } impl PrivKey { /// Generate a new private key. pub fn new() -> Result { // The rsa crate takes like two orders of magnitude longer to generate a key, // so until they get that under control we'll use the raw OpenSSL bindings to // generate a key, encode it to PKCS#1 DER, and load it again. let pkey = openssl::rsa::Rsa::generate(DEFAULT_KEY_SIZE as u32).unwrap(); let pkcs1_der = pkey.private_key_to_der().unwrap(); let pkey = ::from_pkcs1_der(pkcs1_der.as_slice()) .map_err(Error::from)?; let der = pkey.to_pkcs8_der().map_err(Error::from)?; let der = Vec::from(der.as_bytes()); Ok(PrivKey { pkey: OnceCell::from(pkey), der, }) } pub fn from_der_unchecked(der: Vec) -> PrivKey { PrivKey { pkey: OnceCell::new(), der, } } pub async fn derive_pubkey(&self) -> Result { let pkey = self.get_pkey().await?; PubKey::try_from(pkey.to_public_key()) } pub async fn sign(&self, data: &[u8]) -> Result> { let pkey = self.get_pkey().await?; let signing_key: SigningKey = SigningKey::new_with_prefix(pkey.clone()); let signature = signing_key.sign_with_rng(&mut OsRng, data); Ok(Vec::from(signature.as_ref())) } async fn get_pkey(&self) -> Result<&RsaPrivateKey> { self.pkey .get_or_try_init(|| { future::ready( RsaPrivateKey::from_pkcs8_der(self.der.as_slice()) .map_err(|e| Error::from(e).into()), ) }) .await } } impl TryFrom for PrivKey { type Error = crate::core::Error; fn try_from(val: RsaPrivateKey) -> Result { let der = val.to_pkcs8_der().map_err(Error::from)?; let der = Vec::from(der.as_bytes()); Ok(PrivKey { pkey: OnceCell::from(val), der, }) } } impl TryFrom for PubKey { type Error = crate::core::Error; fn try_from(val: RsaPublicKey) -> Result { let der = val.to_public_key_der().map_err(Error::from)?.into_vec(); Ok(PubKey { pkey: OnceCell::from(val), der, }) } } impl sqlx::Type for PubKey where Vec: sqlx::Type, { fn type_info() -> DB::TypeInfo { as sqlx::Type>::type_info() } fn compatible(ty: &DB::TypeInfo) -> bool { as sqlx::Type>::compatible(ty) } } impl sqlx::Type for PrivKey where Vec: sqlx::Type, { fn type_info() -> DB::TypeInfo { as sqlx::Type>::type_info() } fn compatible(ty: &DB::TypeInfo) -> bool { as sqlx::Type>::compatible(ty) } } impl<'q, DB: Database> sqlx::Encode<'q, DB> for PubKey where Vec: sqlx::Encode<'q, DB>, { fn encode(self, buf: &mut >::ArgumentBuffer) -> IsNull { self.der.encode(buf) } fn encode_by_ref(&self, buf: &mut >::ArgumentBuffer) -> IsNull { self.der.encode_by_ref(buf) } fn produces(&self) -> Option { self.der.produces() } fn size_hint(&self) -> usize { self.der.size_hint() } } impl<'q, DB: Database> sqlx::Encode<'q, DB> for PrivKey where Vec: sqlx::Encode<'q, DB>, { fn encode(self, buf: &mut >::ArgumentBuffer) -> IsNull { self.der.encode(buf) } fn encode_by_ref(&self, buf: &mut >::ArgumentBuffer) -> IsNull { self.der.encode_by_ref(buf) } fn produces(&self) -> Option { self.der.produces() } fn size_hint(&self) -> usize { self.der.size_hint() } } impl<'r, DB: Database> sqlx::Decode<'r, DB> for PubKey where Vec: sqlx::Decode<'r, DB>, { fn decode(value: >::ValueRef) -> std::result::Result { let value = as sqlx::Decode<'r, DB>>::decode(value)?; Ok(PubKey::from_der_unchecked(value)) } } impl<'r, DB: Database> sqlx::Decode<'r, DB> for PrivKey where Vec: sqlx::Decode<'r, DB>, { fn decode(value: >::ValueRef) -> std::result::Result { let value = as sqlx::Decode<'r, DB>>::decode(value)?; Ok(PrivKey::from_der_unchecked(value)) } } impl Serialize for PubKey where Vec: Serialize, { fn serialize(&self, serializer: S) -> std::result::Result where S: Serializer, { self.der.serialize(serializer) } } impl Serialize for PrivKey where Vec: Serialize, { fn serialize(&self, serializer: S) -> std::result::Result where S: Serializer, { self.der.serialize(serializer) } } impl<'de> Deserialize<'de> for PubKey where Vec: Deserialize<'de>, { fn deserialize(deserializer: D) -> std::result::Result where D: Deserializer<'de>, { Vec::::deserialize(deserializer).map(PubKey::from_der_unchecked) } } impl<'de> Deserialize<'de> for PrivKey where Vec: Deserialize<'de>, { fn deserialize(deserializer: D) -> std::result::Result where D: Deserializer<'de>, { Vec::::deserialize(deserializer).map(PrivKey::from_der_unchecked) } } impl From for Error { fn from(val: rsa::errors::Error) -> Error { Error(val) } } impl From for Error { fn from(val: pkcs1::Error) -> Error { Error::from(rsa::errors::Error::from(val)) } } impl From for Error { fn from(val: pkcs8::Error) -> Error { Error::from(rsa::errors::Error::from(val)) } } impl From for Error { fn from(val: pkcs8::spki::Error) -> Error { Error::from(pkcs8::Error::from(val)) } } impl fmt::Debug for Error where rsa::errors::Error: fmt::Debug, { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fmt::Debug::fmt(&self.0, f) } } impl fmt::Display for Error where rsa::errors::Error: fmt::Display, { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fmt::Display::fmt(&self.0, f) } } #[cfg(test)] mod tests { use super::*; const TEST_PUBKEY_PKCS8: &str = "-----BEGIN PUBLIC KEY----- MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAz4muPogkMCRdIShQLyCv 6QMb7d9epNfKGFyPi4C5w9rZTJ5Ox7X4cueXA6imMDJ0DCfD34QESJoIjkXht3W0 AanLSQnFh+p/5RsbEPb4zaUzG7OHGrYsIE2/LEUFyAuE15KVnXMmtoqN4k8Y5NtC 2GWGnnNW/iD+mr6SMLPQ44+bdPegBjbQmAJ3I/H4byoYvRWWE7g9klWyEZmlSwQQ MG4m86utQeO7JQ9dHUiG6PtuEm0PVB0pUT0a/qF3wRCMPIpPiA/E+z3yfYYnivKu wPsehgVguIGxzQaIOaN5UU7UmL36bAT3E0yhelmDdXkxeo6dQDnkuLRBwMTtFY3w /QIDAQAB -----END PUBLIC KEY----- "; const TEST_PUBKEY_PKCS1: &str = "-----BEGIN RSA PUBLIC KEY----- MIIBCgKCAQEAz4muPogkMCRdIShQLyCv6QMb7d9epNfKGFyPi4C5w9rZTJ5Ox7X4 cueXA6imMDJ0DCfD34QESJoIjkXht3W0AanLSQnFh+p/5RsbEPb4zaUzG7OHGrYs IE2/LEUFyAuE15KVnXMmtoqN4k8Y5NtC2GWGnnNW/iD+mr6SMLPQ44+bdPegBjbQ mAJ3I/H4byoYvRWWE7g9klWyEZmlSwQQMG4m86utQeO7JQ9dHUiG6PtuEm0PVB0p UT0a/qF3wRCMPIpPiA/E+z3yfYYnivKuwPsehgVguIGxzQaIOaN5UU7UmL36bAT3 E0yhelmDdXkxeo6dQDnkuLRBwMTtFY3w/QIDAQAB -----END RSA PUBLIC KEY----- "; #[actix_web::test] async fn verify_signatures() { let priv_key = PrivKey::new().unwrap(); let pub_key = priv_key.derive_pubkey().await.unwrap(); let message = String::from("hello, world"); let signature = priv_key.sign(message.as_bytes()).await.unwrap(); assert_ne!(signature.as_slice(), message.as_bytes()); assert!(pub_key .verify(message.as_bytes(), signature.as_slice()) .await .is_ok()); let tampered_message = String::from("hello, world!"); assert!(pub_key .verify(tampered_message.as_bytes(), signature.as_slice()) .await .is_err()); let mut tampered_signature = signature.clone(); tampered_signature[0] ^= 1; assert!(pub_key .verify(message.as_bytes(), tampered_signature.as_slice()) .await .is_err()); } #[actix_web::test] async fn parse_pkcs8_pem() { let pub_key = PubKey::from_pem(TEST_PUBKEY_PKCS8).unwrap(); let pem = pub_key.to_pem().await.unwrap(); assert_eq!(pem, TEST_PUBKEY_PKCS8); } #[actix_web::test] async fn parse_pkcs1_pem() { let pub_key = PubKey::from_pem(TEST_PUBKEY_PKCS1).unwrap(); let pem = pub_key.to_pem().await.unwrap(); // to_pem() should always return PKCS#8 assert_eq!(pem, TEST_PUBKEY_PKCS8); } }