From e195158218e66f721c50f9493c61b2c6ce529a2b Mon Sep 17 00:00:00 2001 From: Scott Pruett Date: Fri, 3 Jun 2022 22:18:46 -0400 Subject: [PATCH] local OCR working --- config_1080p.json | 7 +- ocr.json | 101 +++++++++++++++++++++++++++ race_stats.csv | 1 + src/analysis.rs | 45 +++++++++--- src/config.rs | 3 - src/image_processing.rs | 5 +- src/learned_tracks.rs | 2 +- src/main.rs | 52 +++++++------- src/ocr.rs | 3 +- src/ocr_db.rs | 100 +++++++++++++++++++++++++++ src/remote_ocr.rs | 121 --------------------------------- src/state.rs | 18 +++-- src/test_data/test-image-3.png | Bin 0 -> 707 bytes src/training_ui.rs | 73 +++----------------- 14 files changed, 294 insertions(+), 237 deletions(-) create mode 100644 ocr.json create mode 100644 src/ocr_db.rs delete mode 100644 src/remote_ocr.rs create mode 100644 src/test_data/test-image-3.png diff --git a/config_1080p.json b/config_1080p.json index 971de45..9a70a68 100644 --- a/config_1080p.json +++ b/config_1080p.json @@ -33,15 +33,18 @@ "x": 1744, "y": 127, "width": 137, - "height": 32 + "height": 32, + "threshold": 0.88 }, { "name": "lap_time", "x": 1744, "y": 166, "width": 137, - "height": 32 + "height": 32, + "threshold": 0.88 } ], + "track_recognition_threshold": 10, "ocr_server_endpoint": "https://tesserver.spruett.dev/" } diff --git a/ocr.json b/ocr.json new file mode 100644 index 0000000..cc8638c --- /dev/null +++ b/ocr.json @@ -0,0 +1,101 @@ +{ + "learned_chars": { + "//8P+Afw4+fzx/HP+c/5z/nP8c/zx/Pnx/MP+H/+//8=": "0", + "//8H4Afg/+P/4//j/+P/4//j/+P/4//j/+P/4//j//8=": "1", + "//8f+A/g5+fzz/PP88/zn/PP88/zz+fP5+Mf8H/+//8=": "0", + "/////P/8f/4//p//3//H++P78/sBwAGA//v/+//7//8=": "4", + "//8P4A/gD+AP/w//D/gP4AfA/8P/w4fAA+AD8B/8//8=": "5", + "//8P4A/gD+AP/4//D/gH4Afg/8H/wYfgA+AD8B/8//8=": "5", + "//9/4H/wP/Af+A/8B/4DggODAYABgAGAAYD/g/+D//8=": "4", + "/////f/8f/w//5//j//H++f78/MBgP/z//v/+//7//8=": "4", + "/////wfwB+DD4cPBA+AP8Afgw8PDw4PBA+AH8P////8=": "8", + "//8BgAGAAYDhgeGB/8D/wH/gP+A/8B/wH/gP+A/8//8=": "7", + "////z/+fH/4H+AP4w/Dh8eHx4fHh8cPwA/gP/B/+//8=": "0", + "//8H4Afg/+H/4f/h/+H/4f/h/+H/4f/h/+H/4f/h//8=": "1", + "//8P+Afw4+fz5/HP+c/5z/nP+c/zx/Pnx/MP+H////8=": "0", + "///f/w/wB+DH48fhB+AH4APgw8PDw8PDB+AP8P////8=": "8", + "/////w/wB+DH48fjB+AH4Afgw8PDw8PDB+AP8P////8=": "8", + "//8f/Efx8+fz7/vP88fjww/I/8//z//n/+MH+A/+//8=": "9", + "//8/4I/g5//z//P/O/wL4OPH88/zz/fPx8cP4H/8//8=": "6", + "//8P8AfgA8CDgcGDwYPBg8GDwYPBg4OBA8AH4A/w//8=": "0", + "//8f+Afi88/zz/PPx+EP8OPH88/73/Pf48cH4H/+//8=": "8", + "/////wPwA+DD4P/h/+H/4H/wP/gP/gf+A8ADwP////8=": "2", + "//8f+A/g5+fzz/PP88/zz/PP88/zz+fH5+cP8H/+//8=": "0", + "//8D4APg//P/+X/8P/4/8P/H/8//z//P8+cD8D////8=": "3", + "/////wPAA8DDwcPBw+H/4P/wf/A/8D/4P/w/+P////8=": "7", + "//8P+APw++f/7//v/+f/4//x//w//h//z/8DwAPA//8=": "2", + "/////wPAA8Dj4OPg4+B/8H/wP/g/+D/8H/wf/P////8=": "7", + "//8/+A/w5+fjz/PP88/zz/PP88/zz+PP5+cP8D/8//8=": "0", + "//8P+CPw8+f77/vP88fjwQ/M/8//z//n//EH+D////8=": "9", + "//8HwAfA//P/8f/8f/5/8P/H/8//3//P88cD8H/+//8=": "3", + "//8DwAPA+8/75//n//P/8//5//n//H/+f/4//z////8=": "7", + "//8/4Ifg5//z//P/O/wL4OPH88/zz/PP5+cP4H/8//8=": "6", + "//+//Q/4B/CD4cPjw8PDw+PDw+PD4QfgD/Af+H/+//8=": "0", + "//9/7x/AD+AH74f/g/AD4APAg8OHwwfDD8Af4H/7//8=": "6", + "//8f+Afw++f/7//v/+f/4//4f/w//4//x/8DwAPA//8=": "2", + "///f/Q/wB+CH4cPDw8PDw8PDw8PDwwfgB/AP+L/9//8=": "0", + "//9/7x/gD+AH/4P/g/gD4APAg8ODw4fBB+Af8H/9//8=": "6", + "//8H4Afi//P/+P/8f/4/8P/j/8f/z//H8+MD8D/+//8=": "3", + "//8P4A/g7//n/+f/B/wH8P/H/8//3//P88cD8H/+//8=": "5", + "//8DwAPA88/z5//j/+P/8f/5//j//H/+P/4//h////8=": "7", + "//8P8A/wD/CH/4f/B/gH8Afg/8H/wYfgA+AD+B/+//8=": "5", + "//8P8APgAcDBgcGDwYEBgAOAD4L/g//AA8AD4AP4//8=": "9", + "//8f+Afg5+fzz/PP+9/73/vf+9/zz+PP5+cP4B/4//8=": "0", + "/////wfgA+DHwf/B/8H/4H/gP/gf/A/+B8ADwP////8=": "2", + "//9/7x/AD+AH7of/g/AD4APAg8ODwwfDD+Af4H/7//8=": "6", + "/////wfgB+B/4H/gf+B/4H/gf+B/4H/gf+B/4P////8=": "1", + "//8BgAGAAYDhgeGB7cD/wH/gf+A/8D/wH/gP+A/87/8=": "7", + "////+z/4P/wf/g/+D/+H8cPwA8ADwAPAA8D/8P////8=": "4", + "//8f+Afg5+fzz/Pf+9/73/vf+9/zz+PHx+MP8H/+//8=": "0", + "//////////8//B/4H/gf+B/4H/gf+D/8//////////8=": ".", + "//8H4Afg//P/+f/8f/w/4P/H/8//z//P4+cH4D/8//8=": "3", + "//8H4A/g/+P/4//j/+P/4//j/+P/4//j/+P/4//j//8=": "1", + "//8D4APg//P/+X/8f/4/+P/j/8//z//P+8cD8B/8//8=": "3", + "//8DgAOAA8D/4H/wH/AfwB+A/4H/g+OBA4ABwAPw//8=": "3", + "///f+w/wB+CDwcPDw8MHwA/AP8P/w0/gB/AH+F////8=": "9", + "/////f/8f/x//j//n//P8+fz4/MBwAOA//P/8//z//8=": "4", + "//8HwAfA/+P/8//9f/x/8P/H/8//3//P888D4B/4//8=": "3", + "///f/Q/4B/DD48PDw8EHwA/AH8P/4U/gB/AH+F////8=": "9", + "//8/8A/lx//3//P/8/8b4MPH48/z3/PP58eH4D/8//8=": "6", + "/////w/wD/D//////////////////w/wD/AP8H/+//8=": ":", + "//8f/Efh4+fzz/PP88fnwwfY/9//z//v/+On8A/8//8=": "9", + "////////////+D/wP/Af8A/wD/AP8B/4//////////8=": ".", + "//8D4APg//P/+f/8P/4/8P/H/8//z//P88cD8H////8=": "3", + "/////wfgB+Af4H/gf+B/4H/gf+B/4H/gf+B/4P////8=": "1", + "/////wfgB+A/4D/gP+A/4D/gP+A/4D/gP+A/4P////8=": "1", + "//9/7x/AD+AH74f/g/AD4APAg8ODwwfDD8Af4H/7//8=": "6", + "//8f+Afi8+fzz/PP48fHww/Y/8//z//v/+MH+D////8=": "9", + "//8HwAfAB8CD/wP/A+ADwAOA/4H/g+OBA4ABwAPw//8=": "5", + "//8f+Efi5+fzz/vf+9/73/vf+9/zz+PHx+MP8D/8//8=": "0", + "/////wfwA+DD4P/h/+H/4H/wP/gf/Af+B8AHwP////8=": "2", + "//8D4APg//P/+X/8f/4/+P/j/8//z//P+8cD8A/8//8=": "3", + "//8f+Afg5+fzz/vf+9/73/vf+9/zz+PHx+MP8H/+//8=": "0", + "//8P/APw++P/5//n/+f/8//4f/w//4//x/8DwAPA//8=": "2", + "//8P4A/g7//v/+//B/wH4P/P/8//3//P88cD8H/+//8=": "5", + "//8/gA+AB4AD34P/geABgAGAAYOBh4GHA4AHgA/g//8=": "6", + "/////w/wD/D/////////////////////D/AP8P////8=": ":", + "//8f/Afw5+fz5/HP+c/5z/nP+c/xz/Pn5+cP8B/8//8=": "0", + "//9/7x/gD+AH/4P/g/gD4APAg8ODw4fBD+Af8H/9//8=": "6", + "//8BgAGAAYA/gD+AP4A/gD+AP4A/gD+AP4A/gD+A//8=": "1", + "///f/Q/4B/DD4cPDw8EDwA/AH+P/4W/gB/AH+F////8=": "9", + "//8DwAPA+8/75//n//P/8f/5//n//H/+f/4//z////8=": "7", + "//8/gA+AB4AD34H/geABgAGAAYOBh4GHA4AHgA/g//8=": "6", + "//8P+CPw8+f77/vP88fjwQ/M/8//z//n//EH+J////8=": "9", + "//8H8APAAcDDgP+B/4H/wH/gH/AP+Af8A4ADgAOA//8=": "2", + "//8f+A/g5+fzz/PP88/zz/PP88/zz+fH5+cf8H/+//8=": "0", + "///v/QfwA+DD4cPBw8EDwAfAn8H/4UfwB/AD/G////8=": "9", + "/////P/8f/4//p//3//n++P78/sBwAGA//v/+//7//8=": "4", + "/////////////wfgB+AH4AfgB+AH4P////////////8=": "-", + "//8H4Afg/+H/4f/h/+H/4f/h/+H/4f/h/+H/4f/j//8=": "1", + "//8H4AfgB+D/8H/4P/w/8D/g/8P/w4fAA+AD8B/8//8=": "3", + "//8P8APAA8CBgYGDA8AH4APAgYHBg8GDgYEDwAfg//8=": "8", + "//8P+APw++f/7//v/+f/8//x//w//h//z/8DwAPA//8=": "2", + "//8f+Afw4+fzx/HP+c/5z/nP8c/zx/Pnx/MP+H////8=": "0", + "//8HwAfAB8D/4H/wP/g/4D/A/8H/w4fAA8AD8B/8//8=": "3", + "//+//R/4D/CH4cPDw8PDw8PDw8PHwwfgD/Af+H/+//8=": "0", + "//+//R/wD+CH4cfDw8PDw8PDw8PHwwfgD/Af+H/+//8=": "0", + "//////////8//D/8H/gP8A/wH/g//D/8//////////8=": ".", + "////+3/4P/w//B/+D/8H8YfhA8ADwAPAA8D/4f////8=": "4", + "//////////8BgAGAAYABgAGAAYABgAGA//////////8=": "-" + } +} \ No newline at end of file diff --git a/race_stats.csv b/race_stats.csv index 470ecea..969f133 100644 --- a/race_stats.csv +++ b/race_stats.csv @@ -178,3 +178,4 @@ 2022-06-03-00:30 (Magdalena Club),Magdalena Club,60s GP,11,21.062,20.296,100,48,69, 2022-06-03-00:30 (Magdalena Club),Magdalena Club,60s GP,12,20.942,20.296,100,39,61, 2022-06-03-00:30 (Magdalena Club),Magdalena Club,60s GP,13,20.576,20.296,100,30,52, +2022-06-04-02:10,,,3,25.000,25.000,75,75,75, diff --git a/src/analysis.rs b/src/analysis.rs index 7c5ab7b..ec67ea3 100644 --- a/src/analysis.rs +++ b/src/analysis.rs @@ -1,7 +1,7 @@ use std::{ collections::HashMap, thread, - time::{Duration, Instant, SystemTime}, + time::{Duration, Instant, SystemTime}, sync::Arc, }; use anyhow::Result; @@ -14,8 +14,7 @@ use crate::{ capture, config::Config, image_processing::{self, extract_and_filter, hash_image, Region, to_png_bytes}, - remote_ocr, - state::{AppState, DebugOcrFrame, LapState, RaceState, SharedAppState}, learned_tracks::get_track_hash, + state::{AppState, DebugOcrFrame, LapState, RaceState, SharedAppState}, learned_tracks::get_track_hash, ocr_db::OcrDatabase, }; @@ -130,36 +129,62 @@ fn add_saved_frame( saved_frames: &mut HashMap, frame: &RgbImage, region: &Region, - ocr_results: &HashMap>, + ocr_results: &HashMap, ) { let extracted = extract_and_filter(frame, region); let retained = RetainedImage::from_image_bytes(®ion.name, &image_processing::to_png_bytes(&extracted)) .unwrap(); - let hash = hash_image(&extracted); + let hash = hash_image(&extracted).to_base64(); saved_frames.insert( region.name.clone(), DebugOcrFrame { image: retained, rgb_image: extracted, img_hash: hash, - recognized_text: ocr_results.get(®ion.name).and_then(|p| p.clone()), + recognized_text: ocr_results.get(®ion.name).map(|p| p.clone()), }, ); } +pub fn ocr_all_regions( + ocr_db: &OcrDatabase, + image: &RgbImage, + config: &Config, + should_sample: bool, +) -> HashMap { + let mut results = HashMap::new(); + + for region in &config.ocr_regions { + let filtered_image = extract_and_filter(image, region); + let value = ocr_db.ocr_image(&filtered_image); + + if let Some(sample_fraction) = &config.dump_frame_fraction { + if rand::random::() < *sample_fraction && should_sample { + let file_id = rand::random::(); + let img_filename = format!("ocr_data/{}.png", file_id); + filtered_image.save(img_filename).unwrap(); + let value_filename = format!("ocr_data/{}.txt", file_id); + std::fs::write(value_filename, value.clone()).unwrap(); + } + } + results.insert(region.name.clone(), value); + } + results +} + fn run_loop_once(capturer: &mut Capturer, state: &SharedAppState) -> Result<()> { let frame = capture::get_frame(capturer)?; - let (config, ocr_cache, should_sample) = { + let (ocr_db, config, should_sample) = { let locked = state.lock().unwrap(); ( + locked.ocr_db.clone(), locked.config.clone(), - locked.ocr_cache.clone(), locked.should_sample_ocr_data ) }; - let ocr_results = remote_ocr::ocr_all_regions(&frame, config.clone(), ocr_cache, should_sample); + let ocr_results = ocr_all_regions(ocr_db.as_ref(), &frame, config.as_ref(), should_sample); if state.lock().unwrap().debug_frames { let debug_frames = save_frames_from(&frame, config.as_ref(), &ocr_results); @@ -178,7 +203,7 @@ fn run_loop_once(capturer: &mut Capturer, state: &SharedAppState) -> Result<()> pub fn save_frames_from( frame: &RgbImage, config: &Config, - ocr_results: &HashMap>, + ocr_results: &HashMap, ) -> HashMap { let mut saved_frames = HashMap::new(); for region in &config.ocr_regions { diff --git a/src/config.rs b/src/config.rs index 3ef0d7c..6423b7f 100644 --- a/src/config.rs +++ b/src/config.rs @@ -9,9 +9,6 @@ use crate::image_processing::Region; pub struct Config { pub ocr_regions: Vec, pub track_region: Option, - pub ocr_server_endpoint: String, - pub filter_threshold: Option, - pub use_ocr_cache: Option, pub ocr_interval_ms: Option, pub track_recognition_threshold: Option, pub dump_frame_fraction: Option, diff --git a/src/image_processing.rs b/src/image_processing.rs index b41ab0e..951aba3 100644 --- a/src/image_processing.rs +++ b/src/image_processing.rs @@ -1,4 +1,5 @@ use image::{codecs::png::PngEncoder, ColorType, ImageEncoder, Rgb, RgbImage}; +use img_hash::ImageHash; use serde::{Deserialize, Serialize}; #[derive(Clone, Deserialize, Serialize)] @@ -98,7 +99,7 @@ pub fn from_png_bytes(bytes: &[u8]) -> RgbImage { image.to_rgb8() } -pub fn hash_image(image: &RgbImage) -> String { +pub fn hash_image(image: &RgbImage) -> ImageHash { let hasher = img_hash::HasherConfig::new() .hash_alg(img_hash::HashAlg::Mean) .hash_size(16, 16) @@ -107,5 +108,5 @@ pub fn hash_image(image: &RgbImage) -> String { img_hash::image::RgbImage::from_raw(image.width(), image.height(), image.as_raw().to_vec()) .unwrap(); let hash = hasher.hash_image(&have_to_use_other_image_library_version); - hash.to_base64() + hash } diff --git a/src/learned_tracks.rs b/src/learned_tracks.rs index 00d4870..252135b 100644 --- a/src/learned_tracks.rs +++ b/src/learned_tracks.rs @@ -47,5 +47,5 @@ impl LearnedTracks { pub fn get_track_hash(config: &Config, image: &RgbImage) -> Option { let track_region = config.track_region.as_ref()?; let extracted = extract_and_filter(image, track_region); - Some(hash_image(&extracted)) + Some(hash_image(&extracted).to_base64()) } diff --git a/src/main.rs b/src/main.rs index d193861..477fb5f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -5,11 +5,11 @@ mod capture; mod config; mod image_processing; mod ocr; -mod remote_ocr; mod state; mod stats_writer; mod training_ui; mod learned_tracks; +mod ocr_db; use std::{ collections::HashMap, @@ -30,7 +30,8 @@ use eframe::{ use egui_extras::RetainedImage; use image_processing::{from_png_bytes, to_png_bytes}; use learned_tracks::LearnedTracks; -use state::{AppState, DebugOcrFrame, LapState, OcrCache, RaceState, SharedAppState}; +use ocr_db::OcrDatabase; +use state::{AppState, DebugOcrFrame, LapState, RaceState, SharedAppState}; use stats_writer::export_race_stats; fn main() -> anyhow::Result<()> { @@ -41,6 +42,7 @@ fn main() -> anyhow::Result<()> { let app_state = AppState { config: Arc::new(Config::load().unwrap()), learned_tracks: Arc::new(LearnedTracks::load().unwrap()), + ocr_db: Arc::new(OcrDatabase::load().unwrap()), ..Default::default() }; let state = Arc::new(Mutex::new(app_state)); @@ -174,8 +176,8 @@ fn show_race_state( ui_state: &mut UiState, race_name: &str, race: &mut RaceState, - config: Arc, - ocr_cache: Arc, + config: &Config, + ocr_db: &OcrDatabase, ) { egui::Grid::new(format!("race:{}", race_name)).show(ui, |ui| { ui.label("Lap"); @@ -228,9 +230,9 @@ fn show_race_state( if ui.button("Debug").clicked() { open_debug_lap( ui_state, + ocr_db, lap, - config.clone(), - ocr_cache.clone(), + config, ) } @@ -241,10 +243,8 @@ fn show_race_state( }); } -fn show_debug_frames(ui: &mut Ui, debug_frames: &HashMap) { - let mut screenshots_sorted: Vec<_> = debug_frames.iter().collect(); - screenshots_sorted.sort_by_key(|(name, _)| *name); - for (name, debug_image) in screenshots_sorted { +fn show_debug_frames(ui: &mut Ui, debug_frames: &mut HashMap) { + for (name, debug_image) in debug_frames.iter_mut() { ui.label(name); if let Some(text) = &debug_image.recognized_text { ui.label(text); @@ -280,13 +280,13 @@ fn show_config_controls(ui: &mut Ui, ui_state: &mut UiState, state: &mut AppStat fn open_debug_lap( ui_state: &mut UiState, + ocr_db: &OcrDatabase, lap: &LapState, - config: Arc, - ocr_cache: Arc, + config: &Config, ) { if let Some(screenshot_bytes) = &lap.screenshot { let screenshot = from_png_bytes(screenshot_bytes); - let ocr_results = remote_ocr::ocr_all_regions(&screenshot, config.clone(), ocr_cache, false); + let ocr_results = analysis::ocr_all_regions(ocr_db, &screenshot, config, false); let debug_lap = DebugLap { screenshot: RetainedImage::from_image_bytes("debug-lap", &to_png_bytes(&screenshot)) .unwrap(), @@ -308,6 +308,7 @@ fn show_combo_box(ui: &mut Ui, name: &str, label: &str, options: &[String], valu impl eframe::App for AppUi { fn update(&mut self, ctx: &egui::Context, _frame: &mut eframe::Frame) { let mut state = self.state.lock().unwrap(); + let ocr_db = state.ocr_db.clone(); let mut debug_lap_window = self.ui_state.debug_lap.is_some(); let window = egui::Window::new("Debug Lap").open(&mut debug_lap_window); @@ -318,8 +319,8 @@ impl eframe::App for AppUi { .screenshot .show_max_size(ui, Vec2::new(800.0, 600.0)); ui.separator(); - if let Some(debug_lap) = &self.ui_state.debug_lap { - show_debug_frames(ui, &debug_lap.debug_regions); + if let Some(debug_lap) = &mut self.ui_state.debug_lap { + show_debug_frames(ui, &mut debug_lap.debug_regions); } show_config_controls(ui, &mut self.ui_state, state.deref_mut()); } @@ -380,16 +381,17 @@ impl eframe::App for AppUi { ui.separator(); ui.checkbox(&mut state.debug_frames, "Debug OCR regions"); - ui.checkbox( - &mut state.should_sample_ocr_data, - "Dump OCR training frames", - ); + if state.config.dump_frame_fraction.is_some() { + ui.checkbox( + &mut state.should_sample_ocr_data, + "Dump OCR training frames", + ); + } }); egui::CentralPanel::default().show(ctx, |ui| { egui::ScrollArea::vertical().show(ui, |ui| { let config = state.config.clone(); let _learned_tracks = state.learned_tracks.clone(); - let ocr_cache = state.ocr_cache.clone(); if let Some(race) = &mut state.current_race { ui.heading(&format!("Current Race: {}", race.name())); show_race_state( @@ -397,8 +399,8 @@ impl eframe::App for AppUi { &mut self.ui_state, "current", race, - config.clone(), - ocr_cache.clone(), + config.as_ref(), + ocr_db.as_ref(), ); } let len = state.past_races.len(); @@ -412,8 +414,8 @@ impl eframe::App for AppUi { &mut self.ui_state, &format!("race {}:", i), race, - config.clone(), - ocr_cache.clone(), + config.as_ref(), + ocr_db.as_ref(), ); if let Some(img) = &race.screencap { img.show_max_size(ui, Vec2::new(600.0, 500.0)); @@ -471,7 +473,7 @@ impl eframe::App for AppUi { if state.debug_frames { egui::SidePanel::right("screenshots").show(ctx, |ui| { egui::ScrollArea::vertical().show(ui, |ui| { - show_debug_frames(ui, &state.saved_frames); + show_debug_frames(ui, &mut state.saved_frames); show_config_controls(ui, &mut self.ui_state, state.deref_mut()); }); }); diff --git a/src/ocr.rs b/src/ocr.rs index c3bfb70..dd42878 100644 --- a/src/ocr.rs +++ b/src/ocr.rs @@ -1,4 +1,5 @@ use image::{Rgb, RgbImage}; +use img_hash::ImageHash; use crate::image_processing; @@ -102,7 +103,7 @@ pub fn bounding_box_images(image: &RgbImage) -> Vec { trimmed } -pub fn compute_box_hashes(image: &RgbImage) -> Vec { +pub fn compute_box_hashes(image: &RgbImage) -> Vec { let mut hashes = Vec::new(); let boxes = get_character_bounding_boxes(image); diff --git a/src/ocr_db.rs b/src/ocr_db.rs new file mode 100644 index 0000000..e49c7cb --- /dev/null +++ b/src/ocr_db.rs @@ -0,0 +1,100 @@ +use std::{collections::HashMap, sync::{Arc, RwLock}}; + +use crate::{ + config::{load_config_or_make_default, save_json_config, Config}, + image_processing::{extract_and_filter, hash_image}, ocr, +}; + +use anyhow::Result; +use image::RgbImage; +use img_hash::ImageHash; +use serde::{Deserialize, Serialize}; + +#[derive(Default, Serialize, Deserialize, Clone)] +struct RawOcrDatabase { + learned_chars: HashMap, +} + +#[derive(Default)] +pub struct OcrDatabase { + learned_chars: RwLock>, +} + +impl From<&RawOcrDatabase> for OcrDatabase { + fn from(raw: &RawOcrDatabase) -> Self { + let db = Self::default(); + { + let mut state = db.learned_chars.write().unwrap(); + for (hash, s) in &raw.learned_chars { + state.push(( + ImageHash::from_base64(&hash).unwrap(), + s.chars().nth(0).unwrap(), + )); + } + } + db + } +} +impl From<&OcrDatabase> for RawOcrDatabase { + fn from(db: &OcrDatabase) -> Self { + Self { + learned_chars: db + .learned_chars + .read() + .unwrap() + .iter() + .map(|(hash, c)| (hash.to_base64(), c.to_string())) + .collect(), + } + } +} + +impl OcrDatabase { + pub fn load() -> Result { + let raw: RawOcrDatabase = + load_config_or_make_default("ocr.json", include_str!("configs/learned.default.json"))?; + Ok((&raw).into()) + } + pub fn save(&self) -> Result<()> { + let raw: RawOcrDatabase = self.into(); + save_json_config("ocr.json", &raw) + } + + pub fn learn(&self, hash: &str, val: char) { + self.learned_chars.write().unwrap().push((ImageHash::from_base64(hash).unwrap(), val)); + } + + pub fn learn_phrase(&self, hashes: &[ImageHash], phrase: &str) -> Result<()> { + if phrase.len() > hashes.len() { + anyhow::bail!("too many characters detected in OCR result for learned phrase") + } + if phrase.len() < hashes.len() - 1 { + anyhow::bail!("too few characters detected in OCR result for learned phrase") + } + if !phrase.is_ascii() { + anyhow::bail!("cannot learn non-ASCII characters") + } + + let chars: Vec = phrase.chars().collect(); + for (i, hash) in hashes.iter().enumerate() { + self.learn(&hash.to_base64(), chars[i]); + } + Ok(()) + } + + pub fn ocr_char(&self, hash: &ImageHash) -> Option { + let state = self.learned_chars.read().unwrap(); + let (_, c) = state.iter().min_by_key(|(learned_hash, _)| hash.dist(learned_hash))?; + Some(*c) + } + + pub fn ocr_hashes(&self, hashes: &[ImageHash]) -> String { + let buffer: String = hashes.iter().filter_map(|hash| self.ocr_char(hash)).collect(); + buffer.trim_end_matches(|c: char| !c.is_alphanumeric()).to_owned() + } + + pub fn ocr_image(&self, image: &RgbImage) -> String { + let hashes = ocr::compute_box_hashes(image); + self.ocr_hashes(&hashes) + } +} diff --git a/src/remote_ocr.rs b/src/remote_ocr.rs deleted file mode 100644 index 673ba35..0000000 --- a/src/remote_ocr.rs +++ /dev/null @@ -1,121 +0,0 @@ -use std::{ - collections::HashMap, - sync::{Arc, Mutex, RwLock}, -}; - -use anyhow::Result; -use image::RgbImage; -use serde::{Deserialize, Serialize}; - -use crate::{ - config::Config, - image_processing::{extract_and_filter, hash_image}, - state::OcrCache, -}; - -#[derive(Serialize, Deserialize, Debug)] -pub struct OcrRegion { - pub confidence: f64, - pub value: String, -} - -#[derive(Serialize, Deserialize, Debug)] -struct OcrResult { - regions: Vec, - error: Option, -} - -async fn run_ocr(image: &RgbImage, url: &str) -> Result> { - let client = reqwest::Client::new(); - let response = client - .post(url) - .body(crate::image_processing::to_png_bytes(image)) - .send() - .await?; - - if !response.status().is_success() { - eprintln!("failed to run OCR query"); - anyhow::bail!("failed to run OCR query") - } - let result: OcrResult = response.json().await?; - let result = if result.regions.is_empty() { - None - } else { - let mut buffer = String::new(); - for r in &result.regions { - buffer += &r.value; - } - Some(buffer) - }; - Ok(result) -} - -async fn run_ocr_cached( - ocr_cache: Arc>>>, - hash: String, - region: &crate::image_processing::Region, - config: Arc, - filtered_image: &image::ImageBuffer, Vec>, -) -> Option { - let cached = { - let locked = ocr_cache.read().unwrap(); - locked.get(&hash).cloned() - }; - let use_cache = region.use_ocr_cache.unwrap_or(true) && config.use_ocr_cache.unwrap_or(true); - if let Some(cached) = cached { - if use_cache { - return cached; - } - } - match run_ocr(filtered_image, &config.ocr_server_endpoint).await { - Ok(v) => { - if use_cache { - ocr_cache.write().unwrap().insert(hash.clone(), v.clone()); - } - v - } - Err(_) => None, - } -} - -#[tokio::main(flavor = "current_thread")] -pub async fn ocr_all_regions( - image: &RgbImage, - config: Arc, - ocr_cache: Arc, - should_sample: bool, -) -> HashMap> { - let results = Arc::new(Mutex::new(HashMap::new())); - - let mut handles = Vec::new(); - for region in &config.ocr_regions { - let filtered_image = extract_and_filter(image, region); - let region = region.clone(); - let results = results.clone(); - let config = config.clone(); - let ocr_cache = ocr_cache.clone(); - handles.push(tokio::spawn(async move { - let filtered_image = filtered_image; - let hash = hash_image(&filtered_image); - let value = - run_ocr_cached(ocr_cache, hash, ®ion, config.clone(), &filtered_image).await; - - if let Some(sample_fraction) = &config.dump_frame_fraction { - if rand::random::() < *sample_fraction && should_sample { - let file_id = rand::random::(); - let img_filename = format!("ocr_data/{}.png", file_id); - filtered_image.save(img_filename).unwrap(); - let value_filename = format!("ocr_data/{}.txt", file_id); - std::fs::write(value_filename, value.clone().unwrap_or_default()).unwrap(); - } - } - results.lock().unwrap().insert(region.name, value); - })); - } - for handle in handles { - handle.await.expect("failed to join task in OCR"); - } - - let results = results.lock().unwrap().clone(); - results -} diff --git a/src/state.rs b/src/state.rs index a70122c..be2ed17 100644 --- a/src/state.rs +++ b/src/state.rs @@ -4,7 +4,7 @@ use egui_extras::RetainedImage; use image::RgbImage; use time::{OffsetDateTime, format_description}; -use crate::{config::Config, learned_tracks::LearnedTracks}; +use crate::{config::Config, learned_tracks::LearnedTracks, ocr_db::OcrDatabase}; #[derive(Debug, Clone, Default)] @@ -33,8 +33,8 @@ fn parse_duration(time: &str) -> Option { Some(Duration::from_secs_f64(60.0 * minutes + secs)) } -fn parse_to_duration(time: Option<&Option>) -> Option { - parse_duration(&time?.clone()?) +fn parse_to_duration(time: Option<&String>) -> Option { + parse_duration(&time?) } fn check_0_100(v: usize) -> Option { @@ -45,12 +45,12 @@ fn check_0_100(v: usize) -> Option { } } -fn parse_to_0_100(v: Option<&Option>) -> Option { - check_0_100(v?.clone()?.parse::().ok()?) +fn parse_to_0_100(v: Option<&String>) -> Option { + check_0_100(v?.parse::().ok()?) } impl LapState { - pub fn parse(raw: &HashMap>) -> Self { + pub fn parse(raw: &HashMap) -> Self { Self { lap: parse_to_0_100(raw.get("lap")), health: parse_to_0_100(raw.get("health")), @@ -126,10 +126,9 @@ pub struct DebugOcrFrame { pub recognized_text: Option, } -pub type OcrCache = RwLock>>; #[derive(Default)] pub struct AppState { - pub raw_data: HashMap>, + pub raw_data: HashMap, pub last_frame: Option, pub buffered_frames: VecDeque, @@ -144,8 +143,7 @@ pub struct AppState { pub config: Arc, pub learned_tracks: Arc, - - pub ocr_cache: Arc, + pub ocr_db: Arc, } pub type SharedAppState = Arc>; \ No newline at end of file diff --git a/src/test_data/test-image-3.png b/src/test_data/test-image-3.png new file mode 100644 index 0000000000000000000000000000000000000000..f9d91a65e793c7dfdd9699a29a4068e7b0cd41fa GIT binary patch literal 707 zcmeAS@N?(olHy`uVBq!ia0vp^+kse{g9%7(UD5Y}fq|*Z)5S5QV$Rzeck^B=2(Ubu zym;S#?d8h7Edtz|S#l{`u#hbhW?7=g)kmxpB_&BSrIsPFH&A zJcn~?AHH95{BdKp;m_IkqQWn3aj!bKELME^tg8_C?qf&0C zZu?f5rY#gXa$--Oe$&_6Qb9YMG&?4Cm@Y~?@sla1ZzNOFuHETmPN^VeK>fo}Z`qOD=10?9cK_`_NI@dbaSO)o1Mm64x{8 zSRW|uYBlA(yXfyu&6WJz%Q-6o+4GHd2|f~d9J}I3s7QZMMoXYs>5J-<>4#P_#_7z? zIj-K4y?uxD8=>wcH;#&m{duXWIjuim&VKpC5@wwX-5J(<+n2TQTMBKq?k}0TV|DpM zzk3yG?aNYl{oN}bA6#0!n=AjzG4)pA;6v(9F86Inn-`ZW|9vsgdsk8-#mzTtKBVgB z!`botiK9!-OJLAA@)jQGb~E?zIl8pLQOsnQqtw!pIp4E7Ld0J&>hE&V+WPPQ#nq?u zPRk~%WM$3T&YPCpRj74SNN3`@Wos&L3*5}>I%pMSu=MD$UEMb&A`9}oR@^i3xwrS7 zkgkrjs6n)JsM##HyRAuq4La+N)}&rvVf(F5OMRug(PHQO-f!GLPl-4eoMh?W`P4>r d8X?Zkd^_*#)Lf?dg22Sd;OXk;vd$@?2>=F, + ocr_db: OcrDatabase, } struct TrainingImage { @@ -28,7 +29,7 @@ struct TrainingImage { ui_image: RetainedImage, char_images: Vec, - char_hashes: Vec, + char_hashes: Vec, } impl TrainingImage { @@ -55,16 +56,6 @@ fn get_training_data_paths() -> Vec<(PathBuf, PathBuf)> { data_paths } -fn predict_ocr(hashes: &[(String, char)], hash: &str) -> Option { - let hash = img_hash::ImageHash::>::from_base64(hash).unwrap(); - let (_, best_char) = hashes.iter().min_by_key(|(learned_hash, _c)| { - img_hash::ImageHash::from_base64(learned_hash) - .unwrap() - .dist(&hash) - })?; - Some(*best_char) -} - fn get_training_data() -> Vec { let mut data = Vec::new(); for (img_file, ocr_file) in get_training_data_paths() { @@ -90,24 +81,6 @@ fn get_training_data() -> Vec { data } -fn load_learned_hashes() -> Vec<(String, char)> { - let path = Path::new("learned_chars.txt"); - if !path.exists() { - return Vec::new(); - } - - let data = String::from_utf8(std::fs::read(path).unwrap()).unwrap(); - let mut parsed = Vec::new(); - for line in data.lines() { - if let Some((c, hash)) = line.split_once(' ') { - if let Some(c) = c.chars().next() { - parsed.push((hash.to_owned(), c)); - } - } - } - parsed -} - fn load_retained_image(image: &RgbImage) -> RetainedImage { RetainedImage::from_image_bytes("", &to_png_bytes(image)).unwrap() } @@ -116,7 +89,7 @@ pub fn training_ui() -> anyhow::Result<()> { let options = eframe::NativeOptions::default(); let state = TrainingUi { training_images: get_training_data(), - learned_char_hashes: load_learned_hashes(), + ocr_db: OcrDatabase::load().unwrap(), ..Default::default() }; eframe::run_native("OCR Trainer", options, Box::new(|_cc| Box::new(state))); @@ -139,46 +112,22 @@ impl eframe::App for TrainingUi { current_image.save_ocr_text(); } if ui.button("Learn").clicked() { - for (i, char) in current_image.text.chars().enumerate() { - if let Some(hash) = current_image.char_hashes.get(i) { - self.learned_char_hashes.push((hash.clone(), char)); - eprintln!("Learned {}={}", hash, char); - } - } + self.ocr_db.learn_phrase(¤t_image.char_hashes, ¤t_image.text).unwrap(); self.current_image_index += 1; } if ui.button("Learn and delete").clicked() { - for (i, char) in current_image.text.chars().enumerate() { - if let Some(hash) = current_image.char_hashes.get(i) { - self.learned_char_hashes.push((hash.clone(), char)); - eprintln!("Learned {}={}", hash, char); - } - } + self.ocr_db.learn_phrase(¤t_image.char_hashes, ¤t_image.text).unwrap(); current_image.delete_data(); self.current_image_index += 1; } if ui.button("Save learned results").clicked() { - let mut buffer = String::new(); - for (hash, c) in &self.learned_char_hashes { - buffer += &format!("{} {}\n", c, hash); - } - let mut file = std::fs::OpenOptions::new() - .append(true) - .create(true) - .write(true) - .open("learned_chars.txt") - .unwrap(); - file.write_all(buffer.as_bytes()).unwrap(); + self.ocr_db.save().unwrap(); } current_image.ui_image.show(ui); ui.label("OCR value"); ui.text_edit_singleline(&mut current_image.text); - let predicted: String = current_image - .char_hashes - .iter() - .filter_map(|hash| predict_ocr(&self.learned_char_hashes, hash)) - .collect(); + let predicted: String = self.ocr_db.ocr_hashes(¤t_image.char_hashes); if predicted == current_image.text { ui.colored_label(Color32::GREEN, format!("Predicted: {}", predicted)); } else { @@ -189,8 +138,8 @@ impl eframe::App for TrainingUi { c.show(ui); } for c in ¤t_image.char_hashes { - ui.label(c); - if let Some(predicted) = predict_ocr(&self.learned_char_hashes, c) { + ui.label(c.to_base64()); + if let Some(predicted) = self.ocr_db.ocr_char(c) { ui.label(format!("Predicted: {}", predicted)); } }