#![cfg(feature = "service_account")]
use crate::error::Error;
use crate::types::TokenInfo;
use std::{io, path::PathBuf, error::Error as StdError};
use hyper::client::connect::Connection;
use hyper::header;
use http::Uri;
use rustls::{
self,
sign::{self, SigningKey},
PrivateKey,
};
use serde::{Deserialize, Serialize};
use tokio::io::{AsyncRead, AsyncWrite};
use tower_service::Service;
use time::OffsetDateTime;
use url::form_urlencoded;
const GRANT_TYPE: &str = "urn:ietf:params:oauth:grant-type:jwt-bearer";
const GOOGLE_RS256_HEAD: &str = r#"{"alg":"RS256","typ":"JWT"}"#;
fn append_base64<T: AsRef<[u8]> + ?Sized>(s: &T, out: &mut String) {
base64::encode_config_buf(s, base64::URL_SAFE, out)
}
fn decode_rsa_key(pem_pkcs8: &str) -> Result<PrivateKey, io::Error> {
let private_keys = rustls_pemfile::pkcs8_private_keys(&mut pem_pkcs8.as_bytes());
match private_keys {
Ok(mut keys) if !keys.is_empty() => {
keys.truncate(1);
Ok(rustls::PrivateKey(keys.remove(0)))
}
Ok(_) => Err(io::Error::new(
io::ErrorKind::InvalidInput,
"Not enough private keys in PEM",
)),
Err(_) => Err(io::Error::new(
io::ErrorKind::InvalidInput,
"Error reading key from PEM",
)),
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct ServiceAccountKey {
#[serde(rename = "type")]
pub key_type: Option<String>,
pub project_id: Option<String>,
pub private_key_id: Option<String>,
pub private_key: String,
pub client_email: String,
pub client_id: Option<String>,
pub auth_uri: Option<String>,
pub token_uri: String,
pub auth_provider_x509_cert_url: Option<String>,
pub client_x509_cert_url: Option<String>,
}
#[derive(Serialize, Debug)]
struct Claims<'a> {
iss: &'a str,
aud: &'a str,
exp: i64,
iat: i64,
#[serde(rename = "sub")]
subject: Option<&'a str>,
scope: String,
}
impl<'a> Claims<'a> {
fn new<T>(key: &'a ServiceAccountKey, scopes: &[T], subject: Option<&'a str>) -> Self
where
T: AsRef<str>,
{
let iat = OffsetDateTime::now_utc().unix_timestamp();
let expiry = iat + 3600 - 5; let scope = crate::helper::join(scopes, " ");
Claims {
iss: &key.client_email,
aud: &key.token_uri,
exp: expiry,
iat,
subject,
scope,
}
}
}
pub(crate) struct JWTSigner {
signer: Box<dyn rustls::sign::Signer>,
}
impl JWTSigner {
fn new(private_key: &str) -> Result<Self, io::Error> {
let key = decode_rsa_key(private_key)?;
let signing_key = sign::RsaSigningKey::new(&key)
.map_err(|_| io::Error::new(io::ErrorKind::Other, "Couldn't initialize signer"))?;
let signer = signing_key
.choose_scheme(&[rustls::SignatureScheme::RSA_PKCS1_SHA256])
.ok_or_else(|| {
io::Error::new(io::ErrorKind::Other, "Couldn't choose signing scheme")
})?;
Ok(JWTSigner { signer })
}
fn sign_claims(&self, claims: &Claims) -> Result<String, rustls::Error> {
let mut jwt_head = Self::encode_claims(claims);
let signature = self.signer.sign(jwt_head.as_bytes())?;
jwt_head.push('.');
append_base64(&signature, &mut jwt_head);
Ok(jwt_head)
}
fn encode_claims(claims: &Claims) -> String {
let mut head = String::new();
append_base64(GOOGLE_RS256_HEAD, &mut head);
head.push('.');
append_base64(&serde_json::to_string(&claims).unwrap(), &mut head);
head
}
}
pub struct ServiceAccountFlowOpts {
pub(crate) key: FlowOptsKey,
pub(crate) subject: Option<String>,
}
pub(crate) enum FlowOptsKey {
Path(PathBuf),
Key(ServiceAccountKey),
}
pub struct ServiceAccountFlow {
key: ServiceAccountKey,
subject: Option<String>,
signer: JWTSigner,
}
impl ServiceAccountFlow {
pub(crate) async fn new(opts: ServiceAccountFlowOpts) -> Result<Self, io::Error> {
let key = match opts.key {
FlowOptsKey::Path(path) => crate::read_service_account_key(path).await?,
FlowOptsKey::Key(key) => key,
};
let signer = JWTSigner::new(&key.private_key)?;
Ok(ServiceAccountFlow {
key,
subject: opts.subject,
signer,
})
}
pub(crate) async fn token<S, T>(
&self,
hyper_client: &hyper::Client<S>,
scopes: &[T],
) -> Result<TokenInfo, Error>
where
T: AsRef<str>,
S: Service<Uri> + Clone + Send + Sync + 'static,
S::Response: Connection + AsyncRead + AsyncWrite + Send + Unpin + 'static,
S::Future: Send + Unpin + 'static,
S::Error: Into<Box<dyn StdError + Send + Sync>>,
{
let claims = Claims::new(&self.key, scopes, self.subject.as_deref());
let signed = self.signer.sign_claims(&claims).map_err(|_| {
Error::LowLevelError(io::Error::new(
io::ErrorKind::Other,
"unable to sign claims",
))
})?;
let rqbody = form_urlencoded::Serializer::new(String::new())
.extend_pairs(&[("grant_type", GRANT_TYPE), ("assertion", signed.as_str())])
.finish();
let request = hyper::Request::post(&self.key.token_uri)
.header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
.body(hyper::Body::from(rqbody))
.unwrap();
log::debug!("requesting token from service account: {:?}", request);
let (head, body) = hyper_client.request(request).await?.into_parts();
let body = hyper::body::to_bytes(body).await?;
log::debug!("received response; head: {:?}, body: {:?}", head, body);
TokenInfo::from_json(&body)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::helper::read_service_account_key;
const TEST_PRIVATE_KEY_PATH: &'static str = "examples/Sanguine-69411a0c0eea.json";
#[allow(dead_code)]
async fn test_service_account_e2e() {
let acc = ServiceAccountFlow::new(ServiceAccountFlowOpts {
key: FlowOptsKey::Path(TEST_PRIVATE_KEY_PATH.into()),
subject: None,
})
.await
.unwrap();
let client = hyper::Client::builder().build(
hyper_rustls::HttpsConnectorBuilder::new()
.with_native_roots()
.https_only()
.enable_http1()
.enable_http2()
.build(),
);
println!(
"{:?}",
acc.token(&client, &["https://www.googleapis.com/auth/pubsub"])
.await
);
println!(
"{:?}",
acc.token(
&client,
&["https://some.scope/likely-to-hand-out-id-tokens"]
)
.await
);
}
#[tokio::test]
async fn test_jwt_initialize_claims() {
let key = read_service_account_key(TEST_PRIVATE_KEY_PATH)
.await
.unwrap();
let scopes = vec!["scope1", "scope2", "scope3"];
let claims = Claims::new(&key, &scopes, None);
assert_eq!(
claims.iss,
"oauth2-public-test@sanguine-rhythm-105020.iam.gserviceaccount.com".to_string()
);
assert_eq!(claims.scope, "scope1 scope2 scope3".to_string());
assert_eq!(
claims.aud,
"https://accounts.google.com/o/oauth2/token".to_string()
);
assert!(claims.exp > 1000000000);
assert!(claims.iat < claims.exp);
assert_eq!(claims.exp - claims.iat, 3595);
}
#[tokio::test]
async fn test_jwt_sign() {
let key = read_service_account_key(TEST_PRIVATE_KEY_PATH)
.await
.unwrap();
let scopes = vec!["scope1", "scope2", "scope3"];
let signer = JWTSigner::new(&key.private_key).unwrap();
let claims = Claims::new(&key, &scopes, None);
let signature = signer.sign_claims(&claims);
assert!(signature.is_ok());
let signature = signature.unwrap();
assert_eq!(
signature.split(".").nth(0).unwrap(),
"eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9"
);
}
}