278 lines
		
	
	
		
			7.1 KiB
		
	
	
	
		
			Rust
		
	
	
	
	
	
			
		
		
	
	
			278 lines
		
	
	
		
			7.1 KiB
		
	
	
	
		
			Rust
		
	
	
	
	
	
| #![deny(warnings)]
 | |
| #![feature(proc_macro_hygiene)]
 | |
| #[macro_use] extern crate diesel;
 | |
| #[macro_use] extern crate rocket;
 | |
| #[macro_use] extern crate rocket_contrib;
 | |
| 
 | |
| use log::info;
 | |
| use serde_derive::Serialize;
 | |
| use std::env;
 | |
| use std::fs;
 | |
| use std::path::Path;
 | |
| 
 | |
| use biscuit::{
 | |
|     Empty,
 | |
| };
 | |
| use biscuit::jwa::{
 | |
|     SignatureAlgorithm,
 | |
|     Algorithm,
 | |
| };
 | |
| use biscuit::jwk::{
 | |
|     RSAKeyParameters,
 | |
|     CommonParameters,
 | |
|     AlgorithmParameters,
 | |
|     JWK,
 | |
|     JWKSet,
 | |
| };
 | |
| use num::BigUint;
 | |
| use openssl::rsa::Rsa;
 | |
| 
 | |
| use ldap3::{ LdapConn, Scope, SearchEntry };
 | |
| use rocket::request::{
 | |
|     FlashMessage,
 | |
|     Form,
 | |
|     FromRequest,
 | |
|     Outcome,
 | |
|     Request,
 | |
| };
 | |
| use rocket::response::{
 | |
|     Flash,
 | |
|     Redirect,
 | |
| };
 | |
| use rocket_contrib::json::Json;
 | |
| use rocket_contrib::templates::Template;
 | |
| 
 | |
| mod schema;
 | |
| mod models;
 | |
| 
 | |
| #[derive(Debug)]
 | |
| struct BasicAuthentication {
 | |
|     pub username: String,
 | |
|     pub password: String,
 | |
| }
 | |
| 
 | |
| #[derive(Debug)]
 | |
| pub enum AuthError {
 | |
|     Parse,
 | |
|     Decode,
 | |
|     LdapBind,
 | |
|     LdapConfig,
 | |
|     LdapConnection,
 | |
|     LdapSearch,
 | |
| }
 | |
| 
 | |
| #[derive(Debug)]
 | |
| struct LdapUser {
 | |
|     pub dn: String,
 | |
|     pub groups: Vec<String>,
 | |
|     pub mail: Vec<String>,
 | |
|     pub services: Vec<String>,
 | |
|     pub username: String,
 | |
| }
 | |
| 
 | |
| fn auth_user(auth: &BasicAuthentication) -> Result<LdapUser, AuthError> {
 | |
|     let ldap_server_addr = match env::var("LDAP_SERVER_ADDR") {
 | |
|         Ok(addr) => addr,
 | |
|         _ => return Err(AuthError::LdapConfig),
 | |
|     };
 | |
|     let ldap = match LdapConn::new(&ldap_server_addr) {
 | |
|         Ok(conn) => conn,
 | |
|         Err(_err) => return Err(AuthError::LdapConnection),
 | |
|     };
 | |
| 
 | |
|     let base = format!("uid={},ou=people,dc=xeentech,dc=com", auth.username);
 | |
|     match ldap.simple_bind(&base, &auth.password).unwrap().success() {
 | |
|         Ok(_ldap) => println!("Connected and authenticated"),
 | |
|         Err(_err) => return Err(AuthError::LdapBind),
 | |
|     };
 | |
| 
 | |
|     let filter = format!("(uid={})", auth.username);
 | |
|     let s = match ldap.search(&base, Scope::Subtree, &filter, vec!["uid", "mail", "enabledService", "memberOf"]) {
 | |
|         Ok(result) => {
 | |
|             let (rs, _) = result.success().unwrap();
 | |
|             rs
 | |
|         },
 | |
|         Err(_err) => return Err(AuthError::LdapSearch),
 | |
|     };
 | |
| 
 | |
|     // Grab the first, if any, result and discard the rest
 | |
|     let se = SearchEntry::construct(s.first().unwrap().to_owned());
 | |
|     let services = match se.attrs.get("enabledService") {
 | |
|         Some(services) => services.to_vec(),
 | |
|         None => [].to_vec(),
 | |
|     };
 | |
|     let mail = match se.attrs.get("mail") {
 | |
|         Some(mail) => mail.to_vec(),
 | |
|         None => [].to_vec(),
 | |
|     };
 | |
|     let groups = match se.attrs.get("memberOf") {
 | |
|         Some(groups) => groups.to_vec(),
 | |
|         None => [].to_vec(),
 | |
|     };
 | |
|     let username = match se.attrs.get("uid") {
 | |
|         Some(username) => username[0].to_owned(),
 | |
|         None => "".to_string(),
 | |
|     };
 | |
| 
 | |
|     info!("Authentication success for {:?}", base);
 | |
|     Ok(LdapUser {
 | |
|         dn: base,
 | |
|         groups: groups,
 | |
|         mail: mail,
 | |
|         services: services,
 | |
|         username: username,
 | |
|     })
 | |
| }
 | |
| 
 | |
| use models::{ User };
 | |
| 
 | |
| #[derive(FromForm)]
 | |
| struct LoginData {
 | |
|     username: String,
 | |
|     password: String,
 | |
| }
 | |
| 
 | |
| #[derive(Serialize)]
 | |
| struct LoginFormContext {
 | |
|     message: Option<String>,
 | |
| }
 | |
| 
 | |
| #[get("/login")]
 | |
| fn login_form(flash: Option<FlashMessage<'_, '_>>) -> Template {
 | |
|     let context = LoginFormContext {
 | |
|         message: match flash {
 | |
|             Some(ref msg) => Some(msg.msg().to_string()),
 | |
|             _ => None,
 | |
|         },
 | |
|     };
 | |
|     Template::render("login_form", &context)
 | |
| }
 | |
| 
 | |
| #[post("/login", data = "<form_data>")]
 | |
| fn login(form_data: Form<LoginData>, conn: AuthDb) -> Result<Redirect, Flash<Redirect>> {
 | |
|     let auth = BasicAuthentication {
 | |
|         username: form_data.username.to_owned(),
 | |
|         password: form_data.password.to_owned(),
 | |
|     };
 | |
|     let ldap_user = match auth_user(&auth) {
 | |
|         Ok(ldap_user) => ldap_user,
 | |
|         _ => return Err(Flash::error(Redirect::to(uri!(login_form)), "Not able to authenticate with given credentials.")),
 | |
|     };
 | |
|     let user = match User::find_or_create(&conn, ldap_user.username) {
 | |
|         Ok(user) => user,
 | |
|         _ => return Err(Flash::error(Redirect::to(uri!(login_form)), "Failed to fetch user")),
 | |
|     };
 | |
|     if ! user.is_active {
 | |
|         return Err(Flash::error(Redirect::to(uri!(login_form)), "Account is suspended"));
 | |
|     }
 | |
|     println!("User: {:?}", user);
 | |
| 
 | |
|     Ok(Redirect::to("/"))
 | |
| }
 | |
| 
 | |
| fn jwk_from_pem(file_path: &Path) -> Result<JWK<Empty>, Box<dyn std::error::Error + 'static>> {
 | |
|     let key_bytes = fs::read(file_path)?;
 | |
|     let rsa = Rsa::private_key_from_pem(key_bytes.as_slice())?;
 | |
|     Ok(JWK {
 | |
|         common: CommonParameters {
 | |
|             algorithm: Some(Algorithm::Signature(SignatureAlgorithm::RS256)),
 | |
|             key_id: Some(file_path.file_name().unwrap().to_str().unwrap().to_string()),
 | |
|             ..Default::default()
 | |
|         },
 | |
|         algorithm: AlgorithmParameters::RSA(RSAKeyParameters {
 | |
|             n: BigUint::from_bytes_be(&rsa.n().to_vec()),
 | |
|             e: BigUint::from_bytes_be(&rsa.e().to_vec()),
 | |
|             ..Default::default()
 | |
|         }),
 | |
|         additional: Default::default(),
 | |
|     })
 | |
| }
 | |
| 
 | |
| #[get("/oauth2/keys")]
 | |
| fn get_keys() -> Json<JWKSet<Empty>> {
 | |
|     let jwks: Vec<JWK<Empty>> = fs::read_dir("./").unwrap()
 | |
|         .filter_map(|dir_entry| {
 | |
|             let path = dir_entry.unwrap().path();
 | |
|             let ext = match path.extension() {
 | |
|                 Some(ext) => ext.to_str().unwrap().to_owned(),
 | |
|                 None => return None,
 | |
|             };
 | |
|             match ext.as_ref() {
 | |
|                 "pem" => match jwk_from_pem(path.as_path()) {
 | |
|                     Ok(jwk) => Some(jwk),
 | |
|                     _ => None,
 | |
|                 },
 | |
|                 _ => None,
 | |
|             }
 | |
|         })
 | |
|         .collect();
 | |
|     let jwks = JWKSet { keys: jwks };
 | |
|     Json(jwks)
 | |
| }
 | |
| 
 | |
| #[derive(Debug, Serialize)]
 | |
| struct OidcConfig {
 | |
|     pub jwks_uri: String,
 | |
| }
 | |
| 
 | |
| #[get("/.well-known/openid-configuration")]
 | |
| fn oidc_config() -> Json<OidcConfig> {
 | |
|     let config = OidcConfig {
 | |
|         jwks_uri: "https://auth.xeen.dev/oauth2/keys".to_string(),
 | |
|     };
 | |
|     Json(config)
 | |
| }
 | |
| 
 | |
| impl<'a, 'r> FromRequest<'a, 'r> for User {
 | |
|     type Error = ();
 | |
| 
 | |
|     fn from_request(request: &'a Request<'r>) -> Outcome<Self, Self::Error> {
 | |
|         let mut user_id = match request.cookies().get_private("user_id") {
 | |
|             Some(cookie) => cookie.value().to_string(),
 | |
|             None => return Outcome::Forward(()),
 | |
|         };
 | |
|         let conn = request.guard::<AuthDb>().unwrap();
 | |
|         match User::get_with_id(&conn, user_id) {
 | |
|             Ok(user) => Outcome::Success(user),
 | |
|             _ => Outcome::Forward(()),
 | |
|         }
 | |
|     }
 | |
| }
 | |
| 
 | |
| #[derive(Serialize)]
 | |
| struct HelloContext {
 | |
|     username: String,
 | |
| }
 | |
| 
 | |
| #[get("/")]
 | |
| fn hello(user: User) -> Template {
 | |
|     println!("User: {:?}", &user);
 | |
|     let context = HelloContext {
 | |
|         username: user.username,
 | |
|     };
 | |
|     Template::render("hello", &context)
 | |
| }
 | |
| 
 | |
| fn routes() -> Vec<rocket::Route> {
 | |
|     routes![
 | |
|         hello,
 | |
|         oidc_config,
 | |
|         get_keys,
 | |
|         login,
 | |
|         login_form,
 | |
|     ]
 | |
| }
 | |
| 
 | |
| #[database("xeenauth")]
 | |
| struct AuthDb(diesel::PgConnection);
 | |
| 
 | |
| fn main() {
 | |
|     env_logger::init();
 | |
| 
 | |
|     rocket::ignite()
 | |
|         .attach(AuthDb::fairing())
 | |
|         .attach(Template::fairing())
 | |
|         .mount("/", routes())
 | |
|         .launch();
 | |
| }
 | 
