auth-server/src/main.rs

278 lines
7.1 KiB
Rust

#![deny(warnings)]
#![feature(proc_macro_hygiene)]
#[macro_use] extern crate diesel;
#[macro_use] extern crate rocket;
#[macro_use] extern crate rocket_contrib;
use log::info;
use serde_derive::Serialize;
use std::env;
use std::fs;
use std::path::Path;
use biscuit::{
Empty,
};
use biscuit::jwa::{
SignatureAlgorithm,
Algorithm,
};
use biscuit::jwk::{
RSAKeyParameters,
CommonParameters,
AlgorithmParameters,
JWK,
JWKSet,
};
use num::BigUint;
use openssl::rsa::Rsa;
use ldap3::{ LdapConn, Scope, SearchEntry };
use rocket::request::{
FlashMessage,
Form,
FromRequest,
Outcome,
Request,
};
use rocket::response::{
Flash,
Redirect,
};
use rocket_contrib::json::Json;
use rocket_contrib::templates::Template;
mod schema;
mod models;
#[derive(Debug)]
struct BasicAuthentication {
pub username: String,
pub password: String,
}
#[derive(Debug)]
pub enum AuthError {
Parse,
Decode,
LdapBind,
LdapConfig,
LdapConnection,
LdapSearch,
}
#[derive(Debug)]
struct LdapUser {
pub dn: String,
pub groups: Vec<String>,
pub mail: Vec<String>,
pub services: Vec<String>,
pub username: String,
}
fn auth_user(auth: &BasicAuthentication) -> Result<LdapUser, AuthError> {
let ldap_server_addr = match env::var("LDAP_SERVER_ADDR") {
Ok(addr) => addr,
_ => return Err(AuthError::LdapConfig),
};
let ldap = match LdapConn::new(&ldap_server_addr) {
Ok(conn) => conn,
Err(_err) => return Err(AuthError::LdapConnection),
};
let base = format!("uid={},ou=people,dc=xeentech,dc=com", auth.username);
match ldap.simple_bind(&base, &auth.password).unwrap().success() {
Ok(_ldap) => println!("Connected and authenticated"),
Err(_err) => return Err(AuthError::LdapBind),
};
let filter = format!("(uid={})", auth.username);
let s = match ldap.search(&base, Scope::Subtree, &filter, vec!["uid", "mail", "enabledService", "memberOf"]) {
Ok(result) => {
let (rs, _) = result.success().unwrap();
rs
},
Err(_err) => return Err(AuthError::LdapSearch),
};
// Grab the first, if any, result and discard the rest
let se = SearchEntry::construct(s.first().unwrap().to_owned());
let services = match se.attrs.get("enabledService") {
Some(services) => services.to_vec(),
None => [].to_vec(),
};
let mail = match se.attrs.get("mail") {
Some(mail) => mail.to_vec(),
None => [].to_vec(),
};
let groups = match se.attrs.get("memberOf") {
Some(groups) => groups.to_vec(),
None => [].to_vec(),
};
let username = match se.attrs.get("uid") {
Some(username) => username[0].to_owned(),
None => "".to_string(),
};
info!("Authentication success for {:?}", base);
Ok(LdapUser {
dn: base,
groups: groups,
mail: mail,
services: services,
username: username,
})
}
use models::{ User };
#[derive(FromForm)]
struct LoginData {
username: String,
password: String,
}
#[derive(Serialize)]
struct LoginFormContext {
message: Option<String>,
}
#[get("/login")]
fn login_form(flash: Option<FlashMessage<'_, '_>>) -> Template {
let context = LoginFormContext {
message: match flash {
Some(ref msg) => Some(msg.msg().to_string()),
_ => None,
},
};
Template::render("login_form", &context)
}
#[post("/login", data = "<form_data>")]
fn login(form_data: Form<LoginData>, conn: AuthDb) -> Result<Redirect, Flash<Redirect>> {
let auth = BasicAuthentication {
username: form_data.username.to_owned(),
password: form_data.password.to_owned(),
};
let ldap_user = match auth_user(&auth) {
Ok(ldap_user) => ldap_user,
_ => return Err(Flash::error(Redirect::to(uri!(login_form)), "Not able to authenticate with given credentials.")),
};
let user = match User::find_or_create(&conn, ldap_user.username) {
Ok(user) => user,
_ => return Err(Flash::error(Redirect::to(uri!(login_form)), "Failed to fetch user")),
};
if ! user.is_active {
return Err(Flash::error(Redirect::to(uri!(login_form)), "Account is suspended"));
}
println!("User: {:?}", user);
Ok(Redirect::to("/"))
}
fn jwk_from_pem(file_path: &Path) -> Result<JWK<Empty>, Box<dyn std::error::Error + 'static>> {
let key_bytes = fs::read(file_path)?;
let rsa = Rsa::private_key_from_pem(key_bytes.as_slice())?;
Ok(JWK {
common: CommonParameters {
algorithm: Some(Algorithm::Signature(SignatureAlgorithm::RS256)),
key_id: Some(file_path.file_name().unwrap().to_str().unwrap().to_string()),
..Default::default()
},
algorithm: AlgorithmParameters::RSA(RSAKeyParameters {
n: BigUint::from_bytes_be(&rsa.n().to_vec()),
e: BigUint::from_bytes_be(&rsa.e().to_vec()),
..Default::default()
}),
additional: Default::default(),
})
}
#[get("/oauth2/keys")]
fn get_keys() -> Json<JWKSet<Empty>> {
let jwks: Vec<JWK<Empty>> = fs::read_dir("./").unwrap()
.filter_map(|dir_entry| {
let path = dir_entry.unwrap().path();
let ext = match path.extension() {
Some(ext) => ext.to_str().unwrap().to_owned(),
None => return None,
};
match ext.as_ref() {
"pem" => match jwk_from_pem(path.as_path()) {
Ok(jwk) => Some(jwk),
_ => None,
},
_ => None,
}
})
.collect();
let jwks = JWKSet { keys: jwks };
Json(jwks)
}
#[derive(Debug, Serialize)]
struct OidcConfig {
pub jwks_uri: String,
}
#[get("/.well-known/openid-configuration")]
fn oidc_config() -> Json<OidcConfig> {
let config = OidcConfig {
jwks_uri: "https://auth.xeen.dev/oauth2/keys".to_string(),
};
Json(config)
}
impl<'a, 'r> FromRequest<'a, 'r> for User {
type Error = ();
fn from_request(request: &'a Request<'r>) -> Outcome<Self, Self::Error> {
let mut user_id = match request.cookies().get_private("user_id") {
Some(cookie) => cookie.value().to_string(),
None => return Outcome::Forward(()),
};
let conn = request.guard::<AuthDb>().unwrap();
match User::get_with_id(&conn, user_id) {
Ok(user) => Outcome::Success(user),
_ => Outcome::Forward(()),
}
}
}
#[derive(Serialize)]
struct HelloContext {
username: String,
}
#[get("/")]
fn hello(user: User) -> Template {
println!("User: {:?}", &user);
let context = HelloContext {
username: user.username,
};
Template::render("hello", &context)
}
fn routes() -> Vec<rocket::Route> {
routes![
hello,
oidc_config,
get_keys,
login,
login_form,
]
}
#[database("xeenauth")]
struct AuthDb(diesel::PgConnection);
fn main() {
env_logger::init();
rocket::ignite()
.attach(AuthDb::fairing())
.attach(Template::fairing())
.mount("/", routes())
.launch();
}