learned images

This commit is contained in:
Scott Pruett 2022-05-22 14:03:47 -04:00
parent 3121905620
commit 3bffdd1b8d
10 changed files with 115 additions and 58 deletions

View File

@ -43,5 +43,6 @@
"height": 43 "height": 43
} }
], ],
"ocr_server_endpoint": "http://localhost:3000/"
"ocr_server_endpoint": "https://tesserver.spruett.dev/"
} }

View File

@ -1,4 +1,5 @@
{ {
"learned_images": {}, "learned_images": {
"learned_tracks": {} },
"learned_tracks": {}
} }

View File

@ -21,9 +21,8 @@ fn get_raw_frame(capturer: &mut Capturer) -> Result<Vec<u8>> {
} }
} }
pub fn get_frame() -> Result<RgbImage> { pub fn get_frame(capturer: &mut Capturer) -> Result<RgbImage> {
let mut capturer = Capturer::new(Display::primary()?)?; let frame = get_raw_frame(capturer)?;
let frame = get_raw_frame(&mut capturer)?;
let mut image = RgbImage::new(capturer.width() as u32, capturer.height() as u32); let mut image = RgbImage::new(capturer.width() as u32, capturer.height() as u32);
let stride = frame.len() / capturer.height(); let stride = frame.len() / capturer.height();

View File

@ -41,7 +41,7 @@ fn load_or_make_default<T: DeserializeOwned>(path: &str, default: &str) -> Resul
if !file_path.exists() { if !file_path.exists() {
std::fs::write(&path, default)?; std::fs::write(&path, default)?;
} }
load_json_config(&path) load_json_config(path)
} }
fn load_json_config<T: DeserializeOwned>(path: &str) -> Result<T> { fn load_json_config<T: DeserializeOwned>(path: &str) -> Result<T> {

View File

@ -43,5 +43,5 @@
"height": 43 "height": 43
} }
], ],
"ocr_server_endpoint": "http://localhost:3000/" "ocr_server_endpoint": "https://tesserver.spruett.dev/"
} }

View File

@ -1,6 +1,6 @@
use std::{ use std::{
collections::HashMap, collections::HashMap,
time::{Duration, Instant}, time::{Duration, Instant}, thread,
}; };
use anyhow::Result; use anyhow::Result;
@ -8,9 +8,10 @@ use egui_extras::RetainedImage;
use image::RgbImage; use image::RgbImage;
use scrap::{Capturer, Display}; use scrap::{Capturer, Display};
use crate::{ use crate::{
capture, capture,
image_processing::{self, Region}, image_processing::{self, hash_image},
ocr, ocr,
state::{AppState, DebugOcrFrame, ParsedFrame, RaceState, SharedAppState}, state::{AppState, DebugOcrFrame, ParsedFrame, RaceState, SharedAppState},
}; };
@ -103,12 +104,15 @@ fn handle_new_frame(state: &mut AppState, frame: ParsedFrame, image: &RgbImage)
} }
} }
async fn run_loop_once(state: &SharedAppState) -> Result<()> { fn run_loop_once(state: &SharedAppState) -> Result<()> {
let mut capturer = Capturer::new(Display::primary()?)?;
let config = state.lock().unwrap().config.clone(); let config = state.lock().unwrap().config.clone();
let frame = capture::get_frame()?; let learned_config = state.lock().unwrap().learned.clone();
let ocr_results = ocr::ocr_all_regions(&frame, &config.ocr_regions).await; let frame = capture::get_frame(&mut capturer)?;
let ocr_results = ocr::ocr_all_regions(&frame, config.clone(), learned_config.clone());
let mut saved_frames = HashMap::new(); let mut saved_frames = HashMap::new();
if state.lock().unwrap().debug_frames { if state.lock().unwrap().debug_frames {
let hasher = img_hash::HasherConfig::new().to_hasher(); let hasher = img_hash::HasherConfig::new().to_hasher();
for region in &config.ocr_regions { for region in &config.ocr_regions {
@ -119,13 +123,7 @@ async fn run_loop_once(state: &SharedAppState) -> Result<()> {
&image_processing::to_png_bytes(&extracted), &image_processing::to_png_bytes(&extracted),
) )
.unwrap(); .unwrap();
let have_to_use_other_image_library_version = img_hash::image::RgbImage::from_raw( let hash = hash_image(&extracted);
extracted.width(),
extracted.height(),
extracted.as_raw().to_vec(),
)
.unwrap();
let hash = hasher.hash_image(&have_to_use_other_image_library_version);
saved_frames.insert( saved_frames.insert(
region.name.clone(), region.name.clone(),
DebugOcrFrame { DebugOcrFrame {
@ -146,11 +144,11 @@ async fn run_loop_once(state: &SharedAppState) -> Result<()> {
Ok(()) Ok(())
} }
pub async fn run_control_loop(state: SharedAppState) -> Result<()> { pub fn run_control_loop(state: SharedAppState) {
loop { loop {
if let Err(e) = run_loop_once(&state).await { if let Err(e) = run_loop_once(&state) {
eprintln!("Error in control loop: {:?}", e) eprintln!("Error in control loop: {:?}", e)
} }
tokio::time::sleep(Duration::from_millis(500)).await; thread::sleep(Duration::from_millis(500));
} }
} }

View File

@ -1,4 +1,3 @@
use anyhow::Result;
use image::{codecs::png::PngEncoder, ColorType, ImageEncoder, Rgb, RgbImage}; use image::{codecs::png::PngEncoder, ColorType, ImageEncoder, Rgb, RgbImage};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -64,3 +63,12 @@ pub fn to_png_bytes(image: &RgbImage) -> Vec<u8> {
.expect("failed encoding image to PNG"); .expect("failed encoding image to PNG");
buffer buffer
} }
pub fn hash_image(image: &RgbImage) -> String {
let hasher = img_hash::HasherConfig::new().to_hasher();
let have_to_use_other_image_library_version =
img_hash::image::RgbImage::from_raw(image.width(), image.height(), image.as_raw().to_vec())
.unwrap();
let hash = hasher.hash_image(&have_to_use_other_image_library_version);
hash.to_base64()
}

View File

@ -9,7 +9,7 @@ mod config;
use std::{ use std::{
sync::{Arc, Mutex}, sync::{Arc, Mutex},
time::Duration, time::Duration, thread,
}; };
use config::{Config, LearnedConfig}; use config::{Config, LearnedConfig};
@ -19,18 +19,15 @@ use eframe::{
}; };
use state::{AppState, RaceState, SharedAppState}; use state::{AppState, RaceState, SharedAppState};
#[tokio::main(flavor = "multi_thread", worker_threads = 8)] fn main() -> anyhow::Result<()> {
async fn main() -> anyhow::Result<()> {
let mut app_state = AppState::default(); let mut app_state = AppState::default();
app_state.config = Arc::new(Config::load().unwrap()); app_state.config = Arc::new(Config::load().unwrap());
app_state.learned = Arc::new(LearnedConfig::load().unwrap()); app_state.learned = Arc::new(LearnedConfig::load().unwrap());
let state = Arc::new(Mutex::new(app_state)); let state = Arc::new(Mutex::new(app_state));
{ {
let state = state.clone(); let state = state.clone();
let _ = tokio::spawn(async move { let _ = thread::spawn(move || {
control_loop::run_control_loop(state) control_loop::run_control_loop(state);
.await
.expect("control loop failed");
}); });
} }
@ -91,11 +88,16 @@ fn label_time_delta(ui: &mut Ui, time: Duration, old: Option<Duration>) {
struct MyApp { struct MyApp {
state: SharedAppState, state: SharedAppState,
config_load_err: Option<String>,
hash_to_learn: String,
value_to_learn: String,
} }
impl MyApp { impl MyApp {
pub fn new(state: SharedAppState) -> Self { pub fn new(state: SharedAppState) -> Self {
Self { state } Self { state, config_load_err: None, hash_to_learn: "".to_owned(), value_to_learn: "".to_owned() }
} }
} }
@ -110,7 +112,7 @@ fn show_race_state(ui: &mut Ui, race: &RaceState) {
ui.label("Tyres"); ui.label("Tyres");
ui.end_row(); ui.end_row();
for (i, lap) in race.laps.iter().enumerate() { for (i, lap) in race.laps.iter().enumerate() {
if let Some(lap_time) = *&lap.lap_time { if let Some(lap_time) = lap.lap_time {
let prev_lap = race.laps.get(i - 1); let prev_lap = race.laps.get(i - 1);
ui.label(format!("#{}", lap.lap.unwrap_or(i + 1))); ui.label(format!("#{}", lap.lap.unwrap_or(i + 1)));
@ -211,9 +213,41 @@ impl eframe::App for MyApp {
screenshots_sorted.sort_by_key(|(name, _)| name.clone()); screenshots_sorted.sort_by_key(|(name, _)| name.clone());
for (name, image) in screenshots_sorted { for (name, image) in screenshots_sorted {
ui.label(name); ui.label(name);
ui.label(image.img_hash.to_base64()); if ui.button(&image.img_hash).on_hover_text("Copy").clicked() {
ui.output().copied_text = image.img_hash.clone();
}
image.image.show_max_size(ui, ui.available_size()); image.image.show_max_size(ui, ui.available_size());
} }
if ui.button("Reload config").clicked() {
match Config::load() {
Ok(c) => {
state.config = Arc::new(c);
self.config_load_err = None;
}
Err(e) => {
self.config_load_err = Some(format!("failed to load config: {:?}", e));
}
}
}
if let Some(e) = &self.config_load_err {
ui.colored_label(Color32::RED, e);
}
ui.separator();
ui.label("Hash");
ui.text_edit_singleline(&mut self.hash_to_learn);
ui.label("Value");
ui.text_edit_singleline(&mut self.value_to_learn);
if ui.button("Learn").clicked() {
let mut learned_config = (*state.learned).clone();
learned_config.learned_images.insert(self.hash_to_learn.clone(), self.value_to_learn.clone());
learned_config.save().unwrap();
state.learned = Arc::new(learned_config);
self.hash_to_learn = "".to_owned();
self.value_to_learn = "".to_owned();
}
}); });
} }

View File

@ -1,10 +1,16 @@
use std::{collections::HashMap, sync::{Arc, Mutex}}; use std::{
collections::HashMap,
sync::{Arc, Mutex},
};
use anyhow::Result; use anyhow::Result;
use image::RgbImage; use image::RgbImage;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::image_processing::{Region, extract_region, filter_to_white}; use crate::{
config::{Config, LearnedConfig},
image_processing::{extract_region, filter_to_white, hash_image, Region},
};
#[derive(Serialize, Deserialize, Debug)] #[derive(Serialize, Deserialize, Debug)]
pub struct OcrRegion { pub struct OcrRegion {
@ -18,46 +24,56 @@ struct OcrResult {
error: Option<String>, error: Option<String>,
} }
async fn run_ocr(image: &RgbImage) -> Result<Vec<OcrRegion>> { async fn run_ocr(image: &RgbImage, url: &str) -> Result<Option<String>> {
let client = reqwest::Client::new(); let client = reqwest::Client::new();
let response = client let response = client
.post("http://localhost:3000/") .post(url)
.body(crate::image_processing::to_png_bytes(&image)) .body(crate::image_processing::to_png_bytes(image))
.send() .send()
.await?; .await?;
if !response.status().is_success() { if !response.status().is_success() {
eprintln!("failed to run OCR query");
anyhow::bail!("failed to run OCR query") anyhow::bail!("failed to run OCR query")
} }
let result: OcrResult = response.json().await?; let result: OcrResult = response.json().await?;
Ok(result.regions) let result = if result.regions.is_empty() {
None
} else {
let mut buffer = String::new();
for r in &result.regions {
buffer += &r.value;
}
Some(buffer)
};
Ok(result)
} }
pub async fn ocr_all_regions(image: &RgbImage, regions: &[Region]) -> HashMap<String, Option<String>> { #[tokio::main(flavor = "current_thread")]
pub async fn ocr_all_regions(
image: &RgbImage,
config: Arc<Config>,
learned: Arc<LearnedConfig>,
) -> HashMap<String, Option<String>> {
let results = Arc::new(Mutex::new(HashMap::new())); let results = Arc::new(Mutex::new(HashMap::new()));
let mut handles = Vec::new(); let mut handles = Vec::new();
for region in regions { for region in &config.ocr_regions {
let filtered_image = extract_region(image, region); let filtered_image = extract_region(image, region);
let region = region.clone(); let region = region.clone();
let results = results.clone(); let results = results.clone();
let config = config.clone();
let learned = learned.clone();
handles.push(tokio::spawn(async move { handles.push(tokio::spawn(async move {
let mut image = filtered_image; let mut image = filtered_image;
filter_to_white(&mut image); filter_to_white(&mut image);
let ocr_results = run_ocr(&image).await; let hash = hash_image(&image);
let value = match ocr_results { let value = if let Some(learned_value) = learned.learned_images.get(&hash) {
Ok(ocr_regions) => { Some(learned_value.clone())
if ocr_regions.is_empty() { } else {
None run_ocr(&image, &config.ocr_server_endpoint)
} else { .await
let mut out = String::new(); .unwrap_or(None)
for r in ocr_regions {
out += &r.value;
}
Some(out)
}
}
Err(_) => None
}; };
results.lock().unwrap().insert(region.name, value); results.lock().unwrap().insert(region.name, value);
})); }));
@ -68,4 +84,4 @@ pub async fn ocr_all_regions(image: &RgbImage, regions: &[Region]) -> HashMap<St
let results = results.lock().unwrap().clone(); let results = results.lock().unwrap().clone();
results results
} }

View File

@ -73,7 +73,7 @@ pub struct RaceState {
pub struct DebugOcrFrame { pub struct DebugOcrFrame {
pub image: RetainedImage, pub image: RetainedImage,
pub rgb_image: RgbImage, pub rgb_image: RgbImage,
pub img_hash: img_hash::ImageHash, pub img_hash: String,
} }
#[derive(Default)] #[derive(Default)]