Switched to Rocket
This commit is contained in:
		
							parent
							
								
									e0d1eb8897
								
							
						
					
					
						commit
						f32681d95d
					
				
							
								
								
									
										142
									
								
								src/main.rs
									
									
									
									
									
								
							
							
						
						
									
										142
									
								
								src/main.rs
									
									
									
									
									
								
							@ -1,21 +1,12 @@
 | 
			
		||||
#![deny(warnings)]
 | 
			
		||||
#![feature(proc_macro_hygiene)]
 | 
			
		||||
#[macro_use] extern crate rocket;
 | 
			
		||||
 | 
			
		||||
use log::info;
 | 
			
		||||
use serde_derive::Serialize;
 | 
			
		||||
use std::env;
 | 
			
		||||
use std::fs;
 | 
			
		||||
use std::path::Path;
 | 
			
		||||
use std::str::{
 | 
			
		||||
    FromStr,
 | 
			
		||||
    from_utf8,
 | 
			
		||||
};
 | 
			
		||||
use std::thread;
 | 
			
		||||
 | 
			
		||||
use hyper::rt::{Future};
 | 
			
		||||
use hyper::{Body, Request, Response, Server, StatusCode};
 | 
			
		||||
use hyper::header::{AUTHORIZATION};
 | 
			
		||||
use hyper_router::{Route, RouterBuilder, RouterService};
 | 
			
		||||
use base64::decode;
 | 
			
		||||
 | 
			
		||||
use biscuit::{
 | 
			
		||||
    Empty,
 | 
			
		||||
@ -35,6 +26,8 @@ use num::BigUint;
 | 
			
		||||
use openssl::rsa::Rsa;
 | 
			
		||||
 | 
			
		||||
use ldap3::{ LdapConn, Scope, SearchEntry };
 | 
			
		||||
use rocket::request::Form;
 | 
			
		||||
use rocket_contrib::json::Json;
 | 
			
		||||
 | 
			
		||||
#[derive(Debug)]
 | 
			
		||||
struct BasicAuthentication {
 | 
			
		||||
@ -52,25 +45,6 @@ pub enum AuthError {
 | 
			
		||||
    LdapSearch,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl FromStr for BasicAuthentication {
 | 
			
		||||
    type Err = AuthError;
 | 
			
		||||
    fn from_str(s: &str) -> Result<BasicAuthentication, AuthError> {
 | 
			
		||||
        match decode(s) {
 | 
			
		||||
            Ok(bytes) => match from_utf8(&bytes) {
 | 
			
		||||
                Ok(text) => {
 | 
			
		||||
                    let mut pair = text.splitn(2, ":");
 | 
			
		||||
                    Ok(BasicAuthentication {
 | 
			
		||||
                        username: pair.next().unwrap().to_string(),
 | 
			
		||||
                        password: pair.next().unwrap().to_string(),
 | 
			
		||||
                    })
 | 
			
		||||
                },
 | 
			
		||||
                Err(_) => Err(AuthError::Parse)
 | 
			
		||||
            },
 | 
			
		||||
            Err(_) => Err(AuthError::Decode)
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[derive(Debug)]
 | 
			
		||||
struct LdapUser {
 | 
			
		||||
    pub dn: String,
 | 
			
		||||
@ -114,6 +88,7 @@ fn auth_user(auth: &BasicAuthentication) -> Result<LdapUser, AuthError> {
 | 
			
		||||
        None => [].to_vec(),
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
    info!("Authentication success for {:?}", base);
 | 
			
		||||
    Ok(LdapUser {
 | 
			
		||||
        dn: base,
 | 
			
		||||
        mail: mail,
 | 
			
		||||
@ -121,49 +96,22 @@ fn auth_user(auth: &BasicAuthentication) -> Result<LdapUser, AuthError> {
 | 
			
		||||
    })
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
fn auth_handler(req: Request<Body>) -> Response<Body> {
 | 
			
		||||
    let header = match req.headers().get(AUTHORIZATION) {
 | 
			
		||||
        Some(auth_value) => auth_value.to_str().unwrap(),
 | 
			
		||||
        None => return Response::builder()
 | 
			
		||||
            .status(StatusCode::UNAUTHORIZED)
 | 
			
		||||
            .body(Body::from("Authentication header missing"))
 | 
			
		||||
            .unwrap(),
 | 
			
		||||
    };
 | 
			
		||||
    let (auth_type, credentials) = {
 | 
			
		||||
        let mut split = header.split_ascii_whitespace();
 | 
			
		||||
        let auth_type = split.next().unwrap();
 | 
			
		||||
        let credentials = split.next().unwrap();
 | 
			
		||||
        (auth_type, credentials)
 | 
			
		||||
    };
 | 
			
		||||
#[derive(FromForm)]
 | 
			
		||||
struct LoginData {
 | 
			
		||||
    username: String,
 | 
			
		||||
    password: String,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
    if auth_type != "Basic" {
 | 
			
		||||
        return Response::builder()
 | 
			
		||||
            .status(StatusCode::UNAUTHORIZED)
 | 
			
		||||
            .body(Body::from("Basic Authentication was expected"))
 | 
			
		||||
            .unwrap();
 | 
			
		||||
#[post("/login", data = "<form_data>")]
 | 
			
		||||
fn login(form_data: Form<LoginData>) -> String {
 | 
			
		||||
    let auth = BasicAuthentication {
 | 
			
		||||
        username: form_data.username.to_owned(),
 | 
			
		||||
        password: form_data.password.to_owned(),
 | 
			
		||||
    };
 | 
			
		||||
    match auth_user(&auth) {
 | 
			
		||||
        Ok(ldap_user) => format!("OK! {:?}", ldap_user),
 | 
			
		||||
        _ => format!("Bad :("),
 | 
			
		||||
    }
 | 
			
		||||
    let auth = BasicAuthentication::from_str(credentials).unwrap();
 | 
			
		||||
    let worker = thread::spawn(move || {
 | 
			
		||||
        let user = auth_user(&auth);
 | 
			
		||||
        user
 | 
			
		||||
    });
 | 
			
		||||
    let user = match worker.join().unwrap() {
 | 
			
		||||
        Ok(ldap_user) => ldap_user,
 | 
			
		||||
        Err(AuthError::LdapBind) => {
 | 
			
		||||
            return Response::builder()
 | 
			
		||||
                .status(StatusCode::UNAUTHORIZED)
 | 
			
		||||
                .body(Body::from("LDAP bind failed"))
 | 
			
		||||
                .unwrap();
 | 
			
		||||
        },
 | 
			
		||||
        _ => {
 | 
			
		||||
            return Response::builder()
 | 
			
		||||
                .status(StatusCode::INTERNAL_SERVER_ERROR)
 | 
			
		||||
                .body(Body::from("Something is broken"))
 | 
			
		||||
                .unwrap();
 | 
			
		||||
        }
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
    Response::new(Body::from(format!("BasicAuthentication {:?}", user)))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
fn jwk_from_pem(file_path: &Path) -> Result<JWK<Empty>, Box<dyn std::error::Error + 'static>> {
 | 
			
		||||
@ -184,7 +132,8 @@ fn jwk_from_pem(file_path: &Path) -> Result<JWK<Empty>, Box<dyn std::error::Erro
 | 
			
		||||
    })
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
fn get_keys(_req: Request<Body>) -> Response<Body> {
 | 
			
		||||
#[get("/oauth2/keys")]
 | 
			
		||||
fn get_keys() -> Json<JWKSet<Empty>> {
 | 
			
		||||
    let jwks: Vec<JWK<Empty>> = fs::read_dir("./").unwrap()
 | 
			
		||||
        .filter_map(|dir_entry| {
 | 
			
		||||
            let path = dir_entry.unwrap().path();
 | 
			
		||||
@ -202,12 +151,7 @@ fn get_keys(_req: Request<Body>) -> Response<Body> {
 | 
			
		||||
        })
 | 
			
		||||
        .collect();
 | 
			
		||||
    let jwks = JWKSet { keys: jwks };
 | 
			
		||||
    let jwks_json = serde_json::to_string(&jwks).unwrap();
 | 
			
		||||
    Response::builder()
 | 
			
		||||
        .status(StatusCode::OK)
 | 
			
		||||
        .header("Content-Type", "application/json")
 | 
			
		||||
        .body(Body::from(jwks_json))
 | 
			
		||||
        .unwrap()
 | 
			
		||||
    Json(jwks)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[derive(Debug, Serialize)]
 | 
			
		||||
@ -215,44 +159,30 @@ struct OidcConfig {
 | 
			
		||||
    pub jwks_uri: String,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
fn oidc_config(_req: Request<Body>) -> Response<Body> {
 | 
			
		||||
#[get("/.well-known/openid-configuration")]
 | 
			
		||||
fn oidc_config() -> Json<OidcConfig> {
 | 
			
		||||
    let config = OidcConfig {
 | 
			
		||||
        jwks_uri: "https://auth.xeen.dev/oauth2/keys".to_string(),
 | 
			
		||||
    };
 | 
			
		||||
    let config_json = serde_json::to_string(&config).unwrap();
 | 
			
		||||
    Response::builder()
 | 
			
		||||
        .status(StatusCode::OK)
 | 
			
		||||
        .header("Content-Type", "application/json")
 | 
			
		||||
        .body(Body::from(config_json))
 | 
			
		||||
        .unwrap()
 | 
			
		||||
    Json(config)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
fn hello(_req: Request<Body>) -> Response<Body> {
 | 
			
		||||
    Response::new(Body::from("Hi!"))
 | 
			
		||||
#[get("/")]
 | 
			
		||||
fn hello() -> &'static str {
 | 
			
		||||
    "Hello!"
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
fn router_service() -> Result<RouterService, std::io::Error> {
 | 
			
		||||
    let router = RouterBuilder::new()
 | 
			
		||||
        .add(Route::get("/").using(hello))
 | 
			
		||||
        .add(Route::post("/auth").using(auth_handler))
 | 
			
		||||
        .add(Route::get("/oauth2/keys").using(get_keys))
 | 
			
		||||
        .add(Route::get("/.well-known/openid-configuration").using(oidc_config))
 | 
			
		||||
        .build();
 | 
			
		||||
    Ok(RouterService::new(router))
 | 
			
		||||
fn routes() -> Vec<rocket::Route> {
 | 
			
		||||
    routes![
 | 
			
		||||
        hello,
 | 
			
		||||
        oidc_config,
 | 
			
		||||
        get_keys,
 | 
			
		||||
        login,
 | 
			
		||||
    ]
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
fn main() {
 | 
			
		||||
    env_logger::init();
 | 
			
		||||
 | 
			
		||||
    let addr_str = match env::var("LISTEN_ADDR") {
 | 
			
		||||
        Ok(addr) => addr,
 | 
			
		||||
        _ => "0.0.0.0:3000".to_string(),
 | 
			
		||||
    };
 | 
			
		||||
    let addr = addr_str.parse().expect("Bad Address");
 | 
			
		||||
    let server = Server::bind(&addr)
 | 
			
		||||
        .serve(router_service)
 | 
			
		||||
        .map_err(|e| eprintln!("server error: {}", e));
 | 
			
		||||
 | 
			
		||||
    info!("Listening on http://{}", addr);
 | 
			
		||||
    tokio::run(server);
 | 
			
		||||
    rocket::ignite().mount("/", routes()).launch();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user