pub mod ratelimit; use std::sync::Arc; use axum::{Json, extract::Path, response::Html, routing::get}; use rand::seq::IndexedRandom; use ratelimit::Ratelimit; use reqwest::StatusCode; use tokio::sync::Mutex; use crate::{ bahn_api::{ basic_types::{route_ids::RouteId, station_ids::StationEvaNumber}, departures_arrivals::departures, route_info::route_info, station_search::query_stations, transport_modes::{TransportMode, TransportModesSet}, }, storage::stations::{Station, Stations}, }; pub async fn main() { let index = Arc::new( tokio::fs::read_to_string("index.html") .await .expect("failed to load index.html (must be present in cwd)"), ); let ratelimit = Ratelimit::new(10, 10); let stations = Arc::new(Mutex::new(Stations::new())); axum::serve( tokio::net::TcpListener::bind( std::env::var("ADDR").unwrap_or_else(|_| "[::1]:8000".to_owned()), ) .await .unwrap(), axum::Router::new() .route("/", get(|| async move { Html(index.as_ref().clone()) })) .route( "/query_stations/{query}/", get({ let ratelimit = ratelimit.clone(); let stations = Arc::clone(&stations); |Path(query): Path| async move { if query.trim().is_empty() { return Err(StatusCode::NOT_FOUND); } ratelimit.wait().await?; match query_stations(&query).await { Err(e) => { eprintln!( "Tried to query station named {query}, but got error: {e:?}" ); Err(StatusCode::INTERNAL_SERVER_ERROR) } Ok(queried_stations) => { let json = queried_stations .stations .iter() .map(|station| { ( station.station.name.clone(), station.station.id.eva_number.0.clone(), transport_modes_num_from_set(&station.transport_modes), ) }) .collect::>(); let mut lock = stations.lock().await; for station in queried_stations.stations.into_iter() { lock.insert( station.station.id.eva_number, Station { name: station.station.name, db_station_id: station.station.id.db_station_id, lat_lon: station.lat_lon, transport_modes: Some(station.transport_modes), }, ); } drop(lock); Ok(Json(json)) } } } }), ) .route( "/query_departures/{transport_modes}/{eva_number}/", get({ let ratelimit = ratelimit.clone(); let stations = Arc::clone(&stations); |Path((transport_modes_num, eva_number)): Path<(u8, String)>| async move { ratelimit.wait().await?; let transport_modes = transport_modes_num_to_set(transport_modes_num); if transport_modes.len() == 0 { return Err(StatusCode::NOT_FOUND); } let eva_number = StationEvaNumber(eva_number); if let Some(station) = stations.lock().await.get(&eva_number) { match departures((&eva_number, station), transport_modes).await { Err(e) => { eprintln!( "Tried to get departures at {}, but got error: {e:?}", station.name, ); Err(StatusCode::INTERNAL_SERVER_ERROR) } Ok(departures) => Ok(Json( departures .departures .into_iter() .map(|departure| { ( departure.route, departure .stops .into_iter() .map(|stop| stop.name) .collect::>(), departure.id.0, ) }) .collect::>(), )), } } else { Err(StatusCode::GONE) } } }), ) .route( "/query_route/{route_id}/", get({ let ratelimit = ratelimit.clone(); let stations = Arc::clone(&stations); |Path(route_id): Path| async move { ratelimit.wait().await?; if route_id.trim().is_empty() { return Err(StatusCode::NOT_FOUND); } match route_info(RouteId(route_id)).await { Err(e) => { eprintln!( "Tried to get route info but got error: {e:?}", ); Err(StatusCode::INTERNAL_SERVER_ERROR) } Ok(route) => { let json_stops = route.stops .iter() .map(|station| ( station.name.clone(), station.id.eva_number.0.clone(), )) .collect::>(); let mut lock = stations.lock().await; for station in route.stops.into_iter() { lock.insert( station.id.eva_number, Station { name: station.name, db_station_id: station.id.db_station_id, lat_lon: None, transport_modes: None, }, ); } drop(lock); Ok(Json((json_stops, route.canceled))) } } } }), ) .route( "/random_stations/{transport_modes_num}/{count}/", get({ let stations = Arc::clone(&stations); |Path((transport_modes_num, count)): Path<(u8, u8)>| async move { let transport_modes = transport_modes_num_to_set(transport_modes_num); if transport_modes.len() == 0 { return Err(StatusCode::NOT_FOUND); } let lock = stations.lock().await; let all_stations = lock .stations .iter() .filter(|(_, station)| { station .transport_modes .as_ref() .is_some_and(|modes| modes.contains_any(&transport_modes)) }) .collect::>(); let json = all_stations .choose_multiple( &mut rand::rng(), all_stations.len().min(count as usize), ) .map(|(eva_number, station)| { ( station.name.clone(), eva_number.0.clone(), transport_modes_num_from_set(station.transport_modes.as_ref().expect( "otherwise this would not have passed the pre-rand filter", )), ) }) .collect::>(); drop(lock); Ok(Json(json)) } }), ), ) .await .unwrap(); } fn transport_modes_num_to_set(num: u8) -> TransportModesSet { let mut transport_modes = TransportModesSet::new(); for (i, m) in [ (1, [TransportMode::Bus, TransportMode::Bus]), // Bus (2, [TransportMode::Tram, TransportMode::Ubahn]), // Tram+U (4, [TransportMode::Sbahn, TransportMode::Sbahn]), // S-Bahn (8, [TransportMode::Regional, TransportMode::Ir]), // RB/RE (16, [TransportMode::EcIc, TransportMode::Ice]), // IC/ICE ] { if (num & i) != 0 { for m in m { transport_modes.insert(m); } } } transport_modes } fn transport_modes_num_from_set(set: &TransportModesSet) -> u8 { let mut num = 0; for (i, m) in [ (1, [TransportMode::Bus, TransportMode::Bus]), // Bus (2, [TransportMode::Tram, TransportMode::Ubahn]), // Tram+U (4, [TransportMode::Sbahn, TransportMode::Sbahn]), // S-Bahn (8, [TransportMode::Regional, TransportMode::Ir]), // RB/RE (16, [TransportMode::EcIc, TransportMode::Ice]), // IC/ICE ] { for m in m { if set.contains(m) { num |= i; } } } num }