Switched to Rocket
This commit is contained in:
parent
e0d1eb8897
commit
f32681d95d
142
src/main.rs
142
src/main.rs
@ -1,21 +1,12 @@
|
||||
#![deny(warnings)]
|
||||
#![feature(proc_macro_hygiene)]
|
||||
#[macro_use] extern crate rocket;
|
||||
|
||||
use log::info;
|
||||
use serde_derive::Serialize;
|
||||
use std::env;
|
||||
use std::fs;
|
||||
use std::path::Path;
|
||||
use std::str::{
|
||||
FromStr,
|
||||
from_utf8,
|
||||
};
|
||||
use std::thread;
|
||||
|
||||
use hyper::rt::{Future};
|
||||
use hyper::{Body, Request, Response, Server, StatusCode};
|
||||
use hyper::header::{AUTHORIZATION};
|
||||
use hyper_router::{Route, RouterBuilder, RouterService};
|
||||
use base64::decode;
|
||||
|
||||
use biscuit::{
|
||||
Empty,
|
||||
@ -35,6 +26,8 @@ use num::BigUint;
|
||||
use openssl::rsa::Rsa;
|
||||
|
||||
use ldap3::{ LdapConn, Scope, SearchEntry };
|
||||
use rocket::request::Form;
|
||||
use rocket_contrib::json::Json;
|
||||
|
||||
#[derive(Debug)]
|
||||
struct BasicAuthentication {
|
||||
@ -52,25 +45,6 @@ pub enum AuthError {
|
||||
LdapSearch,
|
||||
}
|
||||
|
||||
impl FromStr for BasicAuthentication {
|
||||
type Err = AuthError;
|
||||
fn from_str(s: &str) -> Result<BasicAuthentication, AuthError> {
|
||||
match decode(s) {
|
||||
Ok(bytes) => match from_utf8(&bytes) {
|
||||
Ok(text) => {
|
||||
let mut pair = text.splitn(2, ":");
|
||||
Ok(BasicAuthentication {
|
||||
username: pair.next().unwrap().to_string(),
|
||||
password: pair.next().unwrap().to_string(),
|
||||
})
|
||||
},
|
||||
Err(_) => Err(AuthError::Parse)
|
||||
},
|
||||
Err(_) => Err(AuthError::Decode)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct LdapUser {
|
||||
pub dn: String,
|
||||
@ -114,6 +88,7 @@ fn auth_user(auth: &BasicAuthentication) -> Result<LdapUser, AuthError> {
|
||||
None => [].to_vec(),
|
||||
};
|
||||
|
||||
info!("Authentication success for {:?}", base);
|
||||
Ok(LdapUser {
|
||||
dn: base,
|
||||
mail: mail,
|
||||
@ -121,49 +96,22 @@ fn auth_user(auth: &BasicAuthentication) -> Result<LdapUser, AuthError> {
|
||||
})
|
||||
}
|
||||
|
||||
fn auth_handler(req: Request<Body>) -> Response<Body> {
|
||||
let header = match req.headers().get(AUTHORIZATION) {
|
||||
Some(auth_value) => auth_value.to_str().unwrap(),
|
||||
None => return Response::builder()
|
||||
.status(StatusCode::UNAUTHORIZED)
|
||||
.body(Body::from("Authentication header missing"))
|
||||
.unwrap(),
|
||||
};
|
||||
let (auth_type, credentials) = {
|
||||
let mut split = header.split_ascii_whitespace();
|
||||
let auth_type = split.next().unwrap();
|
||||
let credentials = split.next().unwrap();
|
||||
(auth_type, credentials)
|
||||
};
|
||||
#[derive(FromForm)]
|
||||
struct LoginData {
|
||||
username: String,
|
||||
password: String,
|
||||
}
|
||||
|
||||
if auth_type != "Basic" {
|
||||
return Response::builder()
|
||||
.status(StatusCode::UNAUTHORIZED)
|
||||
.body(Body::from("Basic Authentication was expected"))
|
||||
.unwrap();
|
||||
#[post("/login", data = "<form_data>")]
|
||||
fn login(form_data: Form<LoginData>) -> String {
|
||||
let auth = BasicAuthentication {
|
||||
username: form_data.username.to_owned(),
|
||||
password: form_data.password.to_owned(),
|
||||
};
|
||||
match auth_user(&auth) {
|
||||
Ok(ldap_user) => format!("OK! {:?}", ldap_user),
|
||||
_ => format!("Bad :("),
|
||||
}
|
||||
let auth = BasicAuthentication::from_str(credentials).unwrap();
|
||||
let worker = thread::spawn(move || {
|
||||
let user = auth_user(&auth);
|
||||
user
|
||||
});
|
||||
let user = match worker.join().unwrap() {
|
||||
Ok(ldap_user) => ldap_user,
|
||||
Err(AuthError::LdapBind) => {
|
||||
return Response::builder()
|
||||
.status(StatusCode::UNAUTHORIZED)
|
||||
.body(Body::from("LDAP bind failed"))
|
||||
.unwrap();
|
||||
},
|
||||
_ => {
|
||||
return Response::builder()
|
||||
.status(StatusCode::INTERNAL_SERVER_ERROR)
|
||||
.body(Body::from("Something is broken"))
|
||||
.unwrap();
|
||||
}
|
||||
};
|
||||
|
||||
Response::new(Body::from(format!("BasicAuthentication {:?}", user)))
|
||||
}
|
||||
|
||||
fn jwk_from_pem(file_path: &Path) -> Result<JWK<Empty>, Box<dyn std::error::Error + 'static>> {
|
||||
@ -184,7 +132,8 @@ fn jwk_from_pem(file_path: &Path) -> Result<JWK<Empty>, Box<dyn std::error::Erro
|
||||
})
|
||||
}
|
||||
|
||||
fn get_keys(_req: Request<Body>) -> Response<Body> {
|
||||
#[get("/oauth2/keys")]
|
||||
fn get_keys() -> Json<JWKSet<Empty>> {
|
||||
let jwks: Vec<JWK<Empty>> = fs::read_dir("./").unwrap()
|
||||
.filter_map(|dir_entry| {
|
||||
let path = dir_entry.unwrap().path();
|
||||
@ -202,12 +151,7 @@ fn get_keys(_req: Request<Body>) -> Response<Body> {
|
||||
})
|
||||
.collect();
|
||||
let jwks = JWKSet { keys: jwks };
|
||||
let jwks_json = serde_json::to_string(&jwks).unwrap();
|
||||
Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
.header("Content-Type", "application/json")
|
||||
.body(Body::from(jwks_json))
|
||||
.unwrap()
|
||||
Json(jwks)
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
@ -215,44 +159,30 @@ struct OidcConfig {
|
||||
pub jwks_uri: String,
|
||||
}
|
||||
|
||||
fn oidc_config(_req: Request<Body>) -> Response<Body> {
|
||||
#[get("/.well-known/openid-configuration")]
|
||||
fn oidc_config() -> Json<OidcConfig> {
|
||||
let config = OidcConfig {
|
||||
jwks_uri: "https://auth.xeen.dev/oauth2/keys".to_string(),
|
||||
};
|
||||
let config_json = serde_json::to_string(&config).unwrap();
|
||||
Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
.header("Content-Type", "application/json")
|
||||
.body(Body::from(config_json))
|
||||
.unwrap()
|
||||
Json(config)
|
||||
}
|
||||
|
||||
fn hello(_req: Request<Body>) -> Response<Body> {
|
||||
Response::new(Body::from("Hi!"))
|
||||
#[get("/")]
|
||||
fn hello() -> &'static str {
|
||||
"Hello!"
|
||||
}
|
||||
|
||||
fn router_service() -> Result<RouterService, std::io::Error> {
|
||||
let router = RouterBuilder::new()
|
||||
.add(Route::get("/").using(hello))
|
||||
.add(Route::post("/auth").using(auth_handler))
|
||||
.add(Route::get("/oauth2/keys").using(get_keys))
|
||||
.add(Route::get("/.well-known/openid-configuration").using(oidc_config))
|
||||
.build();
|
||||
Ok(RouterService::new(router))
|
||||
fn routes() -> Vec<rocket::Route> {
|
||||
routes![
|
||||
hello,
|
||||
oidc_config,
|
||||
get_keys,
|
||||
login,
|
||||
]
|
||||
}
|
||||
|
||||
fn main() {
|
||||
env_logger::init();
|
||||
|
||||
let addr_str = match env::var("LISTEN_ADDR") {
|
||||
Ok(addr) => addr,
|
||||
_ => "0.0.0.0:3000".to_string(),
|
||||
};
|
||||
let addr = addr_str.parse().expect("Bad Address");
|
||||
let server = Server::bind(&addr)
|
||||
.serve(router_service)
|
||||
.map_err(|e| eprintln!("server error: {}", e));
|
||||
|
||||
info!("Listening on http://{}", addr);
|
||||
tokio::run(server);
|
||||
rocket::ignite().mount("/", routes()).launch();
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user