diff --git a/src/main.rs b/src/main.rs index 1e84b4b..9164b2a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,21 +1,12 @@ #![deny(warnings)] +#![feature(proc_macro_hygiene)] +#[macro_use] extern crate rocket; use log::info; use serde_derive::Serialize; use std::env; use std::fs; use std::path::Path; -use std::str::{ - FromStr, - from_utf8, -}; -use std::thread; - -use hyper::rt::{Future}; -use hyper::{Body, Request, Response, Server, StatusCode}; -use hyper::header::{AUTHORIZATION}; -use hyper_router::{Route, RouterBuilder, RouterService}; -use base64::decode; use biscuit::{ Empty, @@ -35,6 +26,8 @@ use num::BigUint; use openssl::rsa::Rsa; use ldap3::{ LdapConn, Scope, SearchEntry }; +use rocket::request::Form; +use rocket_contrib::json::Json; #[derive(Debug)] struct BasicAuthentication { @@ -52,25 +45,6 @@ pub enum AuthError { LdapSearch, } -impl FromStr for BasicAuthentication { - type Err = AuthError; - fn from_str(s: &str) -> Result { - match decode(s) { - Ok(bytes) => match from_utf8(&bytes) { - Ok(text) => { - let mut pair = text.splitn(2, ":"); - Ok(BasicAuthentication { - username: pair.next().unwrap().to_string(), - password: pair.next().unwrap().to_string(), - }) - }, - Err(_) => Err(AuthError::Parse) - }, - Err(_) => Err(AuthError::Decode) - } - } -} - #[derive(Debug)] struct LdapUser { pub dn: String, @@ -114,6 +88,7 @@ fn auth_user(auth: &BasicAuthentication) -> Result { None => [].to_vec(), }; + info!("Authentication success for {:?}", base); Ok(LdapUser { dn: base, mail: mail, @@ -121,49 +96,22 @@ fn auth_user(auth: &BasicAuthentication) -> Result { }) } -fn auth_handler(req: Request) -> Response { - let header = match req.headers().get(AUTHORIZATION) { - Some(auth_value) => auth_value.to_str().unwrap(), - None => return Response::builder() - .status(StatusCode::UNAUTHORIZED) - .body(Body::from("Authentication header missing")) - .unwrap(), - }; - let (auth_type, credentials) = { - let mut split = header.split_ascii_whitespace(); - let auth_type = split.next().unwrap(); - let credentials = split.next().unwrap(); - (auth_type, credentials) - }; +#[derive(FromForm)] +struct LoginData { + username: String, + password: String, +} - if auth_type != "Basic" { - return Response::builder() - .status(StatusCode::UNAUTHORIZED) - .body(Body::from("Basic Authentication was expected")) - .unwrap(); +#[post("/login", data = "")] +fn login(form_data: Form) -> String { + let auth = BasicAuthentication { + username: form_data.username.to_owned(), + password: form_data.password.to_owned(), + }; + match auth_user(&auth) { + Ok(ldap_user) => format!("OK! {:?}", ldap_user), + _ => format!("Bad :("), } - let auth = BasicAuthentication::from_str(credentials).unwrap(); - let worker = thread::spawn(move || { - let user = auth_user(&auth); - user - }); - let user = match worker.join().unwrap() { - Ok(ldap_user) => ldap_user, - Err(AuthError::LdapBind) => { - return Response::builder() - .status(StatusCode::UNAUTHORIZED) - .body(Body::from("LDAP bind failed")) - .unwrap(); - }, - _ => { - return Response::builder() - .status(StatusCode::INTERNAL_SERVER_ERROR) - .body(Body::from("Something is broken")) - .unwrap(); - } - }; - - Response::new(Body::from(format!("BasicAuthentication {:?}", user))) } fn jwk_from_pem(file_path: &Path) -> Result, Box> { @@ -184,7 +132,8 @@ fn jwk_from_pem(file_path: &Path) -> Result, Box) -> Response { +#[get("/oauth2/keys")] +fn get_keys() -> Json> { let jwks: Vec> = fs::read_dir("./").unwrap() .filter_map(|dir_entry| { let path = dir_entry.unwrap().path(); @@ -202,12 +151,7 @@ fn get_keys(_req: Request) -> Response { }) .collect(); let jwks = JWKSet { keys: jwks }; - let jwks_json = serde_json::to_string(&jwks).unwrap(); - Response::builder() - .status(StatusCode::OK) - .header("Content-Type", "application/json") - .body(Body::from(jwks_json)) - .unwrap() + Json(jwks) } #[derive(Debug, Serialize)] @@ -215,44 +159,30 @@ struct OidcConfig { pub jwks_uri: String, } -fn oidc_config(_req: Request) -> Response { +#[get("/.well-known/openid-configuration")] +fn oidc_config() -> Json { let config = OidcConfig { jwks_uri: "https://auth.xeen.dev/oauth2/keys".to_string(), }; - let config_json = serde_json::to_string(&config).unwrap(); - Response::builder() - .status(StatusCode::OK) - .header("Content-Type", "application/json") - .body(Body::from(config_json)) - .unwrap() + Json(config) } -fn hello(_req: Request) -> Response { - Response::new(Body::from("Hi!")) +#[get("/")] +fn hello() -> &'static str { + "Hello!" } -fn router_service() -> Result { - let router = RouterBuilder::new() - .add(Route::get("/").using(hello)) - .add(Route::post("/auth").using(auth_handler)) - .add(Route::get("/oauth2/keys").using(get_keys)) - .add(Route::get("/.well-known/openid-configuration").using(oidc_config)) - .build(); - Ok(RouterService::new(router)) +fn routes() -> Vec { + routes![ + hello, + oidc_config, + get_keys, + login, + ] } fn main() { env_logger::init(); - let addr_str = match env::var("LISTEN_ADDR") { - Ok(addr) => addr, - _ => "0.0.0.0:3000".to_string(), - }; - let addr = addr_str.parse().expect("Bad Address"); - let server = Server::bind(&addr) - .serve(router_service) - .map_err(|e| eprintln!("server error: {}", e)); - - info!("Listening on http://{}", addr); - tokio::run(server); + rocket::ignite().mount("/", routes()).launch(); }