From 3bffdd1b8dfff3ba7c6a0ad1fb15ffe40fab670e Mon Sep 17 00:00:00 2001 From: Scott Pruett Date: Sun, 22 May 2022 14:03:47 -0400 Subject: [PATCH] learned images --- config.json | 3 +- learned.json | 5 +-- src/capture.rs | 5 ++- src/config.rs | 2 +- src/configs/config.default.json | 2 +- src/control_loop.rs | 28 +++++++-------- src/image_processing.rs | 10 +++++- src/main.rs | 54 ++++++++++++++++++++++------ src/ocr.rs | 62 +++++++++++++++++++++------------ src/state.rs | 2 +- 10 files changed, 115 insertions(+), 58 deletions(-) diff --git a/config.json b/config.json index 5cec55a..6e39311 100644 --- a/config.json +++ b/config.json @@ -43,5 +43,6 @@ "height": 43 } ], - "ocr_server_endpoint": "http://localhost:3000/" + + "ocr_server_endpoint": "https://tesserver.spruett.dev/" } \ No newline at end of file diff --git a/learned.json b/learned.json index 9190260..1c7f325 100644 --- a/learned.json +++ b/learned.json @@ -1,4 +1,5 @@ { - "learned_images": {}, - "learned_tracks": {} + "learned_images": { + }, + "learned_tracks": {} } \ No newline at end of file diff --git a/src/capture.rs b/src/capture.rs index 996a924..74a21ac 100644 --- a/src/capture.rs +++ b/src/capture.rs @@ -21,9 +21,8 @@ fn get_raw_frame(capturer: &mut Capturer) -> Result> { } } -pub fn get_frame() -> Result { - let mut capturer = Capturer::new(Display::primary()?)?; - let frame = get_raw_frame(&mut capturer)?; +pub fn get_frame(capturer: &mut Capturer) -> Result { + let frame = get_raw_frame(capturer)?; let mut image = RgbImage::new(capturer.width() as u32, capturer.height() as u32); let stride = frame.len() / capturer.height(); diff --git a/src/config.rs b/src/config.rs index e34941d..b5230fd 100644 --- a/src/config.rs +++ b/src/config.rs @@ -41,7 +41,7 @@ fn load_or_make_default(path: &str, default: &str) -> Resul if !file_path.exists() { std::fs::write(&path, default)?; } - load_json_config(&path) + load_json_config(path) } fn load_json_config(path: &str) -> Result { diff --git a/src/configs/config.default.json b/src/configs/config.default.json index 5cec55a..c021f96 100644 --- a/src/configs/config.default.json +++ b/src/configs/config.default.json @@ -43,5 +43,5 @@ "height": 43 } ], - "ocr_server_endpoint": "http://localhost:3000/" + "ocr_server_endpoint": "https://tesserver.spruett.dev/" } \ No newline at end of file diff --git a/src/control_loop.rs b/src/control_loop.rs index 187e708..928773b 100644 --- a/src/control_loop.rs +++ b/src/control_loop.rs @@ -1,6 +1,6 @@ use std::{ collections::HashMap, - time::{Duration, Instant}, + time::{Duration, Instant}, thread, }; use anyhow::Result; @@ -8,9 +8,10 @@ use egui_extras::RetainedImage; use image::RgbImage; use scrap::{Capturer, Display}; + use crate::{ capture, - image_processing::{self, Region}, + image_processing::{self, hash_image}, ocr, state::{AppState, DebugOcrFrame, ParsedFrame, RaceState, SharedAppState}, }; @@ -103,12 +104,15 @@ fn handle_new_frame(state: &mut AppState, frame: ParsedFrame, image: &RgbImage) } } -async fn run_loop_once(state: &SharedAppState) -> Result<()> { +fn run_loop_once(state: &SharedAppState) -> Result<()> { + let mut capturer = Capturer::new(Display::primary()?)?; let config = state.lock().unwrap().config.clone(); - let frame = capture::get_frame()?; - let ocr_results = ocr::ocr_all_regions(&frame, &config.ocr_regions).await; + let learned_config = state.lock().unwrap().learned.clone(); + let frame = capture::get_frame(&mut capturer)?; + let ocr_results = ocr::ocr_all_regions(&frame, config.clone(), learned_config.clone()); let mut saved_frames = HashMap::new(); + if state.lock().unwrap().debug_frames { let hasher = img_hash::HasherConfig::new().to_hasher(); for region in &config.ocr_regions { @@ -119,13 +123,7 @@ async fn run_loop_once(state: &SharedAppState) -> Result<()> { &image_processing::to_png_bytes(&extracted), ) .unwrap(); - let have_to_use_other_image_library_version = img_hash::image::RgbImage::from_raw( - extracted.width(), - extracted.height(), - extracted.as_raw().to_vec(), - ) - .unwrap(); - let hash = hasher.hash_image(&have_to_use_other_image_library_version); + let hash = hash_image(&extracted); saved_frames.insert( region.name.clone(), DebugOcrFrame { @@ -146,11 +144,11 @@ async fn run_loop_once(state: &SharedAppState) -> Result<()> { Ok(()) } -pub async fn run_control_loop(state: SharedAppState) -> Result<()> { +pub fn run_control_loop(state: SharedAppState) { loop { - if let Err(e) = run_loop_once(&state).await { + if let Err(e) = run_loop_once(&state) { eprintln!("Error in control loop: {:?}", e) } - tokio::time::sleep(Duration::from_millis(500)).await; + thread::sleep(Duration::from_millis(500)); } } diff --git a/src/image_processing.rs b/src/image_processing.rs index f39d60e..d6831bb 100644 --- a/src/image_processing.rs +++ b/src/image_processing.rs @@ -1,4 +1,3 @@ -use anyhow::Result; use image::{codecs::png::PngEncoder, ColorType, ImageEncoder, Rgb, RgbImage}; use serde::{Deserialize, Serialize}; @@ -64,3 +63,12 @@ pub fn to_png_bytes(image: &RgbImage) -> Vec { .expect("failed encoding image to PNG"); buffer } + +pub fn hash_image(image: &RgbImage) -> String { + let hasher = img_hash::HasherConfig::new().to_hasher(); + let have_to_use_other_image_library_version = + img_hash::image::RgbImage::from_raw(image.width(), image.height(), image.as_raw().to_vec()) + .unwrap(); + let hash = hasher.hash_image(&have_to_use_other_image_library_version); + hash.to_base64() +} diff --git a/src/main.rs b/src/main.rs index a3e5bed..fbabc49 100644 --- a/src/main.rs +++ b/src/main.rs @@ -9,7 +9,7 @@ mod config; use std::{ sync::{Arc, Mutex}, - time::Duration, + time::Duration, thread, }; use config::{Config, LearnedConfig}; @@ -19,18 +19,15 @@ use eframe::{ }; use state::{AppState, RaceState, SharedAppState}; -#[tokio::main(flavor = "multi_thread", worker_threads = 8)] -async fn main() -> anyhow::Result<()> { +fn main() -> anyhow::Result<()> { let mut app_state = AppState::default(); app_state.config = Arc::new(Config::load().unwrap()); app_state.learned = Arc::new(LearnedConfig::load().unwrap()); let state = Arc::new(Mutex::new(app_state)); { let state = state.clone(); - let _ = tokio::spawn(async move { - control_loop::run_control_loop(state) - .await - .expect("control loop failed"); + let _ = thread::spawn(move || { + control_loop::run_control_loop(state); }); } @@ -91,11 +88,16 @@ fn label_time_delta(ui: &mut Ui, time: Duration, old: Option) { struct MyApp { state: SharedAppState, + + config_load_err: Option, + + hash_to_learn: String, + value_to_learn: String, } impl MyApp { pub fn new(state: SharedAppState) -> Self { - Self { state } + Self { state, config_load_err: None, hash_to_learn: "".to_owned(), value_to_learn: "".to_owned() } } } @@ -110,7 +112,7 @@ fn show_race_state(ui: &mut Ui, race: &RaceState) { ui.label("Tyres"); ui.end_row(); for (i, lap) in race.laps.iter().enumerate() { - if let Some(lap_time) = *&lap.lap_time { + if let Some(lap_time) = lap.lap_time { let prev_lap = race.laps.get(i - 1); ui.label(format!("#{}", lap.lap.unwrap_or(i + 1))); @@ -211,9 +213,41 @@ impl eframe::App for MyApp { screenshots_sorted.sort_by_key(|(name, _)| name.clone()); for (name, image) in screenshots_sorted { ui.label(name); - ui.label(image.img_hash.to_base64()); + if ui.button(&image.img_hash).on_hover_text("Copy").clicked() { + ui.output().copied_text = image.img_hash.clone(); + } image.image.show_max_size(ui, ui.available_size()); } + + if ui.button("Reload config").clicked() { + match Config::load() { + Ok(c) => { + state.config = Arc::new(c); + self.config_load_err = None; + } + Err(e) => { + self.config_load_err = Some(format!("failed to load config: {:?}", e)); + } + } + } + if let Some(e) = &self.config_load_err { + ui.colored_label(Color32::RED, e); + } + + ui.separator(); + ui.label("Hash"); + ui.text_edit_singleline(&mut self.hash_to_learn); + ui.label("Value"); + ui.text_edit_singleline(&mut self.value_to_learn); + if ui.button("Learn").clicked() { + let mut learned_config = (*state.learned).clone(); + learned_config.learned_images.insert(self.hash_to_learn.clone(), self.value_to_learn.clone()); + learned_config.save().unwrap(); + state.learned = Arc::new(learned_config); + + self.hash_to_learn = "".to_owned(); + self.value_to_learn = "".to_owned(); + } }); } diff --git a/src/ocr.rs b/src/ocr.rs index 3bc1a31..2c57c48 100644 --- a/src/ocr.rs +++ b/src/ocr.rs @@ -1,10 +1,16 @@ -use std::{collections::HashMap, sync::{Arc, Mutex}}; +use std::{ + collections::HashMap, + sync::{Arc, Mutex}, +}; use anyhow::Result; use image::RgbImage; use serde::{Deserialize, Serialize}; -use crate::image_processing::{Region, extract_region, filter_to_white}; +use crate::{ + config::{Config, LearnedConfig}, + image_processing::{extract_region, filter_to_white, hash_image, Region}, +}; #[derive(Serialize, Deserialize, Debug)] pub struct OcrRegion { @@ -18,46 +24,56 @@ struct OcrResult { error: Option, } -async fn run_ocr(image: &RgbImage) -> Result> { +async fn run_ocr(image: &RgbImage, url: &str) -> Result> { let client = reqwest::Client::new(); let response = client - .post("http://localhost:3000/") - .body(crate::image_processing::to_png_bytes(&image)) + .post(url) + .body(crate::image_processing::to_png_bytes(image)) .send() .await?; if !response.status().is_success() { + eprintln!("failed to run OCR query"); anyhow::bail!("failed to run OCR query") } let result: OcrResult = response.json().await?; - Ok(result.regions) + let result = if result.regions.is_empty() { + None + } else { + let mut buffer = String::new(); + for r in &result.regions { + buffer += &r.value; + } + Some(buffer) + }; + Ok(result) } -pub async fn ocr_all_regions(image: &RgbImage, regions: &[Region]) -> HashMap> { +#[tokio::main(flavor = "current_thread")] +pub async fn ocr_all_regions( + image: &RgbImage, + config: Arc, + learned: Arc, +) -> HashMap> { let results = Arc::new(Mutex::new(HashMap::new())); let mut handles = Vec::new(); - for region in regions { + for region in &config.ocr_regions { let filtered_image = extract_region(image, region); let region = region.clone(); let results = results.clone(); + let config = config.clone(); + let learned = learned.clone(); handles.push(tokio::spawn(async move { let mut image = filtered_image; filter_to_white(&mut image); - let ocr_results = run_ocr(&image).await; - let value = match ocr_results { - Ok(ocr_regions) => { - if ocr_regions.is_empty() { - None - } else { - let mut out = String::new(); - for r in ocr_regions { - out += &r.value; - } - Some(out) - } - } - Err(_) => None + let hash = hash_image(&image); + let value = if let Some(learned_value) = learned.learned_images.get(&hash) { + Some(learned_value.clone()) + } else { + run_ocr(&image, &config.ocr_server_endpoint) + .await + .unwrap_or(None) }; results.lock().unwrap().insert(region.name, value); })); @@ -68,4 +84,4 @@ pub async fn ocr_all_regions(image: &RgbImage, regions: &[Region]) -> HashMap