use std::sync::Arc; use askama::Template; use axum::{ Router, extract::{MatchedPath, Path, State}, http::Request, response::Html, routing::{get, post}, }; use parking_lot::Mutex; use rusqlite::{CachedStatement, Connection}; use tower_http::trace::TraceLayer; use tower_request_id::{RequestId, RequestIdLayer}; use tracing::{debug, error, info, info_span}; use tracing_subscriber::{layer::SubscriberExt as _, util::SubscriberInitExt as _}; #[derive(Template)] #[template(path = "index.html")] struct IndexTemplate { foods: Vec, sum: i32, } #[derive(Template)] #[template(path = "food-update.html")] struct FoodUpdateTemplate { food: Food, sum: i32, } #[derive(Debug, Clone, PartialEq)] struct Food { id: i32, portion: String, name: String, kc_per_serving: i32, target_servings: i32, actual_servings: i32, color: String, } #[derive(Clone)] struct PreparedStatements {} impl<'conn> PreparedStatements { fn check(conn: &Connection) -> rusqlite::Result { conn.prepare_cached(include_str!("create_tables.sql"))?; conn.prepare_cached(include_str!("increase.sql"))?; conn.prepare_cached(include_str!("decrease.sql"))?; conn.prepare_cached(include_str!("get_food.sql"))?; conn.prepare_cached(include_str!("get_foods.sql"))?; conn.prepare_cached(include_str!("get_sum.sql"))?; Ok(PreparedStatements {}) } fn create_tables(conn: &'conn Connection) -> CachedStatement<'conn> { conn.prepare_cached(include_str!("create_tables.sql")) .expect("cached statement is invalid") } fn increase(conn: &'conn Connection) -> CachedStatement<'conn> { conn.prepare_cached(include_str!("increase.sql")) .expect("cached statement is invalid") } fn decrease(conn: &'conn Connection) -> CachedStatement<'conn> { conn.prepare_cached(include_str!("decrease.sql")) .expect("cached statement is invalid") } fn get_food(conn: &'conn Connection) -> CachedStatement<'conn> { conn.prepare_cached(include_str!("get_food.sql")) .expect("cached statement is invalid") } fn get_foods(conn: &'conn Connection) -> CachedStatement<'conn> { conn.prepare_cached(include_str!("get_foods.sql")) .expect("cached statement is invalid") } fn get_sum(conn: &'conn Connection) -> CachedStatement<'conn> { conn.prepare_cached(include_str!("get_sum.sql")) .expect("cached statement is invalid") } } type ConnState = Arc>; #[tokio::main] async fn main() -> Result<(), std::io::Error> { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| { // axum logs rejections from built-in extractors with the `axum::rejection` // target, at `TRACE` level. `axum::rejection=trace` enables showing those events format!( "{}=debug,tower_http=debug,axum::rejection=trace", env!("CARGO_CRATE_NAME") ) .into() }), ) .with(tracing_subscriber::fmt::layer()) .init(); let db_connecion_str = "./foods.db".to_string(); debug!(db_connecion_str, "opening database"); let conn = Connection::open(db_connecion_str).expect("failed to open database"); PreparedStatements::check(&conn).expect("failed to prepare sql statements"); if let Err(e) = PreparedStatements::create_tables(&conn).execute(()) { error!(?e, "failed to create tables"); panic!("failed to create tables: {:#?}", e); } let app = Router::new() .route("/", get(root)) .route("/increase/{id}", post(increase)) .route("/decrease/{id}", post(decrease)) .layer( TraceLayer::new_for_http().make_span_with(|request: &Request<_>| { let matched_path = request .extensions() .get::() .map(MatchedPath::as_str); let request_id = request .extensions() .get::() .map(ToString::to_string) .unwrap_or_else(|| "unknown".into()); info_span!( "request", method = ?request.method(), matched_path, uri = ?request.uri(), id = %request_id, ) }), ) .layer(RequestIdLayer) .with_state(Arc::new(Mutex::new(conn))); let address = "0.0.0.0:3001"; let listener = tokio::net::TcpListener::bind(address) .await .expect("failed to bind to address"); info!( "listening on {}", listener .local_addr() .expect("failed to get local listening address") ); axum::serve(listener, app).await } fn get_foods(conn: &ConnState) -> Vec { let conn = conn.lock(); let mut stmt = PreparedStatements::get_foods(&conn); let foods: Vec<_> = stmt .query_map((), |row| { Ok(Food { id: row.get(0).unwrap(), portion: row.get(1).unwrap(), name: row.get(2).unwrap(), kc_per_serving: row.get(3).unwrap(), target_servings: row.get(4).unwrap(), actual_servings: row.get(5).unwrap(), color: row.get(6).unwrap(), }) }) .unwrap() .collect::>() .unwrap(); debug!(num_foods = foods.len()); foods } fn get_sum(conn: &Arc>) -> i32 { let conn = conn.lock(); let mut stmt = PreparedStatements::get_sum(&conn); let sum = stmt.query_one((), |row| row.get(0)).unwrap(); debug!(sum); sum } async fn root(State(conn): State) -> Html { let foods = get_foods(&conn); let sum = get_sum(&conn); let index = IndexTemplate { foods, sum }; Html(index.render().unwrap()) } fn do_increase(conn: &Arc>, id: i32) { let conn = conn.lock(); let mut stmt = PreparedStatements::increase(&conn); let new: i32 = stmt.query_one((id,), |row| row.get(0)).unwrap(); debug!(id, new_serving_count = new, "increase"); } fn get_food(conn: &Arc>, id: i32) -> Food { let conn = conn.lock(); let mut stmt = PreparedStatements::get_food(&conn); let food = stmt .query_one((id,), |row| { Ok(Food { id: row.get(0).unwrap(), portion: row.get(1).unwrap(), name: row.get(2).unwrap(), kc_per_serving: row.get(3).unwrap(), target_servings: row.get(4).unwrap(), actual_servings: row.get(5).unwrap(), color: row.get(6).unwrap(), }) }) .unwrap(); debug!(?food); food } async fn increase(State(conn): State>>, Path(id): Path) -> Html { do_increase(&conn, id); let food = get_food(&conn, id); let sum = get_sum(&conn); let update = FoodUpdateTemplate { food, sum }; Html(update.render().unwrap()) } fn do_decrease(conn: &Arc>, id: i32) { let conn = conn.lock(); let mut stmt = PreparedStatements::decrease(&conn); let new: i32 = stmt.query_one((id,), |row| row.get(0)).unwrap(); debug!(id, new_serving_count = new, "decrease"); } async fn decrease(State(conn): State>>, Path(id): Path) -> Html { do_decrease(&conn, id); let food = get_food(&conn, id); let sum = get_sum(&conn); let update = FoodUpdateTemplate { food, sum }; Html(update.render().unwrap()) }