diff --git a/src/analysis.rs b/src/analysis.rs index 51e9aab..9d07f81 100644 --- a/src/analysis.rs +++ b/src/analysis.rs @@ -7,17 +7,18 @@ use std::{ use anyhow::Result; use egui_extras::RetainedImage; use image::RgbImage; -use img_hash::ImageHash; + use scrap::{Capturer, Display}; use crate::{ capture, - config::{Config, LearnedConfig}, + config::Config, image_processing::{self, extract_and_filter, hash_image, Region, to_png_bytes}, ocr, - state::{AppState, DebugOcrFrame, LapState, RaceState, SharedAppState}, + state::{AppState, DebugOcrFrame, LapState, RaceState, SharedAppState}, learned_tracks::get_track_hash, }; + fn is_finished_lap(state: &AppState, frame: &LapState) -> bool { if let Some(race) = &state.current_race { if let Some(last_finish) = &race.last_lap_record_time { @@ -62,25 +63,6 @@ fn merge_frames(prev: &LapState, next: &LapState) -> LapState { ..Default::default() } } - -fn get_track_hash(config: &Config, image: &RgbImage) -> Option { - let track_region = config.track_region.as_ref()?; - let extracted = extract_and_filter(image, track_region); - Some(hash_image(&extracted)) -} - -fn get_track_name(learned: &LearnedConfig, hash: &Option, config: &Config) -> Option { - let hash = hash.as_ref()?; - for (learned_hash_b64, learned_track) in &learned.learned_tracks { - let learned_hash: ImageHash> = img_hash::ImageHash::from_base64(learned_hash_b64).ok()?; - let current_hash: ImageHash> = img_hash::ImageHash::from_base64(hash).ok()?; - if current_hash.dist(&learned_hash) <= config.track_recognition_threshold.unwrap_or(10) { - return Some(learned_track.to_owned()) - } - } - learned.learned_tracks.get(hash).cloned() -} - fn handle_new_frame(state: &mut AppState, frame: LapState, image: RgbImage) { if frame.lap_time.is_some() { state.last_frame = Some(frame.clone()); @@ -88,7 +70,8 @@ fn handle_new_frame(state: &mut AppState, frame: LapState, image: RgbImage) { if state.current_race.is_none() { let track_hash = get_track_hash(state.config.as_ref(), &image); - let track_name = get_track_name(state.learned.as_ref(), &track_hash, state.config.as_ref()); + let track_name = state.learned_tracks.infer_track(&track_hash, state.config.as_ref()); + let inferred_track = track_name.is_some(); let race = RaceState { screencap: Some( RetainedImage::from_image_bytes( @@ -100,6 +83,7 @@ fn handle_new_frame(state: &mut AppState, frame: LapState, image: RgbImage) { race_time: Some(SystemTime::now()), track_hash, track: track_name.unwrap_or_default(), + inferred_track, ..Default::default() }; state.current_race = Some(race); @@ -167,16 +151,15 @@ fn add_saved_frame( fn run_loop_once(capturer: &mut Capturer, state: &SharedAppState) -> Result<()> { let frame = capture::get_frame(capturer)?; - let (config, learned_config, ocr_cache, should_sample) = { + let (config, ocr_cache, should_sample) = { let locked = state.lock().unwrap(); ( locked.config.clone(), - locked.learned.clone(), locked.ocr_cache.clone(), locked.should_sample_ocr_data ) }; - let ocr_results = ocr::ocr_all_regions(&frame, config.clone(), learned_config, ocr_cache, should_sample); + let ocr_results = ocr::ocr_all_regions(&frame, config.clone(), ocr_cache, should_sample); if state.lock().unwrap().debug_frames { let debug_frames = save_frames_from(&frame, config.as_ref(), &ocr_results); diff --git a/src/config.rs b/src/config.rs index 2d6598b..3ef0d7c 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,4 +1,4 @@ -use std::{collections::HashMap, path::PathBuf}; +use std::{path::PathBuf}; use anyhow::Result; use serde::{Serialize, Deserialize, de::DeserializeOwned}; @@ -19,26 +19,11 @@ pub struct Config { impl Config { pub fn load() -> Result { - load_or_make_default("config.json", include_str!("configs/config.default.json")) + load_config_or_make_default("config.json", include_str!("configs/config.default.json")) } } -#[derive(Default, Serialize, Deserialize, Clone)] -pub struct LearnedConfig { - pub learned_images: HashMap, - pub learned_tracks: HashMap, -} - -impl LearnedConfig { - pub fn load() -> Result { - load_or_make_default("learned.json", include_str!("configs/learned.default.json")) - } - pub fn save(&self) -> Result<()> { - save_json_config("learned.json", self) - } -} - -fn load_or_make_default(path: &str, default: &str) -> Result { +pub fn load_config_or_make_default(path: &str, default: &str) -> Result { let file_path = PathBuf::from(path); if !file_path.exists() { std::fs::write(&path, default)?; @@ -52,7 +37,7 @@ fn load_json_config(path: &str) -> Result { Ok(value) } -fn save_json_config(path: &str, val: &T) -> Result<()> { +pub fn save_json_config(path: &str, val: &T) -> Result<()> { let serialized = serde_json::to_vec_pretty(val)?; Ok(std::fs::write(path, &serialized)?) } \ No newline at end of file diff --git a/src/image_processing.rs b/src/image_processing.rs index 2dee34e..b41ab0e 100644 --- a/src/image_processing.rs +++ b/src/image_processing.rs @@ -76,7 +76,7 @@ pub fn check_target_color_fraction( } } } - return ((region.height() * region.width()) as f64) / (color_area as f64); + ((region.height() * region.width()) as f64) / (color_area as f64) } pub fn to_png_bytes(image: &RgbImage) -> Vec { diff --git a/src/learned_tracks.rs b/src/learned_tracks.rs new file mode 100644 index 0000000..00d4870 --- /dev/null +++ b/src/learned_tracks.rs @@ -0,0 +1,51 @@ +use std::{collections::HashMap, sync::Arc}; + +use crate::{config::{load_config_or_make_default, save_json_config, Config}, image_processing::{extract_and_filter, hash_image}}; + +use anyhow::Result; +use image::RgbImage; +use img_hash::ImageHash; +use serde::{Serialize, Deserialize}; + + +#[derive(Default, Serialize, Deserialize, Clone)] +pub struct LearnedTracks { + learned_tracks: HashMap, +} + +impl LearnedTracks { + pub fn load() -> Result { + load_config_or_make_default("learned.json", include_str!("configs/learned.default.json")) + } + pub fn save(&self) -> Result<()> { + save_json_config("learned.json", self) + } + + pub fn infer_track(&self, hash: &Option, config: &Config) -> Option { + let hash = hash.as_ref()?; + for (learned_hash_b64, learned_track) in &self.learned_tracks { + let learned_hash: ImageHash> = ImageHash::from_base64(learned_hash_b64).ok()?; + let current_hash: ImageHash> = ImageHash::from_base64(hash).ok()?; + if current_hash.dist(&learned_hash) <= config.track_recognition_threshold.unwrap_or(10) { + return Some(learned_track.to_owned()) + } + } + None + } + + pub fn learn_and_save(self: &mut Arc, hash: &str, track: &str) -> Result<()> { + let mut tracks = (**self).clone(); + tracks + .learned_tracks + .insert(hash.to_owned(), track.to_owned()); + tracks.save()?; + *self = Arc::new(tracks); + Ok(()) + } +} + +pub fn get_track_hash(config: &Config, image: &RgbImage) -> Option { + let track_region = config.track_region.as_ref()?; + let extracted = extract_and_filter(image, track_region); + Some(hash_image(&extracted)) +} diff --git a/src/local_ocr.rs b/src/local_ocr.rs index 19270ca..1512cbe 100644 --- a/src/local_ocr.rs +++ b/src/local_ocr.rs @@ -1,5 +1,5 @@ -use image::{DynamicImage, Rgb, RgbImage}; -use img_hash::{image::GenericImageView, ImageHash}; +use image::{Rgb, RgbImage}; +use img_hash::{image::GenericImageView}; use crate::image_processing; @@ -33,7 +33,7 @@ fn row_has_any_dark(image: &RgbImage, y: u32, start_x: u32, width: u32) -> bool fn take_while bool>(x: &mut u32, max: u32, f: F) { while *x < max && f(*x) { - *x = *x + 1; + *x += 1; } } diff --git a/src/main.rs b/src/main.rs index f7606a2..10b701c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,43 +1,46 @@ // #![cfg_attr(not(debug_assertions), windows_subsystem = "windows")] // hide console window on Windows in release +mod analysis; mod capture; mod config; -mod analysis; mod image_processing; +mod local_ocr; mod ocr; mod state; mod stats_writer; -mod local_ocr; mod training_ui; +mod learned_tracks; use std::{ collections::HashMap, ops::DerefMut, + path::PathBuf, sync::{Arc, Mutex}, thread, - time::Duration, path::PathBuf, + time::Duration, }; -use config::{Config, LearnedConfig}; use analysis::save_frames_from; +use config::Config; use eframe::{ egui::{self, Ui, Visuals}, emath::Vec2, epaint::Color32, }; use egui_extras::RetainedImage; -use image_processing::{to_png_bytes, from_png_bytes}; +use image_processing::{from_png_bytes, to_png_bytes}; +use learned_tracks::LearnedTracks; use state::{AppState, DebugOcrFrame, LapState, OcrCache, RaceState, SharedAppState}; use stats_writer::export_race_stats; fn main() -> anyhow::Result<()> { - let mode = std::env::args().nth(1).unwrap_or_default().to_string(); + let mode = std::env::args().nth(1).unwrap_or_default(); if mode == "train" { return training_ui::training_ui(); } let app_state = AppState { config: Arc::new(Config::load().unwrap()), - learned: Arc::new(LearnedConfig::load().unwrap()), + learned_tracks: Arc::new(LearnedTracks::load().unwrap()), ..Default::default() }; let state = Arc::new(Mutex::new(app_state)); @@ -51,7 +54,10 @@ fn main() -> anyhow::Result<()> { let options = eframe::NativeOptions::default(); let current_exe = std::env::current_exe().unwrap_or_else(|_| PathBuf::from("supper.exe")); eframe::run_native( - &format!("Supper OCR ({})", current_exe.file_name().unwrap().to_string_lossy()), + &format!( + "Supper OCR ({})", + current_exe.file_name().unwrap().to_string_lossy() + ), options, Box::new(|_cc| Box::new(AppUi::new(state))), ); @@ -172,7 +178,6 @@ fn show_race_state( race_name: &str, race: &mut RaceState, config: Arc, - learned: Arc, ocr_cache: Arc, ) { egui::Grid::new(format!("race:{}", race_name)).show(ui, |ui| { @@ -228,7 +233,6 @@ fn show_race_state( ui_state, lap, config.clone(), - learned.clone(), ocr_cache.clone(), ) } @@ -275,36 +279,17 @@ fn show_config_controls(ui: &mut Ui, ui_state: &mut UiState, state: &mut AppStat if let Some(e) = &ui_state.config_load_err { ui.colored_label(Color32::RED, e); } - - ui.separator(); - ui.label("Hash"); - ui.text_edit_singleline(&mut ui_state.hash_to_learn); - ui.label("Value"); - ui.text_edit_singleline(&mut ui_state.value_to_learn); - if ui.button("Learn").clicked() { - let mut learned_config = (*state.learned).clone(); - learned_config.learned_images.insert( - ui_state.hash_to_learn.clone(), - ui_state.value_to_learn.clone(), - ); - learned_config.save().unwrap(); - state.learned = Arc::new(learned_config); - - ui_state.hash_to_learn = "".to_owned(); - ui_state.value_to_learn = "".to_owned(); - } } fn open_debug_lap( ui_state: &mut UiState, lap: &LapState, config: Arc, - learned: Arc, ocr_cache: Arc, ) { if let Some(screenshot_bytes) = &lap.screenshot { let screenshot = from_png_bytes(screenshot_bytes); - let ocr_results = ocr::ocr_all_regions(&screenshot, config.clone(), learned, ocr_cache, false); + let ocr_results = ocr::ocr_all_regions(&screenshot, config.clone(), ocr_cache, false); let debug_lap = DebugLap { screenshot: RetainedImage::from_image_bytes("debug-lap", &to_png_bytes(&screenshot)) .unwrap(), @@ -323,14 +308,7 @@ fn show_combo_box(ui: &mut Ui, name: &str, label: &str, options: &[String], valu *value = options[index].clone(); } -fn save_learned_track(learned: &mut Arc, track: &str, hash: &str) { - let mut learned_config = (**learned).clone(); - learned_config.learned_tracks.insert( - hash.to_owned(), - track.to_owned(), - ); - learned_config.save().unwrap(); - *learned = Arc::new(learned_config); +fn save_learned_track(_learned_tracks: &mut Arc, _track: &str, _hash: &str) { } impl eframe::App for AppUi { @@ -388,13 +366,19 @@ impl eframe::App for AppUi { if let Some(tyre_wear) = race.tyre_wear() { ui.heading(&format!("p50 Tyre Wear: {}", tyre_wear)); if let Some(tyres) = frame.tyres { - ui.label(&format!("Out of tires in {:.1} lap(s)", (tyres as f64) / (tyre_wear as f64))); + ui.label(&format!( + "Out of tires in {:.1} lap(s)", + (tyres as f64) / (tyre_wear as f64) + )); } } if let Some(gas_wear) = race.gas_per_lap() { ui.heading(&format!("p50 Gas Wear: {}", gas_wear)); if let Some(gas) = frame.gas { - ui.label(&format!("Out of gas in {:.1} lap(s)", (gas as f64) / (gas_wear as f64))); + ui.label(&format!( + "Out of gas in {:.1} lap(s)", + (gas as f64) / (gas_wear as f64) + )); } } } @@ -402,12 +386,15 @@ impl eframe::App for AppUi { ui.separator(); ui.checkbox(&mut state.debug_frames, "Debug OCR regions"); - ui.checkbox(&mut state.should_sample_ocr_data, "Dump OCR training frames"); + ui.checkbox( + &mut state.should_sample_ocr_data, + "Dump OCR training frames", + ); }); egui::CentralPanel::default().show(ctx, |ui| { egui::ScrollArea::vertical().show(ui, |ui| { let config = state.config.clone(); - let learned = state.learned.clone(); + let _learned_tracks = state.learned_tracks.clone(); let ocr_cache = state.ocr_cache.clone(); if let Some(race) = &mut state.current_race { ui.heading(&format!("Current Race: {}", race.name())); @@ -417,13 +404,12 @@ impl eframe::App for AppUi { "current", race, config.clone(), - learned, ocr_cache.clone(), ); } let len = state.past_races.len(); let mut races_to_remove = Vec::new(); - let mut learned = state.learned.clone(); + let mut learned_tracks = state.learned_tracks.clone(); for (i, race) in state.past_races.iter_mut().enumerate() { ui.separator(); ui.heading(format!("Race #{}: {}", len - i, race.name())); @@ -433,14 +419,19 @@ impl eframe::App for AppUi { &format!("race {}:", i), race, config.clone(), - learned.clone(), ocr_cache.clone(), ); if let Some(img) = &race.screencap { img.show_max_size(ui, Vec2::new(600.0, 500.0)); } if !race.exported { - show_combo_box(ui, &format!("car-combo {}", i), "Car", &self.data.cars, &mut race.car); + show_combo_box( + ui, + &format!("car-combo {}", i), + "Car", + &self.data.cars, + &mut race.car, + ); show_combo_box( ui, &format!("track-combo {}", i), @@ -452,7 +443,9 @@ impl eframe::App for AppUi { ui.text_edit_singleline(&mut race.comments); if ui.button("Export").clicked() { if let Some(track_hash) = &race.track_hash { - save_learned_track(&mut learned, &race.track, track_hash); + if !race.inferred_track { + learned_tracks.learn_and_save(track_hash, &race.track).unwrap(); + } } match export_race_stats(race) { Ok(_) => { @@ -474,7 +467,7 @@ impl eframe::App for AppUi { races_to_remove.push(i); } } - state.learned = learned; + state.learned_tracks = learned_tracks; for index in races_to_remove { state.past_races.remove(index); } diff --git a/src/ocr.rs b/src/ocr.rs index 73fec4e..673ba35 100644 --- a/src/ocr.rs +++ b/src/ocr.rs @@ -8,8 +8,9 @@ use image::RgbImage; use serde::{Deserialize, Serialize}; use crate::{ - config::{Config, LearnedConfig}, - image_processing::{extract_and_filter, hash_image}, state::OcrCache, + config::Config, + image_processing::{extract_and_filter, hash_image}, + state::OcrCache, }; #[derive(Serialize, Deserialize, Debug)] @@ -81,9 +82,8 @@ async fn run_ocr_cached( pub async fn ocr_all_regions( image: &RgbImage, config: Arc, - learned: Arc, ocr_cache: Arc, - should_sample: bool + should_sample: bool, ) -> HashMap> { let results = Arc::new(Mutex::new(HashMap::new())); @@ -93,19 +93,15 @@ pub async fn ocr_all_regions( let region = region.clone(); let results = results.clone(); let config = config.clone(); - let learned = learned.clone(); let ocr_cache = ocr_cache.clone(); handles.push(tokio::spawn(async move { let filtered_image = filtered_image; let hash = hash_image(&filtered_image); - let value = if let Some(learned_value) = learned.learned_images.get(&hash) { - Some(learned_value.clone()) - } else { - run_ocr_cached(ocr_cache, hash, ®ion, config.clone(), &filtered_image).await - }; + let value = + run_ocr_cached(ocr_cache, hash, ®ion, config.clone(), &filtered_image).await; if let Some(sample_fraction) = &config.dump_frame_fraction { - if rand::random::() < *sample_fraction { + if rand::random::() < *sample_fraction && should_sample { let file_id = rand::random::(); let img_filename = format!("ocr_data/{}.png", file_id); filtered_image.save(img_filename).unwrap(); diff --git a/src/state.rs b/src/state.rs index e658abf..fbb49c4 100644 --- a/src/state.rs +++ b/src/state.rs @@ -4,7 +4,7 @@ use egui_extras::RetainedImage; use image::RgbImage; use time::{OffsetDateTime, format_description}; -use crate::config::{Config, LearnedConfig}; +use crate::{config::Config, learned_tracks::LearnedTracks}; #[derive(Debug, Clone, Default)] @@ -74,7 +74,7 @@ fn median_wear(values: Vec>) -> Option { last_value = val; } } - wear_values.sort(); + wear_values.sort_unstable(); wear_values.get(wear_values.len() / 2).cloned() } @@ -93,6 +93,8 @@ pub struct RaceState { pub car: String, pub track: String, pub comments: String, + + pub inferred_track: bool, } impl RaceState { @@ -143,7 +145,7 @@ pub struct AppState { pub should_sample_ocr_data: bool, pub config: Arc, - pub learned: Arc, + pub learned_tracks: Arc, pub ocr_cache: Arc, } diff --git a/src/training_ui.rs b/src/training_ui.rs index 2ba6d6d..3cc7cf6 100644 --- a/src/training_ui.rs +++ b/src/training_ui.rs @@ -1,15 +1,10 @@ use std::{ - collections::HashMap, io::Write, path::{PathBuf, Path}, - sync::{Arc, Mutex}, - thread, - time::Duration, }; use eframe::{ - egui::{self, Ui, Visuals}, - emath::Vec2, + egui::{self, Visuals}, epaint::Color32, }; use egui_extras::RetainedImage; @@ -63,7 +58,7 @@ fn get_training_data_paths() -> Vec<(PathBuf, PathBuf)> { fn predict_ocr(hashes: &[(String, char)], hash: &str) -> Option { let hash = img_hash::ImageHash::>::from_base64(hash).unwrap(); - let (_, best_char) = hashes.iter().min_by_key(|(learned_hash, c)| { + let (_, best_char) = hashes.iter().min_by_key(|(learned_hash, _c)| { img_hash::ImageHash::from_base64(learned_hash) .unwrap() .dist(&hash) @@ -106,8 +101,8 @@ fn load_learned_hashes() -> Vec<(String, char)> { let data = String::from_utf8(std::fs::read(path).unwrap()).unwrap(); let mut parsed = Vec::new(); for line in data.lines() { - if let Some((c, hash)) = line.split_once(" ") { - if let Some(c) = c.chars().nth(0) { + if let Some((c, hash)) = line.split_once(' ') { + if let Some(c) = c.chars().next() { parsed.push((hash.to_owned(), c)); } } @@ -197,7 +192,7 @@ impl eframe::App for TrainingUi { } for c in ¤t_image.char_hashes { ui.label(c); - if let Some(predicted) = predict_ocr(&self.learned_char_hashes, &c) { + if let Some(predicted) = predict_ocr(&self.learned_char_hashes, c) { ui.label(format!("Predicted: {}", predicted)); } }