learned images
This commit is contained in:
parent
3121905620
commit
3bffdd1b8d
|
@ -43,5 +43,6 @@
|
|||
"height": 43
|
||||
}
|
||||
],
|
||||
"ocr_server_endpoint": "http://localhost:3000/"
|
||||
|
||||
"ocr_server_endpoint": "https://tesserver.spruett.dev/"
|
||||
}
|
|
@ -1,4 +1,5 @@
|
|||
{
|
||||
"learned_images": {},
|
||||
"learned_tracks": {}
|
||||
"learned_images": {
|
||||
},
|
||||
"learned_tracks": {}
|
||||
}
|
|
@ -21,9 +21,8 @@ fn get_raw_frame(capturer: &mut Capturer) -> Result<Vec<u8>> {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn get_frame() -> Result<RgbImage> {
|
||||
let mut capturer = Capturer::new(Display::primary()?)?;
|
||||
let frame = get_raw_frame(&mut capturer)?;
|
||||
pub fn get_frame(capturer: &mut Capturer) -> Result<RgbImage> {
|
||||
let frame = get_raw_frame(capturer)?;
|
||||
let mut image = RgbImage::new(capturer.width() as u32, capturer.height() as u32);
|
||||
|
||||
let stride = frame.len() / capturer.height();
|
||||
|
|
|
@ -41,7 +41,7 @@ fn load_or_make_default<T: DeserializeOwned>(path: &str, default: &str) -> Resul
|
|||
if !file_path.exists() {
|
||||
std::fs::write(&path, default)?;
|
||||
}
|
||||
load_json_config(&path)
|
||||
load_json_config(path)
|
||||
}
|
||||
|
||||
fn load_json_config<T: DeserializeOwned>(path: &str) -> Result<T> {
|
||||
|
|
|
@ -43,5 +43,5 @@
|
|||
"height": 43
|
||||
}
|
||||
],
|
||||
"ocr_server_endpoint": "http://localhost:3000/"
|
||||
"ocr_server_endpoint": "https://tesserver.spruett.dev/"
|
||||
}
|
|
@ -1,6 +1,6 @@
|
|||
use std::{
|
||||
collections::HashMap,
|
||||
time::{Duration, Instant},
|
||||
time::{Duration, Instant}, thread,
|
||||
};
|
||||
|
||||
use anyhow::Result;
|
||||
|
@ -8,9 +8,10 @@ use egui_extras::RetainedImage;
|
|||
use image::RgbImage;
|
||||
use scrap::{Capturer, Display};
|
||||
|
||||
|
||||
use crate::{
|
||||
capture,
|
||||
image_processing::{self, Region},
|
||||
image_processing::{self, hash_image},
|
||||
ocr,
|
||||
state::{AppState, DebugOcrFrame, ParsedFrame, RaceState, SharedAppState},
|
||||
};
|
||||
|
@ -103,12 +104,15 @@ fn handle_new_frame(state: &mut AppState, frame: ParsedFrame, image: &RgbImage)
|
|||
}
|
||||
}
|
||||
|
||||
async fn run_loop_once(state: &SharedAppState) -> Result<()> {
|
||||
fn run_loop_once(state: &SharedAppState) -> Result<()> {
|
||||
let mut capturer = Capturer::new(Display::primary()?)?;
|
||||
let config = state.lock().unwrap().config.clone();
|
||||
let frame = capture::get_frame()?;
|
||||
let ocr_results = ocr::ocr_all_regions(&frame, &config.ocr_regions).await;
|
||||
let learned_config = state.lock().unwrap().learned.clone();
|
||||
let frame = capture::get_frame(&mut capturer)?;
|
||||
let ocr_results = ocr::ocr_all_regions(&frame, config.clone(), learned_config.clone());
|
||||
|
||||
let mut saved_frames = HashMap::new();
|
||||
|
||||
if state.lock().unwrap().debug_frames {
|
||||
let hasher = img_hash::HasherConfig::new().to_hasher();
|
||||
for region in &config.ocr_regions {
|
||||
|
@ -119,13 +123,7 @@ async fn run_loop_once(state: &SharedAppState) -> Result<()> {
|
|||
&image_processing::to_png_bytes(&extracted),
|
||||
)
|
||||
.unwrap();
|
||||
let have_to_use_other_image_library_version = img_hash::image::RgbImage::from_raw(
|
||||
extracted.width(),
|
||||
extracted.height(),
|
||||
extracted.as_raw().to_vec(),
|
||||
)
|
||||
.unwrap();
|
||||
let hash = hasher.hash_image(&have_to_use_other_image_library_version);
|
||||
let hash = hash_image(&extracted);
|
||||
saved_frames.insert(
|
||||
region.name.clone(),
|
||||
DebugOcrFrame {
|
||||
|
@ -146,11 +144,11 @@ async fn run_loop_once(state: &SharedAppState) -> Result<()> {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn run_control_loop(state: SharedAppState) -> Result<()> {
|
||||
pub fn run_control_loop(state: SharedAppState) {
|
||||
loop {
|
||||
if let Err(e) = run_loop_once(&state).await {
|
||||
if let Err(e) = run_loop_once(&state) {
|
||||
eprintln!("Error in control loop: {:?}", e)
|
||||
}
|
||||
tokio::time::sleep(Duration::from_millis(500)).await;
|
||||
thread::sleep(Duration::from_millis(500));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
use anyhow::Result;
|
||||
use image::{codecs::png::PngEncoder, ColorType, ImageEncoder, Rgb, RgbImage};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
|
@ -64,3 +63,12 @@ pub fn to_png_bytes(image: &RgbImage) -> Vec<u8> {
|
|||
.expect("failed encoding image to PNG");
|
||||
buffer
|
||||
}
|
||||
|
||||
pub fn hash_image(image: &RgbImage) -> String {
|
||||
let hasher = img_hash::HasherConfig::new().to_hasher();
|
||||
let have_to_use_other_image_library_version =
|
||||
img_hash::image::RgbImage::from_raw(image.width(), image.height(), image.as_raw().to_vec())
|
||||
.unwrap();
|
||||
let hash = hasher.hash_image(&have_to_use_other_image_library_version);
|
||||
hash.to_base64()
|
||||
}
|
||||
|
|
54
src/main.rs
54
src/main.rs
|
@ -9,7 +9,7 @@ mod config;
|
|||
|
||||
use std::{
|
||||
sync::{Arc, Mutex},
|
||||
time::Duration,
|
||||
time::Duration, thread,
|
||||
};
|
||||
|
||||
use config::{Config, LearnedConfig};
|
||||
|
@ -19,18 +19,15 @@ use eframe::{
|
|||
};
|
||||
use state::{AppState, RaceState, SharedAppState};
|
||||
|
||||
#[tokio::main(flavor = "multi_thread", worker_threads = 8)]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
fn main() -> anyhow::Result<()> {
|
||||
let mut app_state = AppState::default();
|
||||
app_state.config = Arc::new(Config::load().unwrap());
|
||||
app_state.learned = Arc::new(LearnedConfig::load().unwrap());
|
||||
let state = Arc::new(Mutex::new(app_state));
|
||||
{
|
||||
let state = state.clone();
|
||||
let _ = tokio::spawn(async move {
|
||||
control_loop::run_control_loop(state)
|
||||
.await
|
||||
.expect("control loop failed");
|
||||
let _ = thread::spawn(move || {
|
||||
control_loop::run_control_loop(state);
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -91,11 +88,16 @@ fn label_time_delta(ui: &mut Ui, time: Duration, old: Option<Duration>) {
|
|||
|
||||
struct MyApp {
|
||||
state: SharedAppState,
|
||||
|
||||
config_load_err: Option<String>,
|
||||
|
||||
hash_to_learn: String,
|
||||
value_to_learn: String,
|
||||
}
|
||||
|
||||
impl MyApp {
|
||||
pub fn new(state: SharedAppState) -> Self {
|
||||
Self { state }
|
||||
Self { state, config_load_err: None, hash_to_learn: "".to_owned(), value_to_learn: "".to_owned() }
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -110,7 +112,7 @@ fn show_race_state(ui: &mut Ui, race: &RaceState) {
|
|||
ui.label("Tyres");
|
||||
ui.end_row();
|
||||
for (i, lap) in race.laps.iter().enumerate() {
|
||||
if let Some(lap_time) = *&lap.lap_time {
|
||||
if let Some(lap_time) = lap.lap_time {
|
||||
let prev_lap = race.laps.get(i - 1);
|
||||
ui.label(format!("#{}", lap.lap.unwrap_or(i + 1)));
|
||||
|
||||
|
@ -211,9 +213,41 @@ impl eframe::App for MyApp {
|
|||
screenshots_sorted.sort_by_key(|(name, _)| name.clone());
|
||||
for (name, image) in screenshots_sorted {
|
||||
ui.label(name);
|
||||
ui.label(image.img_hash.to_base64());
|
||||
if ui.button(&image.img_hash).on_hover_text("Copy").clicked() {
|
||||
ui.output().copied_text = image.img_hash.clone();
|
||||
}
|
||||
image.image.show_max_size(ui, ui.available_size());
|
||||
}
|
||||
|
||||
if ui.button("Reload config").clicked() {
|
||||
match Config::load() {
|
||||
Ok(c) => {
|
||||
state.config = Arc::new(c);
|
||||
self.config_load_err = None;
|
||||
}
|
||||
Err(e) => {
|
||||
self.config_load_err = Some(format!("failed to load config: {:?}", e));
|
||||
}
|
||||
}
|
||||
}
|
||||
if let Some(e) = &self.config_load_err {
|
||||
ui.colored_label(Color32::RED, e);
|
||||
}
|
||||
|
||||
ui.separator();
|
||||
ui.label("Hash");
|
||||
ui.text_edit_singleline(&mut self.hash_to_learn);
|
||||
ui.label("Value");
|
||||
ui.text_edit_singleline(&mut self.value_to_learn);
|
||||
if ui.button("Learn").clicked() {
|
||||
let mut learned_config = (*state.learned).clone();
|
||||
learned_config.learned_images.insert(self.hash_to_learn.clone(), self.value_to_learn.clone());
|
||||
learned_config.save().unwrap();
|
||||
state.learned = Arc::new(learned_config);
|
||||
|
||||
self.hash_to_learn = "".to_owned();
|
||||
self.value_to_learn = "".to_owned();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
|
|
60
src/ocr.rs
60
src/ocr.rs
|
@ -1,10 +1,16 @@
|
|||
use std::{collections::HashMap, sync::{Arc, Mutex}};
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
sync::{Arc, Mutex},
|
||||
};
|
||||
|
||||
use anyhow::Result;
|
||||
use image::RgbImage;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::image_processing::{Region, extract_region, filter_to_white};
|
||||
use crate::{
|
||||
config::{Config, LearnedConfig},
|
||||
image_processing::{extract_region, filter_to_white, hash_image, Region},
|
||||
};
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct OcrRegion {
|
||||
|
@ -18,46 +24,56 @@ struct OcrResult {
|
|||
error: Option<String>,
|
||||
}
|
||||
|
||||
async fn run_ocr(image: &RgbImage) -> Result<Vec<OcrRegion>> {
|
||||
async fn run_ocr(image: &RgbImage, url: &str) -> Result<Option<String>> {
|
||||
let client = reqwest::Client::new();
|
||||
let response = client
|
||||
.post("http://localhost:3000/")
|
||||
.body(crate::image_processing::to_png_bytes(&image))
|
||||
.post(url)
|
||||
.body(crate::image_processing::to_png_bytes(image))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
eprintln!("failed to run OCR query");
|
||||
anyhow::bail!("failed to run OCR query")
|
||||
}
|
||||
let result: OcrResult = response.json().await?;
|
||||
Ok(result.regions)
|
||||
let result = if result.regions.is_empty() {
|
||||
None
|
||||
} else {
|
||||
let mut buffer = String::new();
|
||||
for r in &result.regions {
|
||||
buffer += &r.value;
|
||||
}
|
||||
Some(buffer)
|
||||
};
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
pub async fn ocr_all_regions(image: &RgbImage, regions: &[Region]) -> HashMap<String, Option<String>> {
|
||||
#[tokio::main(flavor = "current_thread")]
|
||||
pub async fn ocr_all_regions(
|
||||
image: &RgbImage,
|
||||
config: Arc<Config>,
|
||||
learned: Arc<LearnedConfig>,
|
||||
) -> HashMap<String, Option<String>> {
|
||||
let results = Arc::new(Mutex::new(HashMap::new()));
|
||||
|
||||
let mut handles = Vec::new();
|
||||
for region in regions {
|
||||
for region in &config.ocr_regions {
|
||||
let filtered_image = extract_region(image, region);
|
||||
let region = region.clone();
|
||||
let results = results.clone();
|
||||
let config = config.clone();
|
||||
let learned = learned.clone();
|
||||
handles.push(tokio::spawn(async move {
|
||||
let mut image = filtered_image;
|
||||
filter_to_white(&mut image);
|
||||
let ocr_results = run_ocr(&image).await;
|
||||
let value = match ocr_results {
|
||||
Ok(ocr_regions) => {
|
||||
if ocr_regions.is_empty() {
|
||||
None
|
||||
} else {
|
||||
let mut out = String::new();
|
||||
for r in ocr_regions {
|
||||
out += &r.value;
|
||||
}
|
||||
Some(out)
|
||||
}
|
||||
}
|
||||
Err(_) => None
|
||||
let hash = hash_image(&image);
|
||||
let value = if let Some(learned_value) = learned.learned_images.get(&hash) {
|
||||
Some(learned_value.clone())
|
||||
} else {
|
||||
run_ocr(&image, &config.ocr_server_endpoint)
|
||||
.await
|
||||
.unwrap_or(None)
|
||||
};
|
||||
results.lock().unwrap().insert(region.name, value);
|
||||
}));
|
||||
|
|
|
@ -73,7 +73,7 @@ pub struct RaceState {
|
|||
pub struct DebugOcrFrame {
|
||||
pub image: RetainedImage,
|
||||
pub rgb_image: RgbImage,
|
||||
pub img_hash: img_hash::ImageHash,
|
||||
pub img_hash: String,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
|
|
Loading…
Reference in New Issue