team04_server/lobby/game/
unit_purchase.rs1use std::{sync::Arc, time::Duration};
2
3use rand::seq::IndexedRandom;
4use tokio::sync::Semaphore;
5
6use crate::{
7 lobby::state::{LobbyPhase, LockedLobbyState, SharedLobbyState},
8 log,
9 messages::{MessageTx, unit_options::UnitOptions},
10};
11
12impl SharedLobbyState {
13 pub async fn unit_purchase(&self, current_round: u64) -> LockedLobbyState {
26 let mut pfx = log::pfx();
27 pfx.lobby(self.id());
28
29 let mut lock = self.lock().await;
30 assert!(lock.game_started());
31
32 let timeout = Duration::from_millis(lock.configs.game_config.timeout_unit_shop_phase);
33 lock.phase = LobbyPhase::UnitShopPhase;
34 lock.broadcast_gamestate().await;
35
36 let units = {
38 let game_config = &lock.configs.game_config;
39 let unit_config = &lock.configs.unit_config;
40 let unit_probabilities = game_config
41 .unit_probabilities
42 .get(current_round.saturating_sub(1) as usize)
43 .unwrap_or(
44 game_config
45 .unit_probabilities
46 .last()
47 .expect("unit_probabilities in game config should not be empty"),
48 );
49 let mut units = vec![];
50 for _ in 0..3 {
51 let random_number = rand::random_range(0.0..1.0f64);
52 units.push('pick_unit: {
53 let level_preference = if random_number < unit_probabilities.level1 {
54 [1, 2, 3]
56 } else if random_number < unit_probabilities.level1 + unit_probabilities.level2
57 {
58 if rand::random() { [2, 1, 3] } else { [2, 3, 1] }
60 } else {
61 [3, 2, 1]
63 };
64 for level in level_preference {
65 let options = [
66 &unit_config.level1,
67 &unit_config.level2,
68 &unit_config.level3,
69 ][level - 1]
70 .iter()
71 .map(|unit| unit.unit_type)
72 .filter(|unit| !units.contains(unit))
73 .collect::<Vec<_>>();
74 if let Some(unit) = options.choose(&mut rand::rng()) {
75 break 'pick_unit *unit;
76 } else {
77 }
79 }
80 unreachable!(
81 "unreachable assuming a valid config: can only happen if there are less than 3 units in total"
82 )
83 });
84 }
85 assert_eq!(units.len(), 3);
86 [units[0], units[1], units[2]]
87 };
88 lock.hist_unit_options.push(units);
89 log::debug!(
90 "Unit purchase phase {} starting with units\n {:?}, {:?}, and {:?}",
91 lock.hist_unit_options.len(),
92 units[0], units[1], units[2]; &pfx
93 );
94
95 let player_count = lock.clients.players.len();
97 let semaphore = Arc::new(Semaphore::new(player_count));
98 for player in lock.clients.players.players_alive_mut() {
99 let permit = Arc::clone(&semaphore)
100 .try_acquire_owned()
101 .expect("there are enough permits available for all players");
102 player.unit_choice = Err(Some(permit));
103 player.unit_choice_allowed = true;
104 }
105
106 lock.clients
108 .broadcast_message(&UnitOptions(units).serialize())
109 .await;
110
111 drop(lock);
112
113 let hit_timeout =
115 tokio::time::timeout(timeout, semaphore.acquire_many(player_count as u32))
116 .await
117 .is_err();
118 log::debug!("Unit purchase phase ended {}", if hit_timeout { "due to timeout" } else { "early (all players have purchased)" }; &pfx);
119
120 let mut lock = self.lock().await;
122 for player in lock.clients.players.players_alive_mut() {
123 player.unit_choice_allowed = false;
124 if player.unit_choice.is_err() {
125 player.unit_choice = Ok(*units.choose(&mut rand::rng()).unwrap());
126 }
127 }
128 lock
129 }
130}
131
132#[cfg(test)]
133mod test {
134 use rand::seq::IteratorRandom;
135 use serde::Deserialize;
136 use strum::VariantArray;
137
138 use crate::{
139 config::game::UnitProbability,
140 lobby::test::{
141 FakeCon, config_set_modified, get_server_and_lobby_with_config, player_join,
142 },
143 log,
144 messages::{RxMessage, error::error_code, unit_chosen::UnitChosen},
145 unit::UnitType,
146 };
147
148 #[tokio::test]
149 async fn test_unit_choosing() {
150 for [p1, p2, p3] in [
154 [1.0, 0.0, 0.0],
155 [0.0, 1.0, 0.0],
156 [0.0, 0.0, 1.0],
157 [1.0 / 3.0; 3],
158 ] {
159 for &try_valid_unit_choices in if p1 == p2 && p2 == p3 {
160 &[true, false][..]
161 } else {
162 &[true][..]
163 } {
164 let (server, lobby) = get_server_and_lobby_with_config(config_set_modified(
165 |game_config| {
166 game_config.unit_probabilities = vec![UnitProbability {
167 level1: p1,
168 level2: p2,
169 level3: p3,
170 }];
171 },
172 |_| {},
173 |_| {},
174 ));
175 let (p1id, _, mut p1con) = player_join(&server, &lobby).await;
176 let (p2id, _, mut p2con) = player_join(&server, &lobby).await;
177 let (p3id, _, mut p3con) = player_join(&server, &lobby).await;
178 lobby.lock().await.start_game_now().await;
179 p1con.clear();
180 p2con.clear();
181 p3con.clear();
182
183 let phase = tokio::spawn({
185 let lobby = lobby.clone();
186 async move {
187 let _ = crate::lobby::game::unit_purchase(&lobby, &log::pfx(), 1).await;
188 }
189 });
190
191 p1con.recv().await;
193 p2con.recv().await;
194 p3con.recv().await;
195
196 #[derive(Deserialize)]
198 struct UnitOptions {
199 options: [UnitType; 3],
200 }
201 let UnitOptions { options: p1opts } =
202 serde_json::from_str(&p1con.recv().await.unwrap())
203 .expect("should receive a unit options message");
204 let UnitOptions { options: p2opts } =
205 serde_json::from_str(&p2con.recv().await.unwrap())
206 .expect("should receive a unit options message");
207 let UnitOptions { options: p3opts } =
208 serde_json::from_str(&p3con.recv().await.unwrap())
209 .expect("should receive a unit options message");
210 assert_eq!(p1opts, p2opts);
211 assert_eq!(p1opts, p3opts);
212
213 for (pid, choice) in [(p1id, p1opts[0]), (p2id, p1opts[1]), (p3id, p1opts[2])] {
215 lobby
216 .lock()
217 .await
218 .message_from(
219 &pid,
220 RxMessage::UnitChosen(UnitChosen {
221 choice: if try_valid_unit_choices {
222 choice
223 } else {
224 UnitType::VARIANTS
225 .iter()
226 .copied()
227 .filter(|unit_type| !p1opts.contains(unit_type))
228 .choose(&mut rand::rng())
229 .unwrap()
230 },
231 }),
232 String::new(),
233 )
234 .await;
235 }
236
237 if !try_valid_unit_choices {
238 #[derive(Deserialize)]
240 struct ErrorMessage {
241 pub code: String,
242 }
243 let ErrorMessage { code: p1code, .. } =
244 serde_json::from_str(&p1con.recv().await.unwrap())
245 .expect("should receive an error message");
246 let ErrorMessage { code: p2code, .. } =
247 serde_json::from_str(&p2con.recv().await.unwrap())
248 .expect("should receive an error message");
249 let ErrorMessage { code: p3code, .. } =
250 serde_json::from_str(&p3con.recv().await.unwrap())
251 .expect("should receive an error message");
252 assert_eq!(p1code.as_str(), error_code::UNIT_CHOSEN_NOT_IN_OPTIONS);
253 assert_eq!(p2code.as_str(), error_code::UNIT_CHOSEN_NOT_IN_OPTIONS);
254 assert_eq!(p3code.as_str(), error_code::UNIT_CHOSEN_NOT_IN_OPTIONS);
255 } else {
256 phase.await.unwrap();
258 let units = lobby
259 .lock()
260 .await
261 .clients
262 .players
263 .players_alive()
264 .map(|p| *p.unit_choice.as_ref().unwrap())
265 .collect::<Vec<_>>();
266 assert!(units[0] != units[1] && units[1] != units[2] && units[0] != units[2]);
268 assert!(p1opts.contains(&units[0]));
270 assert!(p1opts.contains(&units[1]));
271 assert!(p1opts.contains(&units[2]));
272 assert!(units.contains(&p1opts[0]));
274 assert!(units.contains(&p1opts[1]));
275 assert!(units.contains(&p1opts[2]));
276 for player in lobby.lock().await.clients.players.players_alive() {
277 assert_eq!(player.unit_bank.len(), 1);
278 assert!(p1opts.contains(&player.unit_bank[0]));
279 }
280 }
281 }
282 }
283 }
284}