learned images

This commit is contained in:
Scott Pruett 2022-05-22 14:03:47 -04:00
parent 3121905620
commit 3bffdd1b8d
10 changed files with 115 additions and 58 deletions

View File

@ -43,5 +43,6 @@
"height": 43
}
],
"ocr_server_endpoint": "http://localhost:3000/"
"ocr_server_endpoint": "https://tesserver.spruett.dev/"
}

View File

@ -1,4 +1,5 @@
{
"learned_images": {},
"learned_images": {
},
"learned_tracks": {}
}

View File

@ -21,9 +21,8 @@ fn get_raw_frame(capturer: &mut Capturer) -> Result<Vec<u8>> {
}
}
pub fn get_frame() -> Result<RgbImage> {
let mut capturer = Capturer::new(Display::primary()?)?;
let frame = get_raw_frame(&mut capturer)?;
pub fn get_frame(capturer: &mut Capturer) -> Result<RgbImage> {
let frame = get_raw_frame(capturer)?;
let mut image = RgbImage::new(capturer.width() as u32, capturer.height() as u32);
let stride = frame.len() / capturer.height();

View File

@ -41,7 +41,7 @@ fn load_or_make_default<T: DeserializeOwned>(path: &str, default: &str) -> Resul
if !file_path.exists() {
std::fs::write(&path, default)?;
}
load_json_config(&path)
load_json_config(path)
}
fn load_json_config<T: DeserializeOwned>(path: &str) -> Result<T> {

View File

@ -43,5 +43,5 @@
"height": 43
}
],
"ocr_server_endpoint": "http://localhost:3000/"
"ocr_server_endpoint": "https://tesserver.spruett.dev/"
}

View File

@ -1,6 +1,6 @@
use std::{
collections::HashMap,
time::{Duration, Instant},
time::{Duration, Instant}, thread,
};
use anyhow::Result;
@ -8,9 +8,10 @@ use egui_extras::RetainedImage;
use image::RgbImage;
use scrap::{Capturer, Display};
use crate::{
capture,
image_processing::{self, Region},
image_processing::{self, hash_image},
ocr,
state::{AppState, DebugOcrFrame, ParsedFrame, RaceState, SharedAppState},
};
@ -103,12 +104,15 @@ fn handle_new_frame(state: &mut AppState, frame: ParsedFrame, image: &RgbImage)
}
}
async fn run_loop_once(state: &SharedAppState) -> Result<()> {
fn run_loop_once(state: &SharedAppState) -> Result<()> {
let mut capturer = Capturer::new(Display::primary()?)?;
let config = state.lock().unwrap().config.clone();
let frame = capture::get_frame()?;
let ocr_results = ocr::ocr_all_regions(&frame, &config.ocr_regions).await;
let learned_config = state.lock().unwrap().learned.clone();
let frame = capture::get_frame(&mut capturer)?;
let ocr_results = ocr::ocr_all_regions(&frame, config.clone(), learned_config.clone());
let mut saved_frames = HashMap::new();
if state.lock().unwrap().debug_frames {
let hasher = img_hash::HasherConfig::new().to_hasher();
for region in &config.ocr_regions {
@ -119,13 +123,7 @@ async fn run_loop_once(state: &SharedAppState) -> Result<()> {
&image_processing::to_png_bytes(&extracted),
)
.unwrap();
let have_to_use_other_image_library_version = img_hash::image::RgbImage::from_raw(
extracted.width(),
extracted.height(),
extracted.as_raw().to_vec(),
)
.unwrap();
let hash = hasher.hash_image(&have_to_use_other_image_library_version);
let hash = hash_image(&extracted);
saved_frames.insert(
region.name.clone(),
DebugOcrFrame {
@ -146,11 +144,11 @@ async fn run_loop_once(state: &SharedAppState) -> Result<()> {
Ok(())
}
pub async fn run_control_loop(state: SharedAppState) -> Result<()> {
pub fn run_control_loop(state: SharedAppState) {
loop {
if let Err(e) = run_loop_once(&state).await {
if let Err(e) = run_loop_once(&state) {
eprintln!("Error in control loop: {:?}", e)
}
tokio::time::sleep(Duration::from_millis(500)).await;
thread::sleep(Duration::from_millis(500));
}
}

View File

@ -1,4 +1,3 @@
use anyhow::Result;
use image::{codecs::png::PngEncoder, ColorType, ImageEncoder, Rgb, RgbImage};
use serde::{Deserialize, Serialize};
@ -64,3 +63,12 @@ pub fn to_png_bytes(image: &RgbImage) -> Vec<u8> {
.expect("failed encoding image to PNG");
buffer
}
pub fn hash_image(image: &RgbImage) -> String {
let hasher = img_hash::HasherConfig::new().to_hasher();
let have_to_use_other_image_library_version =
img_hash::image::RgbImage::from_raw(image.width(), image.height(), image.as_raw().to_vec())
.unwrap();
let hash = hasher.hash_image(&have_to_use_other_image_library_version);
hash.to_base64()
}

View File

@ -9,7 +9,7 @@ mod config;
use std::{
sync::{Arc, Mutex},
time::Duration,
time::Duration, thread,
};
use config::{Config, LearnedConfig};
@ -19,18 +19,15 @@ use eframe::{
};
use state::{AppState, RaceState, SharedAppState};
#[tokio::main(flavor = "multi_thread", worker_threads = 8)]
async fn main() -> anyhow::Result<()> {
fn main() -> anyhow::Result<()> {
let mut app_state = AppState::default();
app_state.config = Arc::new(Config::load().unwrap());
app_state.learned = Arc::new(LearnedConfig::load().unwrap());
let state = Arc::new(Mutex::new(app_state));
{
let state = state.clone();
let _ = tokio::spawn(async move {
control_loop::run_control_loop(state)
.await
.expect("control loop failed");
let _ = thread::spawn(move || {
control_loop::run_control_loop(state);
});
}
@ -91,11 +88,16 @@ fn label_time_delta(ui: &mut Ui, time: Duration, old: Option<Duration>) {
struct MyApp {
state: SharedAppState,
config_load_err: Option<String>,
hash_to_learn: String,
value_to_learn: String,
}
impl MyApp {
pub fn new(state: SharedAppState) -> Self {
Self { state }
Self { state, config_load_err: None, hash_to_learn: "".to_owned(), value_to_learn: "".to_owned() }
}
}
@ -110,7 +112,7 @@ fn show_race_state(ui: &mut Ui, race: &RaceState) {
ui.label("Tyres");
ui.end_row();
for (i, lap) in race.laps.iter().enumerate() {
if let Some(lap_time) = *&lap.lap_time {
if let Some(lap_time) = lap.lap_time {
let prev_lap = race.laps.get(i - 1);
ui.label(format!("#{}", lap.lap.unwrap_or(i + 1)));
@ -211,9 +213,41 @@ impl eframe::App for MyApp {
screenshots_sorted.sort_by_key(|(name, _)| name.clone());
for (name, image) in screenshots_sorted {
ui.label(name);
ui.label(image.img_hash.to_base64());
if ui.button(&image.img_hash).on_hover_text("Copy").clicked() {
ui.output().copied_text = image.img_hash.clone();
}
image.image.show_max_size(ui, ui.available_size());
}
if ui.button("Reload config").clicked() {
match Config::load() {
Ok(c) => {
state.config = Arc::new(c);
self.config_load_err = None;
}
Err(e) => {
self.config_load_err = Some(format!("failed to load config: {:?}", e));
}
}
}
if let Some(e) = &self.config_load_err {
ui.colored_label(Color32::RED, e);
}
ui.separator();
ui.label("Hash");
ui.text_edit_singleline(&mut self.hash_to_learn);
ui.label("Value");
ui.text_edit_singleline(&mut self.value_to_learn);
if ui.button("Learn").clicked() {
let mut learned_config = (*state.learned).clone();
learned_config.learned_images.insert(self.hash_to_learn.clone(), self.value_to_learn.clone());
learned_config.save().unwrap();
state.learned = Arc::new(learned_config);
self.hash_to_learn = "".to_owned();
self.value_to_learn = "".to_owned();
}
});
}

View File

@ -1,10 +1,16 @@
use std::{collections::HashMap, sync::{Arc, Mutex}};
use std::{
collections::HashMap,
sync::{Arc, Mutex},
};
use anyhow::Result;
use image::RgbImage;
use serde::{Deserialize, Serialize};
use crate::image_processing::{Region, extract_region, filter_to_white};
use crate::{
config::{Config, LearnedConfig},
image_processing::{extract_region, filter_to_white, hash_image, Region},
};
#[derive(Serialize, Deserialize, Debug)]
pub struct OcrRegion {
@ -18,46 +24,56 @@ struct OcrResult {
error: Option<String>,
}
async fn run_ocr(image: &RgbImage) -> Result<Vec<OcrRegion>> {
async fn run_ocr(image: &RgbImage, url: &str) -> Result<Option<String>> {
let client = reqwest::Client::new();
let response = client
.post("http://localhost:3000/")
.body(crate::image_processing::to_png_bytes(&image))
.post(url)
.body(crate::image_processing::to_png_bytes(image))
.send()
.await?;
if !response.status().is_success() {
eprintln!("failed to run OCR query");
anyhow::bail!("failed to run OCR query")
}
let result: OcrResult = response.json().await?;
Ok(result.regions)
let result = if result.regions.is_empty() {
None
} else {
let mut buffer = String::new();
for r in &result.regions {
buffer += &r.value;
}
Some(buffer)
};
Ok(result)
}
pub async fn ocr_all_regions(image: &RgbImage, regions: &[Region]) -> HashMap<String, Option<String>> {
#[tokio::main(flavor = "current_thread")]
pub async fn ocr_all_regions(
image: &RgbImage,
config: Arc<Config>,
learned: Arc<LearnedConfig>,
) -> HashMap<String, Option<String>> {
let results = Arc::new(Mutex::new(HashMap::new()));
let mut handles = Vec::new();
for region in regions {
for region in &config.ocr_regions {
let filtered_image = extract_region(image, region);
let region = region.clone();
let results = results.clone();
let config = config.clone();
let learned = learned.clone();
handles.push(tokio::spawn(async move {
let mut image = filtered_image;
filter_to_white(&mut image);
let ocr_results = run_ocr(&image).await;
let value = match ocr_results {
Ok(ocr_regions) => {
if ocr_regions.is_empty() {
None
let hash = hash_image(&image);
let value = if let Some(learned_value) = learned.learned_images.get(&hash) {
Some(learned_value.clone())
} else {
let mut out = String::new();
for r in ocr_regions {
out += &r.value;
}
Some(out)
}
}
Err(_) => None
run_ocr(&image, &config.ocr_server_endpoint)
.await
.unwrap_or(None)
};
results.lock().unwrap().insert(region.name, value);
}));

View File

@ -73,7 +73,7 @@ pub struct RaceState {
pub struct DebugOcrFrame {
pub image: RetainedImage,
pub rgb_image: RgbImage,
pub img_hash: img_hash::ImageHash,
pub img_hash: String,
}
#[derive(Default)]