learned images
This commit is contained in:
parent
3121905620
commit
3bffdd1b8d
|
@ -43,5 +43,6 @@
|
||||||
"height": 43
|
"height": 43
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"ocr_server_endpoint": "http://localhost:3000/"
|
|
||||||
|
"ocr_server_endpoint": "https://tesserver.spruett.dev/"
|
||||||
}
|
}
|
|
@ -1,4 +1,5 @@
|
||||||
{
|
{
|
||||||
"learned_images": {},
|
"learned_images": {
|
||||||
"learned_tracks": {}
|
},
|
||||||
|
"learned_tracks": {}
|
||||||
}
|
}
|
|
@ -21,9 +21,8 @@ fn get_raw_frame(capturer: &mut Capturer) -> Result<Vec<u8>> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_frame() -> Result<RgbImage> {
|
pub fn get_frame(capturer: &mut Capturer) -> Result<RgbImage> {
|
||||||
let mut capturer = Capturer::new(Display::primary()?)?;
|
let frame = get_raw_frame(capturer)?;
|
||||||
let frame = get_raw_frame(&mut capturer)?;
|
|
||||||
let mut image = RgbImage::new(capturer.width() as u32, capturer.height() as u32);
|
let mut image = RgbImage::new(capturer.width() as u32, capturer.height() as u32);
|
||||||
|
|
||||||
let stride = frame.len() / capturer.height();
|
let stride = frame.len() / capturer.height();
|
||||||
|
|
|
@ -41,7 +41,7 @@ fn load_or_make_default<T: DeserializeOwned>(path: &str, default: &str) -> Resul
|
||||||
if !file_path.exists() {
|
if !file_path.exists() {
|
||||||
std::fs::write(&path, default)?;
|
std::fs::write(&path, default)?;
|
||||||
}
|
}
|
||||||
load_json_config(&path)
|
load_json_config(path)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn load_json_config<T: DeserializeOwned>(path: &str) -> Result<T> {
|
fn load_json_config<T: DeserializeOwned>(path: &str) -> Result<T> {
|
||||||
|
|
|
@ -43,5 +43,5 @@
|
||||||
"height": 43
|
"height": 43
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"ocr_server_endpoint": "http://localhost:3000/"
|
"ocr_server_endpoint": "https://tesserver.spruett.dev/"
|
||||||
}
|
}
|
|
@ -1,6 +1,6 @@
|
||||||
use std::{
|
use std::{
|
||||||
collections::HashMap,
|
collections::HashMap,
|
||||||
time::{Duration, Instant},
|
time::{Duration, Instant}, thread,
|
||||||
};
|
};
|
||||||
|
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
|
@ -8,9 +8,10 @@ use egui_extras::RetainedImage;
|
||||||
use image::RgbImage;
|
use image::RgbImage;
|
||||||
use scrap::{Capturer, Display};
|
use scrap::{Capturer, Display};
|
||||||
|
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
capture,
|
capture,
|
||||||
image_processing::{self, Region},
|
image_processing::{self, hash_image},
|
||||||
ocr,
|
ocr,
|
||||||
state::{AppState, DebugOcrFrame, ParsedFrame, RaceState, SharedAppState},
|
state::{AppState, DebugOcrFrame, ParsedFrame, RaceState, SharedAppState},
|
||||||
};
|
};
|
||||||
|
@ -103,12 +104,15 @@ fn handle_new_frame(state: &mut AppState, frame: ParsedFrame, image: &RgbImage)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn run_loop_once(state: &SharedAppState) -> Result<()> {
|
fn run_loop_once(state: &SharedAppState) -> Result<()> {
|
||||||
|
let mut capturer = Capturer::new(Display::primary()?)?;
|
||||||
let config = state.lock().unwrap().config.clone();
|
let config = state.lock().unwrap().config.clone();
|
||||||
let frame = capture::get_frame()?;
|
let learned_config = state.lock().unwrap().learned.clone();
|
||||||
let ocr_results = ocr::ocr_all_regions(&frame, &config.ocr_regions).await;
|
let frame = capture::get_frame(&mut capturer)?;
|
||||||
|
let ocr_results = ocr::ocr_all_regions(&frame, config.clone(), learned_config.clone());
|
||||||
|
|
||||||
let mut saved_frames = HashMap::new();
|
let mut saved_frames = HashMap::new();
|
||||||
|
|
||||||
if state.lock().unwrap().debug_frames {
|
if state.lock().unwrap().debug_frames {
|
||||||
let hasher = img_hash::HasherConfig::new().to_hasher();
|
let hasher = img_hash::HasherConfig::new().to_hasher();
|
||||||
for region in &config.ocr_regions {
|
for region in &config.ocr_regions {
|
||||||
|
@ -119,13 +123,7 @@ async fn run_loop_once(state: &SharedAppState) -> Result<()> {
|
||||||
&image_processing::to_png_bytes(&extracted),
|
&image_processing::to_png_bytes(&extracted),
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let have_to_use_other_image_library_version = img_hash::image::RgbImage::from_raw(
|
let hash = hash_image(&extracted);
|
||||||
extracted.width(),
|
|
||||||
extracted.height(),
|
|
||||||
extracted.as_raw().to_vec(),
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
let hash = hasher.hash_image(&have_to_use_other_image_library_version);
|
|
||||||
saved_frames.insert(
|
saved_frames.insert(
|
||||||
region.name.clone(),
|
region.name.clone(),
|
||||||
DebugOcrFrame {
|
DebugOcrFrame {
|
||||||
|
@ -146,11 +144,11 @@ async fn run_loop_once(state: &SharedAppState) -> Result<()> {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn run_control_loop(state: SharedAppState) -> Result<()> {
|
pub fn run_control_loop(state: SharedAppState) {
|
||||||
loop {
|
loop {
|
||||||
if let Err(e) = run_loop_once(&state).await {
|
if let Err(e) = run_loop_once(&state) {
|
||||||
eprintln!("Error in control loop: {:?}", e)
|
eprintln!("Error in control loop: {:?}", e)
|
||||||
}
|
}
|
||||||
tokio::time::sleep(Duration::from_millis(500)).await;
|
thread::sleep(Duration::from_millis(500));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
use anyhow::Result;
|
|
||||||
use image::{codecs::png::PngEncoder, ColorType, ImageEncoder, Rgb, RgbImage};
|
use image::{codecs::png::PngEncoder, ColorType, ImageEncoder, Rgb, RgbImage};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
@ -64,3 +63,12 @@ pub fn to_png_bytes(image: &RgbImage) -> Vec<u8> {
|
||||||
.expect("failed encoding image to PNG");
|
.expect("failed encoding image to PNG");
|
||||||
buffer
|
buffer
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn hash_image(image: &RgbImage) -> String {
|
||||||
|
let hasher = img_hash::HasherConfig::new().to_hasher();
|
||||||
|
let have_to_use_other_image_library_version =
|
||||||
|
img_hash::image::RgbImage::from_raw(image.width(), image.height(), image.as_raw().to_vec())
|
||||||
|
.unwrap();
|
||||||
|
let hash = hasher.hash_image(&have_to_use_other_image_library_version);
|
||||||
|
hash.to_base64()
|
||||||
|
}
|
||||||
|
|
54
src/main.rs
54
src/main.rs
|
@ -9,7 +9,7 @@ mod config;
|
||||||
|
|
||||||
use std::{
|
use std::{
|
||||||
sync::{Arc, Mutex},
|
sync::{Arc, Mutex},
|
||||||
time::Duration,
|
time::Duration, thread,
|
||||||
};
|
};
|
||||||
|
|
||||||
use config::{Config, LearnedConfig};
|
use config::{Config, LearnedConfig};
|
||||||
|
@ -19,18 +19,15 @@ use eframe::{
|
||||||
};
|
};
|
||||||
use state::{AppState, RaceState, SharedAppState};
|
use state::{AppState, RaceState, SharedAppState};
|
||||||
|
|
||||||
#[tokio::main(flavor = "multi_thread", worker_threads = 8)]
|
fn main() -> anyhow::Result<()> {
|
||||||
async fn main() -> anyhow::Result<()> {
|
|
||||||
let mut app_state = AppState::default();
|
let mut app_state = AppState::default();
|
||||||
app_state.config = Arc::new(Config::load().unwrap());
|
app_state.config = Arc::new(Config::load().unwrap());
|
||||||
app_state.learned = Arc::new(LearnedConfig::load().unwrap());
|
app_state.learned = Arc::new(LearnedConfig::load().unwrap());
|
||||||
let state = Arc::new(Mutex::new(app_state));
|
let state = Arc::new(Mutex::new(app_state));
|
||||||
{
|
{
|
||||||
let state = state.clone();
|
let state = state.clone();
|
||||||
let _ = tokio::spawn(async move {
|
let _ = thread::spawn(move || {
|
||||||
control_loop::run_control_loop(state)
|
control_loop::run_control_loop(state);
|
||||||
.await
|
|
||||||
.expect("control loop failed");
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -91,11 +88,16 @@ fn label_time_delta(ui: &mut Ui, time: Duration, old: Option<Duration>) {
|
||||||
|
|
||||||
struct MyApp {
|
struct MyApp {
|
||||||
state: SharedAppState,
|
state: SharedAppState,
|
||||||
|
|
||||||
|
config_load_err: Option<String>,
|
||||||
|
|
||||||
|
hash_to_learn: String,
|
||||||
|
value_to_learn: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl MyApp {
|
impl MyApp {
|
||||||
pub fn new(state: SharedAppState) -> Self {
|
pub fn new(state: SharedAppState) -> Self {
|
||||||
Self { state }
|
Self { state, config_load_err: None, hash_to_learn: "".to_owned(), value_to_learn: "".to_owned() }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -110,7 +112,7 @@ fn show_race_state(ui: &mut Ui, race: &RaceState) {
|
||||||
ui.label("Tyres");
|
ui.label("Tyres");
|
||||||
ui.end_row();
|
ui.end_row();
|
||||||
for (i, lap) in race.laps.iter().enumerate() {
|
for (i, lap) in race.laps.iter().enumerate() {
|
||||||
if let Some(lap_time) = *&lap.lap_time {
|
if let Some(lap_time) = lap.lap_time {
|
||||||
let prev_lap = race.laps.get(i - 1);
|
let prev_lap = race.laps.get(i - 1);
|
||||||
ui.label(format!("#{}", lap.lap.unwrap_or(i + 1)));
|
ui.label(format!("#{}", lap.lap.unwrap_or(i + 1)));
|
||||||
|
|
||||||
|
@ -211,9 +213,41 @@ impl eframe::App for MyApp {
|
||||||
screenshots_sorted.sort_by_key(|(name, _)| name.clone());
|
screenshots_sorted.sort_by_key(|(name, _)| name.clone());
|
||||||
for (name, image) in screenshots_sorted {
|
for (name, image) in screenshots_sorted {
|
||||||
ui.label(name);
|
ui.label(name);
|
||||||
ui.label(image.img_hash.to_base64());
|
if ui.button(&image.img_hash).on_hover_text("Copy").clicked() {
|
||||||
|
ui.output().copied_text = image.img_hash.clone();
|
||||||
|
}
|
||||||
image.image.show_max_size(ui, ui.available_size());
|
image.image.show_max_size(ui, ui.available_size());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if ui.button("Reload config").clicked() {
|
||||||
|
match Config::load() {
|
||||||
|
Ok(c) => {
|
||||||
|
state.config = Arc::new(c);
|
||||||
|
self.config_load_err = None;
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
self.config_load_err = Some(format!("failed to load config: {:?}", e));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if let Some(e) = &self.config_load_err {
|
||||||
|
ui.colored_label(Color32::RED, e);
|
||||||
|
}
|
||||||
|
|
||||||
|
ui.separator();
|
||||||
|
ui.label("Hash");
|
||||||
|
ui.text_edit_singleline(&mut self.hash_to_learn);
|
||||||
|
ui.label("Value");
|
||||||
|
ui.text_edit_singleline(&mut self.value_to_learn);
|
||||||
|
if ui.button("Learn").clicked() {
|
||||||
|
let mut learned_config = (*state.learned).clone();
|
||||||
|
learned_config.learned_images.insert(self.hash_to_learn.clone(), self.value_to_learn.clone());
|
||||||
|
learned_config.save().unwrap();
|
||||||
|
state.learned = Arc::new(learned_config);
|
||||||
|
|
||||||
|
self.hash_to_learn = "".to_owned();
|
||||||
|
self.value_to_learn = "".to_owned();
|
||||||
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
60
src/ocr.rs
60
src/ocr.rs
|
@ -1,10 +1,16 @@
|
||||||
use std::{collections::HashMap, sync::{Arc, Mutex}};
|
use std::{
|
||||||
|
collections::HashMap,
|
||||||
|
sync::{Arc, Mutex},
|
||||||
|
};
|
||||||
|
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use image::RgbImage;
|
use image::RgbImage;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::image_processing::{Region, extract_region, filter_to_white};
|
use crate::{
|
||||||
|
config::{Config, LearnedConfig},
|
||||||
|
image_processing::{extract_region, filter_to_white, hash_image, Region},
|
||||||
|
};
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
pub struct OcrRegion {
|
pub struct OcrRegion {
|
||||||
|
@ -18,46 +24,56 @@ struct OcrResult {
|
||||||
error: Option<String>,
|
error: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn run_ocr(image: &RgbImage) -> Result<Vec<OcrRegion>> {
|
async fn run_ocr(image: &RgbImage, url: &str) -> Result<Option<String>> {
|
||||||
let client = reqwest::Client::new();
|
let client = reqwest::Client::new();
|
||||||
let response = client
|
let response = client
|
||||||
.post("http://localhost:3000/")
|
.post(url)
|
||||||
.body(crate::image_processing::to_png_bytes(&image))
|
.body(crate::image_processing::to_png_bytes(image))
|
||||||
.send()
|
.send()
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
if !response.status().is_success() {
|
if !response.status().is_success() {
|
||||||
|
eprintln!("failed to run OCR query");
|
||||||
anyhow::bail!("failed to run OCR query")
|
anyhow::bail!("failed to run OCR query")
|
||||||
}
|
}
|
||||||
let result: OcrResult = response.json().await?;
|
let result: OcrResult = response.json().await?;
|
||||||
Ok(result.regions)
|
let result = if result.regions.is_empty() {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
let mut buffer = String::new();
|
||||||
|
for r in &result.regions {
|
||||||
|
buffer += &r.value;
|
||||||
|
}
|
||||||
|
Some(buffer)
|
||||||
|
};
|
||||||
|
Ok(result)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn ocr_all_regions(image: &RgbImage, regions: &[Region]) -> HashMap<String, Option<String>> {
|
#[tokio::main(flavor = "current_thread")]
|
||||||
|
pub async fn ocr_all_regions(
|
||||||
|
image: &RgbImage,
|
||||||
|
config: Arc<Config>,
|
||||||
|
learned: Arc<LearnedConfig>,
|
||||||
|
) -> HashMap<String, Option<String>> {
|
||||||
let results = Arc::new(Mutex::new(HashMap::new()));
|
let results = Arc::new(Mutex::new(HashMap::new()));
|
||||||
|
|
||||||
let mut handles = Vec::new();
|
let mut handles = Vec::new();
|
||||||
for region in regions {
|
for region in &config.ocr_regions {
|
||||||
let filtered_image = extract_region(image, region);
|
let filtered_image = extract_region(image, region);
|
||||||
let region = region.clone();
|
let region = region.clone();
|
||||||
let results = results.clone();
|
let results = results.clone();
|
||||||
|
let config = config.clone();
|
||||||
|
let learned = learned.clone();
|
||||||
handles.push(tokio::spawn(async move {
|
handles.push(tokio::spawn(async move {
|
||||||
let mut image = filtered_image;
|
let mut image = filtered_image;
|
||||||
filter_to_white(&mut image);
|
filter_to_white(&mut image);
|
||||||
let ocr_results = run_ocr(&image).await;
|
let hash = hash_image(&image);
|
||||||
let value = match ocr_results {
|
let value = if let Some(learned_value) = learned.learned_images.get(&hash) {
|
||||||
Ok(ocr_regions) => {
|
Some(learned_value.clone())
|
||||||
if ocr_regions.is_empty() {
|
} else {
|
||||||
None
|
run_ocr(&image, &config.ocr_server_endpoint)
|
||||||
} else {
|
.await
|
||||||
let mut out = String::new();
|
.unwrap_or(None)
|
||||||
for r in ocr_regions {
|
|
||||||
out += &r.value;
|
|
||||||
}
|
|
||||||
Some(out)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(_) => None
|
|
||||||
};
|
};
|
||||||
results.lock().unwrap().insert(region.name, value);
|
results.lock().unwrap().insert(region.name, value);
|
||||||
}));
|
}));
|
||||||
|
|
|
@ -73,7 +73,7 @@ pub struct RaceState {
|
||||||
pub struct DebugOcrFrame {
|
pub struct DebugOcrFrame {
|
||||||
pub image: RetainedImage,
|
pub image: RetainedImage,
|
||||||
pub rgb_image: RgbImage,
|
pub rgb_image: RgbImage,
|
||||||
pub img_hash: img_hash::ImageHash,
|
pub img_hash: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Default)]
|
#[derive(Default)]
|
||||||
|
|
Loading…
Reference in New Issue