diff --git a/src/analysis.rs b/src/analysis.rs index 9d07f81..7c5ab7b 100644 --- a/src/analysis.rs +++ b/src/analysis.rs @@ -14,7 +14,7 @@ use crate::{ capture, config::Config, image_processing::{self, extract_and_filter, hash_image, Region, to_png_bytes}, - ocr, + remote_ocr, state::{AppState, DebugOcrFrame, LapState, RaceState, SharedAppState}, learned_tracks::get_track_hash, }; @@ -159,7 +159,7 @@ fn run_loop_once(capturer: &mut Capturer, state: &SharedAppState) -> Result<()> locked.should_sample_ocr_data ) }; - let ocr_results = ocr::ocr_all_regions(&frame, config.clone(), ocr_cache, should_sample); + let ocr_results = remote_ocr::ocr_all_regions(&frame, config.clone(), ocr_cache, should_sample); if state.lock().unwrap().debug_frames { let debug_frames = save_frames_from(&frame, config.as_ref(), &ocr_results); diff --git a/src/local_ocr.rs b/src/local_ocr.rs deleted file mode 100644 index 1512cbe..0000000 --- a/src/local_ocr.rs +++ /dev/null @@ -1,135 +0,0 @@ -use image::{Rgb, RgbImage}; -use img_hash::{image::GenericImageView}; - -use crate::image_processing; - -#[derive(Debug)] -struct BoundingBox { - x: u32, - y: u32, - width: u32, - height: u32, -} - -fn column_has_any_dark(image: &RgbImage, x: u32) -> bool { - for y in 0..image.height() { - let [r, g, b] = image.get_pixel(x, y).0; - if r < 100 && g < 100 && b < 100 { - return true; - } - } - false -} - -fn row_has_any_dark(image: &RgbImage, y: u32, start_x: u32, width: u32) -> bool { - for x in start_x..(start_x + width) { - let [r, g, b] = image.get_pixel(x, y).0; - if r < 100 && g < 100 && b < 100 { - return true; - } - } - false -} - -fn take_while bool>(x: &mut u32, max: u32, f: F) { - while *x < max && f(*x) { - *x += 1; - } -} - -fn get_character_bounding_boxes(image: &RgbImage) -> Vec { - let mut x = 0; - let mut boxes = Vec::new(); - while x < image.width() { - take_while(&mut x, image.width(), |x| !column_has_any_dark(image, x)); - - let start_x = x; - take_while(&mut x, image.width(), |x| column_has_any_dark(image, x)); - let width = x - start_x; - - if width >= 1 { - let mut y = 0; - take_while(&mut y, image.height(), |y| { - !row_has_any_dark(image, y, start_x, width) - }); - - let start_y = y; - - let mut inverse_y = 0; - take_while(&mut inverse_y, image.height(), |y| { - !row_has_any_dark(image, image.height() - 1 - y, start_x, width) - }); - let end_y = image.height() - inverse_y; - let height = end_y - start_y; - if height >= 1 { - boxes.push(BoundingBox { - x: start_x, - y: start_y, - width, - height, - }); - } - } - } - boxes -} - -fn trim_to_bounding_box(image: &RgbImage, bounding_box: &BoundingBox) -> RgbImage { - const PADDING: u32 = 2; - let mut buffer = RgbImage::from_pixel( - bounding_box.width + 2 * PADDING, - bounding_box.height + 2 * PADDING, - Rgb([0xFF, 0xFF, 0xFF]), - ); - for y in 0..bounding_box.height { - for x in 0..bounding_box.width { - buffer.put_pixel( - x + PADDING, - y + PADDING, - *image.get_pixel(bounding_box.x + x, bounding_box.y + y), - ); - } - } - buffer -} - -pub fn bounding_box_images(image: &RgbImage) -> Vec { - let mut trimmed = Vec::new(); - - let boxes = get_character_bounding_boxes(image); - for bounding_box in boxes { - trimmed.push(trim_to_bounding_box(image, &bounding_box)); - } - trimmed -} - -pub fn compute_box_hashes(image: &RgbImage) -> Vec { - let mut hashes = Vec::new(); - - let boxes = get_character_bounding_boxes(image); - for bounding_box in boxes { - let trimmed = trim_to_bounding_box(image, &bounding_box); - hashes.push(image_processing::hash_image(&trimmed)) - } - hashes -} - -#[test] -fn test_bounding_boxes() { - let image_bytes = include_bytes!("test_data/test-image-2.png"); - let image = image::load_from_memory(image_bytes).unwrap().to_rgb8(); - let boxes = get_character_bounding_boxes(&image); - assert_eq!(boxes.len(), 10); - assert_ne!(boxes[0].x, 0); - assert_ne!(boxes[0].y, 0); - assert_ne!(boxes[0].height, 0); - assert_ne!(boxes[0].width, 0); -} - -#[test] -fn test_box_hashes() { - let image_bytes = include_bytes!("test_data/test-image-2.png"); - let image = image::load_from_memory(image_bytes).unwrap().to_rgb8(); - let hashes = compute_box_hashes(&image); - assert_eq!(hashes.len(), 10); -} diff --git a/src/main.rs b/src/main.rs index 10b701c..d193861 100644 --- a/src/main.rs +++ b/src/main.rs @@ -4,8 +4,8 @@ mod analysis; mod capture; mod config; mod image_processing; -mod local_ocr; mod ocr; +mod remote_ocr; mod state; mod stats_writer; mod training_ui; @@ -125,9 +125,6 @@ struct DebugLap { struct UiState { config_load_err: Option, - hash_to_learn: String, - value_to_learn: String, - debug_lap: Option, } @@ -289,7 +286,7 @@ fn open_debug_lap( ) { if let Some(screenshot_bytes) = &lap.screenshot { let screenshot = from_png_bytes(screenshot_bytes); - let ocr_results = ocr::ocr_all_regions(&screenshot, config.clone(), ocr_cache, false); + let ocr_results = remote_ocr::ocr_all_regions(&screenshot, config.clone(), ocr_cache, false); let debug_lap = DebugLap { screenshot: RetainedImage::from_image_bytes("debug-lap", &to_png_bytes(&screenshot)) .unwrap(), @@ -308,9 +305,6 @@ fn show_combo_box(ui: &mut Ui, name: &str, label: &str, options: &[String], valu *value = options[index].clone(); } -fn save_learned_track(_learned_tracks: &mut Arc, _track: &str, _hash: &str) { -} - impl eframe::App for AppUi { fn update(&mut self, ctx: &egui::Context, _frame: &mut eframe::Frame) { let mut state = self.state.lock().unwrap(); diff --git a/src/ocr.rs b/src/ocr.rs index 673ba35..c3bfb70 100644 --- a/src/ocr.rs +++ b/src/ocr.rs @@ -1,121 +1,134 @@ -use std::{ - collections::HashMap, - sync::{Arc, Mutex, RwLock}, -}; +use image::{Rgb, RgbImage}; -use anyhow::Result; -use image::RgbImage; -use serde::{Deserialize, Serialize}; +use crate::image_processing; -use crate::{ - config::Config, - image_processing::{extract_and_filter, hash_image}, - state::OcrCache, -}; - -#[derive(Serialize, Deserialize, Debug)] -pub struct OcrRegion { - pub confidence: f64, - pub value: String, +#[derive(Debug)] +struct BoundingBox { + x: u32, + y: u32, + width: u32, + height: u32, } -#[derive(Serialize, Deserialize, Debug)] -struct OcrResult { - regions: Vec, - error: Option, -} - -async fn run_ocr(image: &RgbImage, url: &str) -> Result> { - let client = reqwest::Client::new(); - let response = client - .post(url) - .body(crate::image_processing::to_png_bytes(image)) - .send() - .await?; - - if !response.status().is_success() { - eprintln!("failed to run OCR query"); - anyhow::bail!("failed to run OCR query") - } - let result: OcrResult = response.json().await?; - let result = if result.regions.is_empty() { - None - } else { - let mut buffer = String::new(); - for r in &result.regions { - buffer += &r.value; - } - Some(buffer) - }; - Ok(result) -} - -async fn run_ocr_cached( - ocr_cache: Arc>>>, - hash: String, - region: &crate::image_processing::Region, - config: Arc, - filtered_image: &image::ImageBuffer, Vec>, -) -> Option { - let cached = { - let locked = ocr_cache.read().unwrap(); - locked.get(&hash).cloned() - }; - let use_cache = region.use_ocr_cache.unwrap_or(true) && config.use_ocr_cache.unwrap_or(true); - if let Some(cached) = cached { - if use_cache { - return cached; +fn column_has_any_dark(image: &RgbImage, x: u32) -> bool { + for y in 0..image.height() { + let [r, g, b] = image.get_pixel(x, y).0; + if r < 100 && g < 100 && b < 100 { + return true; } } - match run_ocr(filtered_image, &config.ocr_server_endpoint).await { - Ok(v) => { - if use_cache { - ocr_cache.write().unwrap().insert(hash.clone(), v.clone()); + false +} + +fn row_has_any_dark(image: &RgbImage, y: u32, start_x: u32, width: u32) -> bool { + for x in start_x..(start_x + width) { + let [r, g, b] = image.get_pixel(x, y).0; + if r < 100 && g < 100 && b < 100 { + return true; + } + } + false +} + +fn take_while bool>(x: &mut u32, max: u32, f: F) { + while *x < max && f(*x) { + *x += 1; + } +} + +fn get_character_bounding_boxes(image: &RgbImage) -> Vec { + let mut x = 0; + let mut boxes = Vec::new(); + while x < image.width() { + take_while(&mut x, image.width(), |x| !column_has_any_dark(image, x)); + + let start_x = x; + take_while(&mut x, image.width(), |x| column_has_any_dark(image, x)); + let width = x - start_x; + + if width >= 1 { + let mut y = 0; + take_while(&mut y, image.height(), |y| { + !row_has_any_dark(image, y, start_x, width) + }); + + let start_y = y; + + let mut inverse_y = 0; + take_while(&mut inverse_y, image.height(), |y| { + !row_has_any_dark(image, image.height() - 1 - y, start_x, width) + }); + let end_y = image.height() - inverse_y; + let height = end_y - start_y; + if height >= 1 { + boxes.push(BoundingBox { + x: start_x, + y: start_y, + width, + height, + }); } - v } - Err(_) => None, } + boxes } -#[tokio::main(flavor = "current_thread")] -pub async fn ocr_all_regions( - image: &RgbImage, - config: Arc, - ocr_cache: Arc, - should_sample: bool, -) -> HashMap> { - let results = Arc::new(Mutex::new(HashMap::new())); - - let mut handles = Vec::new(); - for region in &config.ocr_regions { - let filtered_image = extract_and_filter(image, region); - let region = region.clone(); - let results = results.clone(); - let config = config.clone(); - let ocr_cache = ocr_cache.clone(); - handles.push(tokio::spawn(async move { - let filtered_image = filtered_image; - let hash = hash_image(&filtered_image); - let value = - run_ocr_cached(ocr_cache, hash, ®ion, config.clone(), &filtered_image).await; - - if let Some(sample_fraction) = &config.dump_frame_fraction { - if rand::random::() < *sample_fraction && should_sample { - let file_id = rand::random::(); - let img_filename = format!("ocr_data/{}.png", file_id); - filtered_image.save(img_filename).unwrap(); - let value_filename = format!("ocr_data/{}.txt", file_id); - std::fs::write(value_filename, value.clone().unwrap_or_default()).unwrap(); - } - } - results.lock().unwrap().insert(region.name, value); - })); +fn trim_to_bounding_box(image: &RgbImage, bounding_box: &BoundingBox) -> RgbImage { + const PADDING: u32 = 2; + let mut buffer = RgbImage::from_pixel( + bounding_box.width + 2 * PADDING, + bounding_box.height + 2 * PADDING, + Rgb([0xFF, 0xFF, 0xFF]), + ); + for y in 0..bounding_box.height { + for x in 0..bounding_box.width { + buffer.put_pixel( + x + PADDING, + y + PADDING, + *image.get_pixel(bounding_box.x + x, bounding_box.y + y), + ); + } } - for handle in handles { - handle.await.expect("failed to join task in OCR"); - } - - let results = results.lock().unwrap().clone(); - results + buffer +} + +pub fn bounding_box_images(image: &RgbImage) -> Vec { + let mut trimmed = Vec::new(); + + let boxes = get_character_bounding_boxes(image); + for bounding_box in boxes { + trimmed.push(trim_to_bounding_box(image, &bounding_box)); + } + trimmed +} + +pub fn compute_box_hashes(image: &RgbImage) -> Vec { + let mut hashes = Vec::new(); + + let boxes = get_character_bounding_boxes(image); + for bounding_box in boxes { + let trimmed = trim_to_bounding_box(image, &bounding_box); + hashes.push(image_processing::hash_image(&trimmed)) + } + hashes +} + +#[test] +fn test_bounding_boxes() { + let image_bytes = include_bytes!("test_data/test-image-2.png"); + let image = image::load_from_memory(image_bytes).unwrap().to_rgb8(); + let boxes = get_character_bounding_boxes(&image); + assert_eq!(boxes.len(), 10); + assert_ne!(boxes[0].x, 0); + assert_ne!(boxes[0].y, 0); + assert_ne!(boxes[0].height, 0); + assert_ne!(boxes[0].width, 0); +} + +#[test] +fn test_box_hashes() { + let image_bytes = include_bytes!("test_data/test-image-2.png"); + let image = image::load_from_memory(image_bytes).unwrap().to_rgb8(); + let hashes = compute_box_hashes(&image); + assert_eq!(hashes.len(), 10); } diff --git a/src/remote_ocr.rs b/src/remote_ocr.rs new file mode 100644 index 0000000..673ba35 --- /dev/null +++ b/src/remote_ocr.rs @@ -0,0 +1,121 @@ +use std::{ + collections::HashMap, + sync::{Arc, Mutex, RwLock}, +}; + +use anyhow::Result; +use image::RgbImage; +use serde::{Deserialize, Serialize}; + +use crate::{ + config::Config, + image_processing::{extract_and_filter, hash_image}, + state::OcrCache, +}; + +#[derive(Serialize, Deserialize, Debug)] +pub struct OcrRegion { + pub confidence: f64, + pub value: String, +} + +#[derive(Serialize, Deserialize, Debug)] +struct OcrResult { + regions: Vec, + error: Option, +} + +async fn run_ocr(image: &RgbImage, url: &str) -> Result> { + let client = reqwest::Client::new(); + let response = client + .post(url) + .body(crate::image_processing::to_png_bytes(image)) + .send() + .await?; + + if !response.status().is_success() { + eprintln!("failed to run OCR query"); + anyhow::bail!("failed to run OCR query") + } + let result: OcrResult = response.json().await?; + let result = if result.regions.is_empty() { + None + } else { + let mut buffer = String::new(); + for r in &result.regions { + buffer += &r.value; + } + Some(buffer) + }; + Ok(result) +} + +async fn run_ocr_cached( + ocr_cache: Arc>>>, + hash: String, + region: &crate::image_processing::Region, + config: Arc, + filtered_image: &image::ImageBuffer, Vec>, +) -> Option { + let cached = { + let locked = ocr_cache.read().unwrap(); + locked.get(&hash).cloned() + }; + let use_cache = region.use_ocr_cache.unwrap_or(true) && config.use_ocr_cache.unwrap_or(true); + if let Some(cached) = cached { + if use_cache { + return cached; + } + } + match run_ocr(filtered_image, &config.ocr_server_endpoint).await { + Ok(v) => { + if use_cache { + ocr_cache.write().unwrap().insert(hash.clone(), v.clone()); + } + v + } + Err(_) => None, + } +} + +#[tokio::main(flavor = "current_thread")] +pub async fn ocr_all_regions( + image: &RgbImage, + config: Arc, + ocr_cache: Arc, + should_sample: bool, +) -> HashMap> { + let results = Arc::new(Mutex::new(HashMap::new())); + + let mut handles = Vec::new(); + for region in &config.ocr_regions { + let filtered_image = extract_and_filter(image, region); + let region = region.clone(); + let results = results.clone(); + let config = config.clone(); + let ocr_cache = ocr_cache.clone(); + handles.push(tokio::spawn(async move { + let filtered_image = filtered_image; + let hash = hash_image(&filtered_image); + let value = + run_ocr_cached(ocr_cache, hash, ®ion, config.clone(), &filtered_image).await; + + if let Some(sample_fraction) = &config.dump_frame_fraction { + if rand::random::() < *sample_fraction && should_sample { + let file_id = rand::random::(); + let img_filename = format!("ocr_data/{}.png", file_id); + filtered_image.save(img_filename).unwrap(); + let value_filename = format!("ocr_data/{}.txt", file_id); + std::fs::write(value_filename, value.clone().unwrap_or_default()).unwrap(); + } + } + results.lock().unwrap().insert(region.name, value); + })); + } + for handle in handles { + handle.await.expect("failed to join task in OCR"); + } + + let results = results.lock().unwrap().clone(); + results +} diff --git a/src/state.rs b/src/state.rs index fbb49c4..a70122c 100644 --- a/src/state.rs +++ b/src/state.rs @@ -66,13 +66,11 @@ impl LapState { fn median_wear(values: Vec>) -> Option { let mut wear_values = Vec::new(); let mut last_value = 100; - for val in values { - if let Some(val) = val { - if val < last_value { - wear_values.push(last_value - val); - } - last_value = val; + for val in values.into_iter().flatten() { + if val < last_value { + wear_values.push(last_value - val); } + last_value = val; } wear_values.sort_unstable(); wear_values.get(wear_values.len() / 2).cloned() diff --git a/src/training_ui.rs b/src/training_ui.rs index 3cc7cf6..ab23915 100644 --- a/src/training_ui.rs +++ b/src/training_ui.rs @@ -10,7 +10,7 @@ use eframe::{ use egui_extras::RetainedImage; use image::RgbImage; -use crate::{image_processing::to_png_bytes, local_ocr}; +use crate::{image_processing::to_png_bytes, ocr}; #[derive(Default)] struct TrainingUi { @@ -24,7 +24,6 @@ struct TrainingUi { struct TrainingImage { img_file: PathBuf, data_file: PathBuf, - image: RgbImage, text: String, ui_image: RetainedImage, @@ -74,15 +73,14 @@ fn get_training_data() -> Vec { let ocr_value = String::from_utf8(std::fs::read(&ocr_file).unwrap()).unwrap(); let ui_image = load_retained_image(&image); - let char_images = local_ocr::bounding_box_images(&image) + let char_images = ocr::bounding_box_images(&image) .iter() .map(load_retained_image) .collect(); - let char_hashes = local_ocr::compute_box_hashes(&image); + let char_hashes = ocr::compute_box_hashes(&image); data.push(TrainingImage { img_file, data_file: ocr_file, - image, text: ocr_value, ui_image, char_images, @@ -170,7 +168,7 @@ impl eframe::App for TrainingUi { .write(true) .open("learned_chars.txt") .unwrap(); - file.write(buffer.as_bytes()).unwrap(); + file.write_all(buffer.as_bytes()).unwrap(); } current_image.ui_image.show(ui);