use std::{ collections::HashMap, sync::{Arc, Mutex, RwLock}, }; use anyhow::Result; use image::RgbImage; use serde::{Deserialize, Serialize}; use crate::{ config::{Config, LearnedConfig}, image_processing::{extract_and_filter, hash_image}, state::OcrCache, }; #[derive(Serialize, Deserialize, Debug)] pub struct OcrRegion { pub confidence: f64, pub value: String, } #[derive(Serialize, Deserialize, Debug)] struct OcrResult { regions: Vec, error: Option, } async fn run_ocr(image: &RgbImage, url: &str) -> Result> { let client = reqwest::Client::new(); let response = client .post(url) .body(crate::image_processing::to_png_bytes(image)) .send() .await?; if !response.status().is_success() { eprintln!("failed to run OCR query"); anyhow::bail!("failed to run OCR query") } let result: OcrResult = response.json().await?; let result = if result.regions.is_empty() { None } else { let mut buffer = String::new(); for r in &result.regions { buffer += &r.value; } Some(buffer) }; Ok(result) } async fn run_ocr_cached( ocr_cache: Arc>>>, hash: String, region: &crate::image_processing::Region, config: Arc, filtered_image: &image::ImageBuffer, Vec>, ) -> Option { let cached = { let locked = ocr_cache.read().unwrap(); locked.get(&hash).cloned() }; let use_cache = region.use_ocr_cache.unwrap_or(true) && config.use_ocr_cache.unwrap_or(true); if let Some(cached) = cached { if use_cache { return cached; } } match run_ocr(filtered_image, &config.ocr_server_endpoint).await { Ok(v) => { if use_cache { ocr_cache.write().unwrap().insert(hash.clone(), v.clone()); } v } Err(_) => None, } } #[tokio::main(flavor = "current_thread")] pub async fn ocr_all_regions( image: &RgbImage, config: Arc, learned: Arc, ocr_cache: Arc, should_sample: bool ) -> HashMap> { let results = Arc::new(Mutex::new(HashMap::new())); let mut handles = Vec::new(); for region in &config.ocr_regions { let filtered_image = extract_and_filter(image, region); let region = region.clone(); let results = results.clone(); let config = config.clone(); let learned = learned.clone(); let ocr_cache = ocr_cache.clone(); handles.push(tokio::spawn(async move { let filtered_image = filtered_image; let hash = hash_image(&filtered_image); let value = if let Some(learned_value) = learned.learned_images.get(&hash) { Some(learned_value.clone()) } else { run_ocr_cached(ocr_cache, hash, ®ion, config.clone(), &filtered_image).await }; if let Some(sample_fraction) = &config.dump_frame_fraction { if rand::random::() < *sample_fraction { let file_id = rand::random::(); let img_filename = format!("ocr_data/{}.png", file_id); filtered_image.save(img_filename).unwrap(); let value_filename = format!("ocr_data/{}.txt", file_id); std::fs::write(value_filename, value.clone().unwrap_or_default()).unwrap(); } } results.lock().unwrap().insert(region.name, value); })); } for handle in handles { handle.await.expect("failed to join task in OCR"); } let results = results.lock().unwrap().clone(); results }