2025-08-19 23:42:39 +02:00

257 lines
11 KiB
Rust

pub mod ratelimit;
use std::sync::Arc;
use axum::{Json, extract::Path, response::Html, routing::get};
use rand::seq::IndexedRandom;
use ratelimit::Ratelimit;
use reqwest::StatusCode;
use tokio::sync::Mutex;
use crate::{
bahn_api::{
basic_types::{route_ids::RouteId, station_ids::StationEvaNumber},
departures_arrivals::departures,
route_info::route_info,
station_search::query_stations,
transport_modes::{TransportMode, TransportModesSet},
},
storage::stations::{Station, Stations},
};
pub async fn main() {
let index = Arc::new(
tokio::fs::read_to_string("index.html")
.await
.expect("failed to load index.html (must be present in cwd)"),
);
let ratelimit = Ratelimit::new(10, 10);
let stations = Arc::new(Mutex::new(Stations::new()));
axum::serve(
tokio::net::TcpListener::bind(
std::env::var("ADDR").unwrap_or_else(|_| "[::1]:8000".to_owned()),
)
.await
.unwrap(),
axum::Router::new()
.route("/", get(|| async move { Html(index.as_ref().clone()) }))
.route(
"/query_stations/{query}/",
get({
let ratelimit = ratelimit.clone();
let stations = Arc::clone(&stations);
|Path(query): Path<String>| async move {
if query.trim().is_empty() {
return Err(StatusCode::NOT_FOUND);
}
ratelimit.wait().await?;
match query_stations(&query).await {
Err(e) => {
eprintln!(
"Tried to query station named {query}, but got error: {e:?}"
);
Err(StatusCode::INTERNAL_SERVER_ERROR)
}
Ok(queried_stations) => {
let json = queried_stations
.stations
.iter()
.map(|station| {
(
station.station.name.clone(),
station.station.id.eva_number.0.clone(),
transport_modes_num_from_set(&station.transport_modes),
)
})
.collect::<Vec<_>>();
let mut lock = stations.lock().await;
for station in queried_stations.stations.into_iter() {
lock.insert(
station.station.id.eva_number,
Station {
name: station.station.name,
db_station_id: station.station.id.db_station_id,
lat_lon: station.lat_lon,
transport_modes: Some(station.transport_modes),
},
);
}
drop(lock);
Ok(Json(json))
}
}
}
}),
)
.route(
"/query_departures/{transport_modes}/{eva_number}/",
get({
let ratelimit = ratelimit.clone();
let stations = Arc::clone(&stations);
|Path((transport_modes_num, eva_number)): Path<(u8, String)>| async move {
ratelimit.wait().await?;
let transport_modes = transport_modes_num_to_set(transport_modes_num);
if transport_modes.len() == 0 {
return Err(StatusCode::NOT_FOUND);
}
let eva_number = StationEvaNumber(eva_number);
if let Some(station) = stations.lock().await.get(&eva_number) {
match departures((&eva_number, station), transport_modes).await {
Err(e) => {
eprintln!(
"Tried to get departures at {}, but got error: {e:?}",
station.name,
);
Err(StatusCode::INTERNAL_SERVER_ERROR)
}
Ok(departures) => Ok(Json(
departures
.departures
.into_iter()
.map(|departure| {
(
departure.route,
departure
.stops
.into_iter()
.map(|stop| stop.name)
.collect::<Vec<_>>(),
departure.id.0,
)
})
.collect::<Vec<_>>(),
)),
}
} else {
Err(StatusCode::GONE)
}
}
}),
)
.route(
"/query_route/{route_id}/",
get({
let ratelimit = ratelimit.clone();
let stations = Arc::clone(&stations);
|Path(route_id): Path<String>| async move {
ratelimit.wait().await?;
if route_id.trim().is_empty() {
return Err(StatusCode::NOT_FOUND);
}
match route_info(RouteId(route_id)).await {
Err(e) => {
eprintln!(
"Tried to get route info but got error: {e:?}",
);
Err(StatusCode::INTERNAL_SERVER_ERROR)
}
Ok(route) => {
let json_stops = route.stops
.iter()
.map(|station| (
station.name.clone(),
station.id.eva_number.0.clone(),
))
.collect::<Vec<_>>();
let mut lock = stations.lock().await;
for station in route.stops.into_iter() {
lock.insert(
station.id.eva_number,
Station {
name: station.name,
db_station_id: station.id.db_station_id,
lat_lon: None,
transport_modes: None,
},
);
}
drop(lock);
Ok(Json((json_stops, route.canceled)))
}
}
}
}),
)
.route(
"/random_stations/{transport_modes_num}/{count}/",
get({
let stations = Arc::clone(&stations);
|Path((transport_modes_num, count)): Path<(u8, u8)>| async move {
let transport_modes = transport_modes_num_to_set(transport_modes_num);
if transport_modes.len() == 0 {
return Err(StatusCode::NOT_FOUND);
}
let lock = stations.lock().await;
let all_stations = lock
.stations
.iter()
.filter(|(_, station)| {
station
.transport_modes
.as_ref()
.is_some_and(|modes| modes.contains_any(&transport_modes))
})
.collect::<Vec<_>>();
let json = all_stations
.choose_multiple(
&mut rand::rng(),
all_stations.len().min(count as usize),
)
.map(|(eva_number, station)| {
(
station.name.clone(),
eva_number.0.clone(),
transport_modes_num_from_set(station.transport_modes.as_ref().expect(
"otherwise this would not have passed the pre-rand filter",
)),
)
})
.collect::<Vec<_>>();
drop(lock);
Ok(Json(json))
}
}),
),
)
.await
.unwrap();
}
fn transport_modes_num_to_set(num: u8) -> TransportModesSet {
let mut transport_modes = TransportModesSet::new();
for (i, m) in [
(1, [TransportMode::Bus, TransportMode::Bus]), // Bus
(2, [TransportMode::Tram, TransportMode::Ubahn]), // Tram+U
(4, [TransportMode::Sbahn, TransportMode::Sbahn]), // S-Bahn
(8, [TransportMode::Regional, TransportMode::Ir]), // RB/RE
(16, [TransportMode::EcIc, TransportMode::Ice]), // IC/ICE
] {
if (num & i) != 0 {
for m in m {
transport_modes.insert(m);
}
}
}
transport_modes
}
fn transport_modes_num_from_set(set: &TransportModesSet) -> u8 {
let mut num = 0;
for (i, m) in [
(1, [TransportMode::Bus, TransportMode::Bus]), // Bus
(2, [TransportMode::Tram, TransportMode::Ubahn]), // Tram+U
(4, [TransportMode::Sbahn, TransportMode::Sbahn]), // S-Bahn
(8, [TransportMode::Regional, TransportMode::Ir]), // RB/RE
(16, [TransportMode::EcIc, TransportMode::Ice]), // IC/ICE
] {
for m in m {
if set.contains(m) {
num |= i;
}
}
}
num
}