257 lines
11 KiB
Rust
257 lines
11 KiB
Rust
pub mod ratelimit;
|
|
|
|
use std::sync::Arc;
|
|
|
|
use axum::{Json, extract::Path, response::Html, routing::get};
|
|
use rand::seq::IndexedRandom;
|
|
use ratelimit::Ratelimit;
|
|
use reqwest::StatusCode;
|
|
use tokio::sync::Mutex;
|
|
|
|
use crate::{
|
|
bahn_api::{
|
|
basic_types::{route_ids::RouteId, station_ids::StationEvaNumber},
|
|
departures_arrivals::departures,
|
|
route_info::route_info,
|
|
station_search::query_stations,
|
|
transport_modes::{TransportMode, TransportModesSet},
|
|
},
|
|
storage::stations::{Station, Stations},
|
|
};
|
|
|
|
pub async fn main() {
|
|
let index = Arc::new(
|
|
tokio::fs::read_to_string("index.html")
|
|
.await
|
|
.expect("failed to load index.html (must be present in cwd)"),
|
|
);
|
|
|
|
let ratelimit = Ratelimit::new(10, 10);
|
|
|
|
let stations = Arc::new(Mutex::new(Stations::new()));
|
|
|
|
axum::serve(
|
|
tokio::net::TcpListener::bind(
|
|
std::env::var("ADDR").unwrap_or_else(|_| "[::1]:8000".to_owned()),
|
|
)
|
|
.await
|
|
.unwrap(),
|
|
axum::Router::new()
|
|
.route("/", get(|| async move { Html(index.as_ref().clone()) }))
|
|
.route(
|
|
"/query_stations/{query}/",
|
|
get({
|
|
let ratelimit = ratelimit.clone();
|
|
let stations = Arc::clone(&stations);
|
|
|Path(query): Path<String>| async move {
|
|
if query.trim().is_empty() {
|
|
return Err(StatusCode::NOT_FOUND);
|
|
}
|
|
ratelimit.wait().await?;
|
|
match query_stations(&query).await {
|
|
Err(e) => {
|
|
eprintln!(
|
|
"Tried to query station named {query}, but got error: {e:?}"
|
|
);
|
|
Err(StatusCode::INTERNAL_SERVER_ERROR)
|
|
}
|
|
Ok(queried_stations) => {
|
|
let json = queried_stations
|
|
.stations
|
|
.iter()
|
|
.map(|station| {
|
|
(
|
|
station.station.name.clone(),
|
|
station.station.id.eva_number.0.clone(),
|
|
transport_modes_num_from_set(&station.transport_modes),
|
|
)
|
|
})
|
|
.collect::<Vec<_>>();
|
|
let mut lock = stations.lock().await;
|
|
for station in queried_stations.stations.into_iter() {
|
|
lock.insert(
|
|
station.station.id.eva_number,
|
|
Station {
|
|
name: station.station.name,
|
|
db_station_id: station.station.id.db_station_id,
|
|
lat_lon: station.lat_lon,
|
|
transport_modes: Some(station.transport_modes),
|
|
},
|
|
);
|
|
}
|
|
drop(lock);
|
|
Ok(Json(json))
|
|
}
|
|
}
|
|
}
|
|
}),
|
|
)
|
|
.route(
|
|
"/query_departures/{transport_modes}/{eva_number}/",
|
|
get({
|
|
let ratelimit = ratelimit.clone();
|
|
let stations = Arc::clone(&stations);
|
|
|Path((transport_modes_num, eva_number)): Path<(u8, String)>| async move {
|
|
ratelimit.wait().await?;
|
|
let transport_modes = transport_modes_num_to_set(transport_modes_num);
|
|
if transport_modes.len() == 0 {
|
|
return Err(StatusCode::NOT_FOUND);
|
|
}
|
|
let eva_number = StationEvaNumber(eva_number);
|
|
if let Some(station) = stations.lock().await.get(&eva_number) {
|
|
match departures((&eva_number, station), transport_modes).await {
|
|
Err(e) => {
|
|
eprintln!(
|
|
"Tried to get departures at {}, but got error: {e:?}",
|
|
station.name,
|
|
);
|
|
Err(StatusCode::INTERNAL_SERVER_ERROR)
|
|
}
|
|
Ok(departures) => Ok(Json(
|
|
departures
|
|
.departures
|
|
.into_iter()
|
|
.map(|departure| {
|
|
(
|
|
departure.route,
|
|
departure
|
|
.stops
|
|
.into_iter()
|
|
.map(|stop| stop.name)
|
|
.collect::<Vec<_>>(),
|
|
departure.id.0,
|
|
)
|
|
})
|
|
.collect::<Vec<_>>(),
|
|
)),
|
|
}
|
|
} else {
|
|
Err(StatusCode::GONE)
|
|
}
|
|
}
|
|
}),
|
|
)
|
|
.route(
|
|
"/query_route/{route_id}/",
|
|
get({
|
|
let ratelimit = ratelimit.clone();
|
|
let stations = Arc::clone(&stations);
|
|
|Path(route_id): Path<String>| async move {
|
|
ratelimit.wait().await?;
|
|
if route_id.trim().is_empty() {
|
|
return Err(StatusCode::NOT_FOUND);
|
|
}
|
|
match route_info(RouteId(route_id)).await {
|
|
Err(e) => {
|
|
eprintln!(
|
|
"Tried to get route info but got error: {e:?}",
|
|
);
|
|
Err(StatusCode::INTERNAL_SERVER_ERROR)
|
|
}
|
|
Ok(route) => {
|
|
let json_stops = route.stops
|
|
.iter()
|
|
.map(|station| (
|
|
station.name.clone(),
|
|
station.id.eva_number.0.clone(),
|
|
))
|
|
.collect::<Vec<_>>();
|
|
let mut lock = stations.lock().await;
|
|
for station in route.stops.into_iter() {
|
|
lock.insert(
|
|
station.id.eva_number,
|
|
Station {
|
|
name: station.name,
|
|
db_station_id: station.id.db_station_id,
|
|
lat_lon: None,
|
|
transport_modes: None,
|
|
},
|
|
);
|
|
}
|
|
drop(lock);
|
|
Ok(Json((json_stops, route.canceled)))
|
|
}
|
|
}
|
|
}
|
|
}),
|
|
)
|
|
.route(
|
|
"/random_stations/{transport_modes_num}/{count}/",
|
|
get({
|
|
let stations = Arc::clone(&stations);
|
|
|Path((transport_modes_num, count)): Path<(u8, u8)>| async move {
|
|
let transport_modes = transport_modes_num_to_set(transport_modes_num);
|
|
if transport_modes.len() == 0 {
|
|
return Err(StatusCode::NOT_FOUND);
|
|
}
|
|
let lock = stations.lock().await;
|
|
let all_stations = lock
|
|
.stations
|
|
.iter()
|
|
.filter(|(_, station)| {
|
|
station
|
|
.transport_modes
|
|
.as_ref()
|
|
.is_some_and(|modes| modes.contains_any(&transport_modes))
|
|
})
|
|
.collect::<Vec<_>>();
|
|
let json = all_stations
|
|
.choose_multiple(
|
|
&mut rand::rng(),
|
|
all_stations.len().min(count as usize),
|
|
)
|
|
.map(|(eva_number, station)| {
|
|
(
|
|
station.name.clone(),
|
|
eva_number.0.clone(),
|
|
transport_modes_num_from_set(station.transport_modes.as_ref().expect(
|
|
"otherwise this would not have passed the pre-rand filter",
|
|
)),
|
|
)
|
|
})
|
|
.collect::<Vec<_>>();
|
|
drop(lock);
|
|
Ok(Json(json))
|
|
}
|
|
}),
|
|
),
|
|
)
|
|
.await
|
|
.unwrap();
|
|
}
|
|
|
|
fn transport_modes_num_to_set(num: u8) -> TransportModesSet {
|
|
let mut transport_modes = TransportModesSet::new();
|
|
for (i, m) in [
|
|
(1, [TransportMode::Bus, TransportMode::Bus]), // Bus
|
|
(2, [TransportMode::Tram, TransportMode::Ubahn]), // Tram+U
|
|
(4, [TransportMode::Sbahn, TransportMode::Sbahn]), // S-Bahn
|
|
(8, [TransportMode::Regional, TransportMode::Ir]), // RB/RE
|
|
(16, [TransportMode::EcIc, TransportMode::Ice]), // IC/ICE
|
|
] {
|
|
if (num & i) != 0 {
|
|
for m in m {
|
|
transport_modes.insert(m);
|
|
}
|
|
}
|
|
}
|
|
transport_modes
|
|
}
|
|
fn transport_modes_num_from_set(set: &TransportModesSet) -> u8 {
|
|
let mut num = 0;
|
|
for (i, m) in [
|
|
(1, [TransportMode::Bus, TransportMode::Bus]), // Bus
|
|
(2, [TransportMode::Tram, TransportMode::Ubahn]), // Tram+U
|
|
(4, [TransportMode::Sbahn, TransportMode::Sbahn]), // S-Bahn
|
|
(8, [TransportMode::Regional, TransportMode::Ir]), // RB/RE
|
|
(16, [TransportMode::EcIc, TransportMode::Ice]), // IC/ICE
|
|
] {
|
|
for m in m {
|
|
if set.contains(m) {
|
|
num |= i;
|
|
}
|
|
}
|
|
}
|
|
num
|
|
}
|