You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

430 lines
12 KiB
Rust

use rsa::{
pkcs1,
pkcs1v15::{Signature, SigningKey, VerifyingKey},
pkcs8::{
self, DecodePrivateKey, DecodePublicKey, EncodePrivateKey, EncodePublicKey, LineEnding,
},
rand_core::OsRng,
sha2::Sha256,
signature::{RandomizedSigner, Verifier},
RsaPrivateKey, RsaPublicKey,
};
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use sqlx::{
database::{HasArguments, HasValueRef},
encode::IsNull,
error::BoxDynError,
Database,
};
use std::{fmt, future};
use tokio::sync::OnceCell;
use crate::core::*;
pub const DEFAULT_KEY_SIZE: usize = 2048;
/// Our abstract wrapper around "any" type of public key.
/// We currently assume all keys are RSA.
#[derive(Clone)]
pub struct PubKey {
pkey: OnceCell<RsaPublicKey>,
der: Vec<u8>,
}
/// Our abstract wrapper around "any" type of private key.
/// We currently assume all keys are RSA.
#[derive(Clone)]
pub struct PrivKey {
pkey: OnceCell<RsaPrivateKey>,
der: Vec<u8>,
}
pub struct Error(rsa::errors::Error);
impl PubKey {
pub fn from_der_unchecked(der: Vec<u8>) -> PubKey {
PubKey {
pkey: OnceCell::new(),
der,
}
}
pub fn from_pem(pem: &str) -> Result<PubKey> {
if pem.starts_with("-----BEGIN PUBLIC KEY-----") {
PubKey::from_pkcs8_pem(pem)
} else {
PubKey::from_pkcs1_pem(pem)
}
}
pub fn from_pkcs8_pem(pem: &str) -> Result<PubKey> {
let pkey = RsaPublicKey::from_public_key_pem(pem).map_err(Error::from)?;
let der = pkey.to_public_key_der().map_err(Error::from)?.into_vec();
Ok(PubKey {
pkey: OnceCell::from(pkey),
der,
})
}
pub fn from_pkcs1_pem(pem: &str) -> Result<PubKey> {
let pkey = <RsaPublicKey as pkcs1::DecodeRsaPublicKey>::from_pkcs1_pem(pem)
.map_err(Error::from)?;
let der = pkey.to_public_key_der().map_err(Error::from)?.into_vec();
Ok(PubKey {
pkey: OnceCell::from(pkey),
der,
})
}
pub async fn to_pem(&self) -> Result<String> {
let pkey = self.get_pkey().await?;
let pem = pkey
.to_public_key_pem(LineEnding::LF)
.map_err(Error::from)?;
Ok(pem)
}
pub async fn verify(&self, data: &[u8], signature: &[u8]) -> Result<()> {
let pkey = self.get_pkey().await?;
let signature = Signature::from(Box::from(signature));
let verifying_key: VerifyingKey<Sha256> = VerifyingKey::new_with_prefix(pkey.clone());
verifying_key
.verify(data, &signature)
.map_err(|_| crate::core::Error::BadSignature)
}
async fn get_pkey(&self) -> Result<&RsaPublicKey> {
self.pkey
.get_or_try_init(|| {
future::ready(
RsaPublicKey::from_public_key_der(self.der.as_slice())
.map_err(Error::from)
.map_err(crate::core::Error::from),
)
})
.await
}
}
impl PrivKey {
/// Generate a new private key.
pub fn new() -> Result<PrivKey> {
// The rsa crate takes like two orders of magnitude longer to generate a key,
// so until they get that under control we'll use the raw OpenSSL bindings to
// generate a key, encode it to PKCS#1 DER, and load it again.
let pkey = openssl::rsa::Rsa::generate(DEFAULT_KEY_SIZE as u32).unwrap();
let pkcs1_der = pkey.private_key_to_der().unwrap();
let pkey =
<RsaPrivateKey as pkcs1::DecodeRsaPrivateKey>::from_pkcs1_der(pkcs1_der.as_slice())
.map_err(Error::from)?;
let der = pkey.to_pkcs8_der().map_err(Error::from)?;
let der = Vec::from(der.as_bytes());
Ok(PrivKey {
pkey: OnceCell::from(pkey),
der,
})
}
pub fn from_der_unchecked(der: Vec<u8>) -> PrivKey {
PrivKey {
pkey: OnceCell::new(),
der,
}
}
pub async fn derive_pubkey(&self) -> Result<PubKey> {
let pkey = self.get_pkey().await?;
PubKey::try_from(pkey.to_public_key())
}
pub async fn sign(&self, data: &[u8]) -> Result<Vec<u8>> {
let pkey = self.get_pkey().await?;
let signing_key: SigningKey<Sha256> = SigningKey::new_with_prefix(pkey.clone());
let signature = signing_key.sign_with_rng(&mut OsRng, data);
Ok(Vec::from(signature.as_ref()))
}
async fn get_pkey(&self) -> Result<&RsaPrivateKey> {
self.pkey
.get_or_try_init(|| {
future::ready(
RsaPrivateKey::from_pkcs8_der(self.der.as_slice())
.map_err(|e| Error::from(e).into()),
)
})
.await
}
}
impl TryFrom<RsaPrivateKey> for PrivKey {
type Error = crate::core::Error;
fn try_from(val: RsaPrivateKey) -> Result<PrivKey> {
let der = val.to_pkcs8_der().map_err(Error::from)?;
let der = Vec::from(der.as_bytes());
Ok(PrivKey {
pkey: OnceCell::from(val),
der,
})
}
}
impl TryFrom<RsaPublicKey> for PubKey {
type Error = crate::core::Error;
fn try_from(val: RsaPublicKey) -> Result<PubKey> {
let der = val.to_public_key_der().map_err(Error::from)?.into_vec();
Ok(PubKey {
pkey: OnceCell::from(val),
der,
})
}
}
impl<DB: Database> sqlx::Type<DB> for PubKey
where
Vec<u8>: sqlx::Type<DB>,
{
fn type_info() -> DB::TypeInfo {
<Vec<u8> as sqlx::Type<DB>>::type_info()
}
fn compatible(ty: &DB::TypeInfo) -> bool {
<Vec<u8> as sqlx::Type<DB>>::compatible(ty)
}
}
impl<DB: Database> sqlx::Type<DB> for PrivKey
where
Vec<u8>: sqlx::Type<DB>,
{
fn type_info() -> DB::TypeInfo {
<Vec<u8> as sqlx::Type<DB>>::type_info()
}
fn compatible(ty: &DB::TypeInfo) -> bool {
<Vec<u8> as sqlx::Type<DB>>::compatible(ty)
}
}
impl<'q, DB: Database> sqlx::Encode<'q, DB> for PubKey
where
Vec<u8>: sqlx::Encode<'q, DB>,
{
fn encode(self, buf: &mut <DB as HasArguments<'q>>::ArgumentBuffer) -> IsNull {
self.der.encode(buf)
}
fn encode_by_ref(&self, buf: &mut <DB as HasArguments<'q>>::ArgumentBuffer) -> IsNull {
self.der.encode_by_ref(buf)
}
fn produces(&self) -> Option<DB::TypeInfo> {
self.der.produces()
}
fn size_hint(&self) -> usize {
self.der.size_hint()
}
}
impl<'q, DB: Database> sqlx::Encode<'q, DB> for PrivKey
where
Vec<u8>: sqlx::Encode<'q, DB>,
{
fn encode(self, buf: &mut <DB as HasArguments<'q>>::ArgumentBuffer) -> IsNull {
self.der.encode(buf)
}
fn encode_by_ref(&self, buf: &mut <DB as HasArguments<'q>>::ArgumentBuffer) -> IsNull {
self.der.encode_by_ref(buf)
}
fn produces(&self) -> Option<DB::TypeInfo> {
self.der.produces()
}
fn size_hint(&self) -> usize {
self.der.size_hint()
}
}
impl<'r, DB: Database> sqlx::Decode<'r, DB> for PubKey
where
Vec<u8>: sqlx::Decode<'r, DB>,
{
fn decode(value: <DB as HasValueRef<'r>>::ValueRef) -> std::result::Result<Self, BoxDynError> {
let value = <Vec<u8> as sqlx::Decode<'r, DB>>::decode(value)?;
Ok(PubKey::from_der_unchecked(value))
}
}
impl<'r, DB: Database> sqlx::Decode<'r, DB> for PrivKey
where
Vec<u8>: sqlx::Decode<'r, DB>,
{
fn decode(value: <DB as HasValueRef<'r>>::ValueRef) -> std::result::Result<Self, BoxDynError> {
let value = <Vec<u8> as sqlx::Decode<'r, DB>>::decode(value)?;
Ok(PrivKey::from_der_unchecked(value))
}
}
impl Serialize for PubKey
where
Vec<u8>: Serialize,
{
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: Serializer,
{
self.der.serialize(serializer)
}
}
impl Serialize for PrivKey
where
Vec<u8>: Serialize,
{
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: Serializer,
{
self.der.serialize(serializer)
}
}
impl<'de> Deserialize<'de> for PubKey
where
Vec<u8>: Deserialize<'de>,
{
fn deserialize<D>(deserializer: D) -> std::result::Result<PubKey, D::Error>
where
D: Deserializer<'de>,
{
Vec::<u8>::deserialize(deserializer).map(PubKey::from_der_unchecked)
}
}
impl<'de> Deserialize<'de> for PrivKey
where
Vec<u8>: Deserialize<'de>,
{
fn deserialize<D>(deserializer: D) -> std::result::Result<PrivKey, D::Error>
where
D: Deserializer<'de>,
{
Vec::<u8>::deserialize(deserializer).map(PrivKey::from_der_unchecked)
}
}
impl From<rsa::errors::Error> for Error {
fn from(val: rsa::errors::Error) -> Error {
Error(val)
}
}
impl From<pkcs1::Error> for Error {
fn from(val: pkcs1::Error) -> Error {
Error::from(rsa::errors::Error::from(val))
}
}
impl From<pkcs8::Error> for Error {
fn from(val: pkcs8::Error) -> Error {
Error::from(rsa::errors::Error::from(val))
}
}
impl From<pkcs8::spki::Error> for Error {
fn from(val: pkcs8::spki::Error) -> Error {
Error::from(pkcs8::Error::from(val))
}
}
impl fmt::Debug for Error
where
rsa::errors::Error: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
fmt::Debug::fmt(&self.0, f)
}
}
impl fmt::Display for Error
where
rsa::errors::Error: fmt::Display,
{
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
fmt::Display::fmt(&self.0, f)
}
}
#[cfg(test)]
mod tests {
use super::*;
const TEST_PUBKEY_PKCS8: &str = "-----BEGIN PUBLIC KEY-----
MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAz4muPogkMCRdIShQLyCv
6QMb7d9epNfKGFyPi4C5w9rZTJ5Ox7X4cueXA6imMDJ0DCfD34QESJoIjkXht3W0
AanLSQnFh+p/5RsbEPb4zaUzG7OHGrYsIE2/LEUFyAuE15KVnXMmtoqN4k8Y5NtC
2GWGnnNW/iD+mr6SMLPQ44+bdPegBjbQmAJ3I/H4byoYvRWWE7g9klWyEZmlSwQQ
MG4m86utQeO7JQ9dHUiG6PtuEm0PVB0pUT0a/qF3wRCMPIpPiA/E+z3yfYYnivKu
wPsehgVguIGxzQaIOaN5UU7UmL36bAT3E0yhelmDdXkxeo6dQDnkuLRBwMTtFY3w
/QIDAQAB
-----END PUBLIC KEY-----
";
const TEST_PUBKEY_PKCS1: &str = "-----BEGIN RSA PUBLIC KEY-----
MIIBCgKCAQEAz4muPogkMCRdIShQLyCv6QMb7d9epNfKGFyPi4C5w9rZTJ5Ox7X4
cueXA6imMDJ0DCfD34QESJoIjkXht3W0AanLSQnFh+p/5RsbEPb4zaUzG7OHGrYs
IE2/LEUFyAuE15KVnXMmtoqN4k8Y5NtC2GWGnnNW/iD+mr6SMLPQ44+bdPegBjbQ
mAJ3I/H4byoYvRWWE7g9klWyEZmlSwQQMG4m86utQeO7JQ9dHUiG6PtuEm0PVB0p
UT0a/qF3wRCMPIpPiA/E+z3yfYYnivKuwPsehgVguIGxzQaIOaN5UU7UmL36bAT3
E0yhelmDdXkxeo6dQDnkuLRBwMTtFY3w/QIDAQAB
-----END RSA PUBLIC KEY-----
";
#[actix_web::test]
async fn verify_signatures() {
let priv_key = PrivKey::new().unwrap();
let pub_key = priv_key.derive_pubkey().await.unwrap();
let message = String::from("hello, world");
let signature = priv_key.sign(message.as_bytes()).await.unwrap();
assert_ne!(signature.as_slice(), message.as_bytes());
assert!(pub_key
.verify(message.as_bytes(), signature.as_slice())
.await
.is_ok());
let tampered_message = String::from("hello, world!");
assert!(pub_key
.verify(tampered_message.as_bytes(), signature.as_slice())
.await
.is_err());
let mut tampered_signature = signature.clone();
tampered_signature[0] ^= 1;
assert!(pub_key
.verify(message.as_bytes(), tampered_signature.as_slice())
.await
.is_err());
}
#[actix_web::test]
async fn parse_pkcs8_pem() {
let pub_key = PubKey::from_pem(TEST_PUBKEY_PKCS8).unwrap();
let pem = pub_key.to_pem().await.unwrap();
assert_eq!(pem, TEST_PUBKEY_PKCS8);
}
#[actix_web::test]
async fn parse_pkcs1_pem() {
let pub_key = PubKey::from_pem(TEST_PUBKEY_PKCS1).unwrap();
let pem = pub_key.to_pem().await.unwrap();
// to_pem() should always return PKCS#8
assert_eq!(pem, TEST_PUBKEY_PKCS8);
}
}