Switched to Rocket

This commit is contained in:
Alex Wright 2020-02-29 16:48:07 +01:00
parent e0d1eb8897
commit f32681d95d

View File

@ -1,21 +1,12 @@
#![deny(warnings)]
#![feature(proc_macro_hygiene)]
#[macro_use] extern crate rocket;
use log::info;
use serde_derive::Serialize;
use std::env;
use std::fs;
use std::path::Path;
use std::str::{
FromStr,
from_utf8,
};
use std::thread;
use hyper::rt::{Future};
use hyper::{Body, Request, Response, Server, StatusCode};
use hyper::header::{AUTHORIZATION};
use hyper_router::{Route, RouterBuilder, RouterService};
use base64::decode;
use biscuit::{
Empty,
@ -35,6 +26,8 @@ use num::BigUint;
use openssl::rsa::Rsa;
use ldap3::{ LdapConn, Scope, SearchEntry };
use rocket::request::Form;
use rocket_contrib::json::Json;
#[derive(Debug)]
struct BasicAuthentication {
@ -52,25 +45,6 @@ pub enum AuthError {
LdapSearch,
}
impl FromStr for BasicAuthentication {
type Err = AuthError;
fn from_str(s: &str) -> Result<BasicAuthentication, AuthError> {
match decode(s) {
Ok(bytes) => match from_utf8(&bytes) {
Ok(text) => {
let mut pair = text.splitn(2, ":");
Ok(BasicAuthentication {
username: pair.next().unwrap().to_string(),
password: pair.next().unwrap().to_string(),
})
},
Err(_) => Err(AuthError::Parse)
},
Err(_) => Err(AuthError::Decode)
}
}
}
#[derive(Debug)]
struct LdapUser {
pub dn: String,
@ -114,6 +88,7 @@ fn auth_user(auth: &BasicAuthentication) -> Result<LdapUser, AuthError> {
None => [].to_vec(),
};
info!("Authentication success for {:?}", base);
Ok(LdapUser {
dn: base,
mail: mail,
@ -121,49 +96,22 @@ fn auth_user(auth: &BasicAuthentication) -> Result<LdapUser, AuthError> {
})
}
fn auth_handler(req: Request<Body>) -> Response<Body> {
let header = match req.headers().get(AUTHORIZATION) {
Some(auth_value) => auth_value.to_str().unwrap(),
None => return Response::builder()
.status(StatusCode::UNAUTHORIZED)
.body(Body::from("Authentication header missing"))
.unwrap(),
};
let (auth_type, credentials) = {
let mut split = header.split_ascii_whitespace();
let auth_type = split.next().unwrap();
let credentials = split.next().unwrap();
(auth_type, credentials)
};
#[derive(FromForm)]
struct LoginData {
username: String,
password: String,
}
if auth_type != "Basic" {
return Response::builder()
.status(StatusCode::UNAUTHORIZED)
.body(Body::from("Basic Authentication was expected"))
.unwrap();
#[post("/login", data = "<form_data>")]
fn login(form_data: Form<LoginData>) -> String {
let auth = BasicAuthentication {
username: form_data.username.to_owned(),
password: form_data.password.to_owned(),
};
match auth_user(&auth) {
Ok(ldap_user) => format!("OK! {:?}", ldap_user),
_ => format!("Bad :("),
}
let auth = BasicAuthentication::from_str(credentials).unwrap();
let worker = thread::spawn(move || {
let user = auth_user(&auth);
user
});
let user = match worker.join().unwrap() {
Ok(ldap_user) => ldap_user,
Err(AuthError::LdapBind) => {
return Response::builder()
.status(StatusCode::UNAUTHORIZED)
.body(Body::from("LDAP bind failed"))
.unwrap();
},
_ => {
return Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(Body::from("Something is broken"))
.unwrap();
}
};
Response::new(Body::from(format!("BasicAuthentication {:?}", user)))
}
fn jwk_from_pem(file_path: &Path) -> Result<JWK<Empty>, Box<dyn std::error::Error + 'static>> {
@ -184,7 +132,8 @@ fn jwk_from_pem(file_path: &Path) -> Result<JWK<Empty>, Box<dyn std::error::Erro
})
}
fn get_keys(_req: Request<Body>) -> Response<Body> {
#[get("/oauth2/keys")]
fn get_keys() -> Json<JWKSet<Empty>> {
let jwks: Vec<JWK<Empty>> = fs::read_dir("./").unwrap()
.filter_map(|dir_entry| {
let path = dir_entry.unwrap().path();
@ -202,12 +151,7 @@ fn get_keys(_req: Request<Body>) -> Response<Body> {
})
.collect();
let jwks = JWKSet { keys: jwks };
let jwks_json = serde_json::to_string(&jwks).unwrap();
Response::builder()
.status(StatusCode::OK)
.header("Content-Type", "application/json")
.body(Body::from(jwks_json))
.unwrap()
Json(jwks)
}
#[derive(Debug, Serialize)]
@ -215,44 +159,30 @@ struct OidcConfig {
pub jwks_uri: String,
}
fn oidc_config(_req: Request<Body>) -> Response<Body> {
#[get("/.well-known/openid-configuration")]
fn oidc_config() -> Json<OidcConfig> {
let config = OidcConfig {
jwks_uri: "https://auth.xeen.dev/oauth2/keys".to_string(),
};
let config_json = serde_json::to_string(&config).unwrap();
Response::builder()
.status(StatusCode::OK)
.header("Content-Type", "application/json")
.body(Body::from(config_json))
.unwrap()
Json(config)
}
fn hello(_req: Request<Body>) -> Response<Body> {
Response::new(Body::from("Hi!"))
#[get("/")]
fn hello() -> &'static str {
"Hello!"
}
fn router_service() -> Result<RouterService, std::io::Error> {
let router = RouterBuilder::new()
.add(Route::get("/").using(hello))
.add(Route::post("/auth").using(auth_handler))
.add(Route::get("/oauth2/keys").using(get_keys))
.add(Route::get("/.well-known/openid-configuration").using(oidc_config))
.build();
Ok(RouterService::new(router))
fn routes() -> Vec<rocket::Route> {
routes![
hello,
oidc_config,
get_keys,
login,
]
}
fn main() {
env_logger::init();
let addr_str = match env::var("LISTEN_ADDR") {
Ok(addr) => addr,
_ => "0.0.0.0:3000".to_string(),
};
let addr = addr_str.parse().expect("Bad Address");
let server = Server::bind(&addr)
.serve(router_service)
.map_err(|e| eprintln!("server error: {}", e));
info!("Listening on http://{}", addr);
tokio::run(server);
rocket::ignite().mount("/", routes()).launch();
}