diff --git a/config.json b/config.json index 9ceb643..ff58122 100644 --- a/config.json +++ b/config.json @@ -1,62 +1,72 @@ -{ - "ocr_regions": [ - { - "name": "lap", - "x": 2300, - "y": 46, - "width": 140, - "height": 90 - }, - { - "name": "health", - "x": 90, - "y": 1364, - "width": 52, - "height": 24 - }, - { - "name": "gas", - "x": 208, - "y": 1364, - "width": 52, - "height": 24 - }, - { - "name": "tyres", - "x": 325, - "y": 1364, - "width": 52, - "height": 24 - }, - { - "name": "best", - "x": 2325, - "y": 169, - "width": 183, - "height": 43 - }, - { - "name": "lap_time", - "x": 2325, - "y": 222, - "width": 183, - "height": 43 - } - ], - "track_region": { - "name": "track", - "x": 2020, - "y": 1030, - "width": 540, - "height": 410, - "threshold": 0.85 - }, - "penalty_orange_region": { - "name": "penalty", - "x": 989, - "y": 117, - "width": 30, - "height": 30 - }, - "track_recognition_threshold": 10 +{ + "ocr_regions": [ + { + "name": "lap", + "x": 2300, + "y": 46, + "width": 140, + "height": 90, + "threshold": null, + "use_ocr_cache": null + }, + { + "name": "health", + "x": 90, + "y": 1364, + "width": 52, + "height": 24, + "threshold": null, + "use_ocr_cache": null + }, + { + "name": "gas", + "x": 208, + "y": 1364, + "width": 52, + "height": 24, + "threshold": null, + "use_ocr_cache": null + }, + { + "name": "tyres", + "x": 325, + "y": 1364, + "width": 52, + "height": 24, + "threshold": null, + "use_ocr_cache": null + }, + { + "name": "best", + "x": 2325, + "y": 169, + "width": 183, + "height": 43, + "threshold": null, + "use_ocr_cache": null + }, + { + "name": "lap_time", + "x": 2325, + "y": 222, + "width": 183, + "height": 43, + "threshold": null, + "use_ocr_cache": null + } + ], + "track_region": { + "name": "track", + "x": 2020, + "y": 1030, + "width": 540, + "height": 410, + "threshold": 0.85, + "use_ocr_cache": null + }, + "ocr_interval_ms": 500, + "track_recognition_threshold": 10, + "dump_frame_fraction": null, + "light_mode": false, + "font_scale": 1.3 } \ No newline at end of file diff --git a/src/analysis.rs b/src/analysis.rs index 72a800a..945b02c 100644 --- a/src/analysis.rs +++ b/src/analysis.rs @@ -145,7 +145,7 @@ fn add_saved_frame( image: retained, rgb_image: extracted, img_hash: hash, - recognized_text: ocr_results.get(®ion.name).cloned(), + recognized_text: ocr_results.get(®ion.name).cloned().unwrap_or_default(), }, ); } @@ -224,7 +224,7 @@ pub fn run_control_loop(state: SharedAppState) { if let Err(e) = run_loop_once(&mut capturer, &state) { eprintln!("Error in control loop: {:?}", e) } - let interval = state.lock().unwrap().config.ocr_interval_ms.unwrap_or(500); + let interval = state.lock().unwrap().config.ocr_interval_ms; thread::sleep(Duration::from_millis(interval)); } } diff --git a/src/config.rs b/src/config.rs index 7427643..3140029 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,23 +1,55 @@ -use std::path::PathBuf; +use std::{path::PathBuf, sync::Arc}; use anyhow::Result; use serde::{de::DeserializeOwned, Deserialize, Serialize}; use crate::image_processing::Region; +fn default_track_recognition_threshold() -> u32 { + 10 +} +fn default_ocr_interval_ms() -> u64 { + 500 +} +fn default_font_scale() -> f32 { + 1.0 +} + #[derive(Default, Serialize, Deserialize, Clone)] pub struct Config { pub ocr_regions: Vec, pub track_region: Option, - pub ocr_interval_ms: Option, - pub track_recognition_threshold: Option, + + #[serde(default = "default_ocr_interval_ms")] + pub ocr_interval_ms: u64, + + #[serde(default = "default_track_recognition_threshold")] + pub track_recognition_threshold: u32, + pub dump_frame_fraction: Option, + + #[serde(default = "Default::default")] + pub light_mode: bool, + + #[serde(default = "default_font_scale")] + pub font_scale: f32, } impl Config { pub fn load() -> Result { load_config_or_make_default("config.json", include_str!("configs/config.default.json")) } + + pub fn update_and_save( + self: &mut Arc, + update_fn: F, + ) -> Result<()> { + let mut config = (**self).clone(); + update_fn(&mut config); + save_json_config("config.json", &config)?; + *self = Arc::new(config); + Ok(()) + } } pub fn load_config_or_make_default(path: &str, default: &str) -> Result { diff --git a/src/learned_tracks.rs b/src/learned_tracks.rs index 197e925..5a12e98 100644 --- a/src/learned_tracks.rs +++ b/src/learned_tracks.rs @@ -28,8 +28,7 @@ impl LearnedTracks { for (learned_hash_b64, learned_track) in &self.learned_tracks { let learned_hash: ImageHash> = ImageHash::from_base64(learned_hash_b64).ok()?; let current_hash: ImageHash> = ImageHash::from_base64(hash).ok()?; - if current_hash.dist(&learned_hash) <= config.track_recognition_threshold.unwrap_or(10) - { + if current_hash.dist(&learned_hash) <= config.track_recognition_threshold { return Some(learned_track.to_owned()); } } diff --git a/src/main.rs b/src/main.rs index ca28135..0d321d1 100644 --- a/src/main.rs +++ b/src/main.rs @@ -23,7 +23,7 @@ use std::{ use analysis::save_frames_from; use config::Config; use eframe::{ - egui::{self, Ui, Visuals}, + egui::{self, FontTweak, Ui, Visuals}, emath::Vec2, epaint::Color32, }; @@ -238,20 +238,22 @@ fn show_race_state( }); } -fn show_debug_frames(ui: &mut Ui, debug_frames: &mut HashMap) { +fn show_debug_frames( + ui: &mut Ui, + ocr_db: &OcrDatabase, + debug_frames: &mut HashMap, +) { for (name, debug_image) in debug_frames.iter_mut() { ui.label(name); - if let Some(text) = &debug_image.recognized_text { - ui.label(text); - } - if ui - .button(&debug_image.img_hash) - .on_hover_text("Copy") - .clicked() - { - ui.output().copied_text = debug_image.img_hash.clone(); - } + ui.text_edit_singleline(&mut debug_image.recognized_text); debug_image.image.show_max_size(ui, Vec2::new(300.0, 300.0)); + + if ui.button("Learn OCR").clicked() { + let hashes = ocr::compute_box_hashes(&debug_image.rgb_image); + ocr_db + .learn_phrase(&hashes, &debug_image.recognized_text) + .unwrap(); + } ui.separator(); } } @@ -298,6 +300,20 @@ fn show_combo_box(ui: &mut Ui, name: &str, label: &str, options: &[String], valu impl eframe::App for AppUi { fn update(&mut self, ctx: &egui::Context, _frame: &mut eframe::Frame) { let mut state = self.state.lock().unwrap(); + if state.config.light_mode { + ctx.set_visuals(Visuals::light()); + } else { + ctx.set_visuals(Visuals::dark()); + } + let mut fonts = egui::FontDefinitions::default(); + for font in fonts.font_data.values_mut() { + *font = font.clone().tweak(FontTweak { + scale: state.config.font_scale, + ..Default::default() + }); + } + ctx.set_fonts(fonts); + let ocr_db = state.ocr_db.clone(); let mut debug_lap_window = self.ui_state.debug_lap.is_some(); @@ -310,7 +326,7 @@ impl eframe::App for AppUi { .show_max_size(ui, Vec2::new(800.0, 600.0)); ui.separator(); if let Some(debug_lap) = &mut self.ui_state.debug_lap { - show_debug_frames(ui, &mut debug_lap.debug_regions); + show_debug_frames(ui, &ocr_db, &mut debug_lap.debug_regions); } show_config_controls(ui, &mut self.ui_state, state.deref_mut()); } @@ -320,7 +336,6 @@ impl eframe::App for AppUi { self.ui_state.debug_lap = None; } - ctx.set_visuals(Visuals::dark()); egui::SidePanel::left("frame").show(ctx, |ui| { if let Some(frame) = &state.last_frame { ui.heading("Race data"); @@ -349,7 +364,7 @@ impl eframe::App for AppUi { ui.separator(); ui.heading("Strategy"); if let Some(tyre_wear) = race.tyre_wear() { - ui.heading(&format!("p50 Tyre Wear: {}", tyre_wear)); + ui.label(&format!("Median Tyre Wear: {}", tyre_wear)); if let Some(tyres) = frame.tyres { ui.label(&format!( "Out of tires in {:.1} lap(s)", @@ -358,7 +373,7 @@ impl eframe::App for AppUi { } } if let Some(gas_wear) = race.gas_per_lap() { - ui.heading(&format!("p50 Gas Wear: {}", gas_wear)); + ui.label(&format!("Median Gas Wear: {}", gas_wear)); if let Some(gas) = frame.gas { ui.label(&format!( "Out of gas in {:.1} lap(s)", @@ -368,20 +383,55 @@ impl eframe::App for AppUi { } } } - - ui.separator(); - ui.checkbox(&mut state.debug_frames, "Debug OCR regions"); - if state.config.dump_frame_fraction.is_some() { - ui.checkbox( - &mut state.should_sample_ocr_data, - "Dump OCR training frames", - ); - } }); egui::CentralPanel::default().show(ctx, |ui| { + egui::menu::bar(ui, |ui| { + egui::menu::menu_button(ui, "Control", |ui| { + if ui.button("Debug OCR regions").clicked() { + state.debug_frames = !state.debug_frames; + } + if state.config.dump_frame_fraction.is_some() { + let button_text = if state.should_sample_ocr_data { + "Stop OCR training dump" + } else { + "Dump OCR training data" + }; + if ui.button(button_text).clicked() { + state.should_sample_ocr_data = !state.should_sample_ocr_data; + } + } + show_config_controls(ui, &mut self.ui_state, &mut state); + }); + + egui::menu::menu_button(ui, "Preferences", |ui| { + let light_mode_text = if state.config.light_mode { + "☀ light mode" + } else { + "☀ dark mode" + }; + if ui.button(light_mode_text).clicked() { + state + .config + .update_and_save(|config| config.light_mode = !config.light_mode) + .unwrap(); + } + + let mut font_scale = state.config.font_scale; + ui.add(egui::Slider::new(&mut font_scale, 0.1..=5.0).text("Font scale")); + if font_scale != state.config.font_scale { + state + .config + .update_and_save(|config| { + config.font_scale = font_scale; + }) + .unwrap(); + } + }); + }); + ui.separator(); + egui::ScrollArea::vertical().show(ui, |ui| { let config = state.config.clone(); - let _learned_tracks = state.learned_tracks.clone(); if let Some(race) = &mut state.current_race { ui.heading(&format!("Current Race: {}", race.name())); show_race_state( @@ -465,8 +515,8 @@ impl eframe::App for AppUi { if state.debug_frames { egui::SidePanel::right("screenshots").show(ctx, |ui| { egui::ScrollArea::vertical().show(ui, |ui| { - show_debug_frames(ui, &mut state.saved_frames); - show_config_controls(ui, &mut self.ui_state, state.deref_mut()); + let ocr_db = state.ocr_db.clone(); + show_debug_frames(ui, ocr_db.as_ref(), &mut state.saved_frames); }); }); } diff --git a/src/state.rs b/src/state.rs index 590a99d..22ead1f 100644 --- a/src/state.rs +++ b/src/state.rs @@ -128,7 +128,7 @@ pub struct DebugOcrFrame { pub image: RetainedImage, pub rgb_image: RgbImage, pub img_hash: String, - pub recognized_text: Option, + pub recognized_text: String, } #[derive(Default)]