rework learned tracks
This commit is contained in:
parent
db2c73d0c1
commit
eaed474ccf
|
@ -7,17 +7,18 @@ use std::{
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use egui_extras::RetainedImage;
|
use egui_extras::RetainedImage;
|
||||||
use image::RgbImage;
|
use image::RgbImage;
|
||||||
use img_hash::ImageHash;
|
|
||||||
use scrap::{Capturer, Display};
|
use scrap::{Capturer, Display};
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
capture,
|
capture,
|
||||||
config::{Config, LearnedConfig},
|
config::Config,
|
||||||
image_processing::{self, extract_and_filter, hash_image, Region, to_png_bytes},
|
image_processing::{self, extract_and_filter, hash_image, Region, to_png_bytes},
|
||||||
ocr,
|
ocr,
|
||||||
state::{AppState, DebugOcrFrame, LapState, RaceState, SharedAppState},
|
state::{AppState, DebugOcrFrame, LapState, RaceState, SharedAppState}, learned_tracks::get_track_hash,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
fn is_finished_lap(state: &AppState, frame: &LapState) -> bool {
|
fn is_finished_lap(state: &AppState, frame: &LapState) -> bool {
|
||||||
if let Some(race) = &state.current_race {
|
if let Some(race) = &state.current_race {
|
||||||
if let Some(last_finish) = &race.last_lap_record_time {
|
if let Some(last_finish) = &race.last_lap_record_time {
|
||||||
|
@ -62,25 +63,6 @@ fn merge_frames(prev: &LapState, next: &LapState) -> LapState {
|
||||||
..Default::default()
|
..Default::default()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_track_hash(config: &Config, image: &RgbImage) -> Option<String> {
|
|
||||||
let track_region = config.track_region.as_ref()?;
|
|
||||||
let extracted = extract_and_filter(image, track_region);
|
|
||||||
Some(hash_image(&extracted))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn get_track_name(learned: &LearnedConfig, hash: &Option<String>, config: &Config) -> Option<String> {
|
|
||||||
let hash = hash.as_ref()?;
|
|
||||||
for (learned_hash_b64, learned_track) in &learned.learned_tracks {
|
|
||||||
let learned_hash: ImageHash<Vec<u8>> = img_hash::ImageHash::from_base64(learned_hash_b64).ok()?;
|
|
||||||
let current_hash: ImageHash<Vec<u8>> = img_hash::ImageHash::from_base64(hash).ok()?;
|
|
||||||
if current_hash.dist(&learned_hash) <= config.track_recognition_threshold.unwrap_or(10) {
|
|
||||||
return Some(learned_track.to_owned())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
learned.learned_tracks.get(hash).cloned()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn handle_new_frame(state: &mut AppState, frame: LapState, image: RgbImage) {
|
fn handle_new_frame(state: &mut AppState, frame: LapState, image: RgbImage) {
|
||||||
if frame.lap_time.is_some() {
|
if frame.lap_time.is_some() {
|
||||||
state.last_frame = Some(frame.clone());
|
state.last_frame = Some(frame.clone());
|
||||||
|
@ -88,7 +70,8 @@ fn handle_new_frame(state: &mut AppState, frame: LapState, image: RgbImage) {
|
||||||
|
|
||||||
if state.current_race.is_none() {
|
if state.current_race.is_none() {
|
||||||
let track_hash = get_track_hash(state.config.as_ref(), &image);
|
let track_hash = get_track_hash(state.config.as_ref(), &image);
|
||||||
let track_name = get_track_name(state.learned.as_ref(), &track_hash, state.config.as_ref());
|
let track_name = state.learned_tracks.infer_track(&track_hash, state.config.as_ref());
|
||||||
|
let inferred_track = track_name.is_some();
|
||||||
let race = RaceState {
|
let race = RaceState {
|
||||||
screencap: Some(
|
screencap: Some(
|
||||||
RetainedImage::from_image_bytes(
|
RetainedImage::from_image_bytes(
|
||||||
|
@ -100,6 +83,7 @@ fn handle_new_frame(state: &mut AppState, frame: LapState, image: RgbImage) {
|
||||||
race_time: Some(SystemTime::now()),
|
race_time: Some(SystemTime::now()),
|
||||||
track_hash,
|
track_hash,
|
||||||
track: track_name.unwrap_or_default(),
|
track: track_name.unwrap_or_default(),
|
||||||
|
inferred_track,
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
state.current_race = Some(race);
|
state.current_race = Some(race);
|
||||||
|
@ -167,16 +151,15 @@ fn add_saved_frame(
|
||||||
fn run_loop_once(capturer: &mut Capturer, state: &SharedAppState) -> Result<()> {
|
fn run_loop_once(capturer: &mut Capturer, state: &SharedAppState) -> Result<()> {
|
||||||
let frame = capture::get_frame(capturer)?;
|
let frame = capture::get_frame(capturer)?;
|
||||||
|
|
||||||
let (config, learned_config, ocr_cache, should_sample) = {
|
let (config, ocr_cache, should_sample) = {
|
||||||
let locked = state.lock().unwrap();
|
let locked = state.lock().unwrap();
|
||||||
(
|
(
|
||||||
locked.config.clone(),
|
locked.config.clone(),
|
||||||
locked.learned.clone(),
|
|
||||||
locked.ocr_cache.clone(),
|
locked.ocr_cache.clone(),
|
||||||
locked.should_sample_ocr_data
|
locked.should_sample_ocr_data
|
||||||
)
|
)
|
||||||
};
|
};
|
||||||
let ocr_results = ocr::ocr_all_regions(&frame, config.clone(), learned_config, ocr_cache, should_sample);
|
let ocr_results = ocr::ocr_all_regions(&frame, config.clone(), ocr_cache, should_sample);
|
||||||
|
|
||||||
if state.lock().unwrap().debug_frames {
|
if state.lock().unwrap().debug_frames {
|
||||||
let debug_frames = save_frames_from(&frame, config.as_ref(), &ocr_results);
|
let debug_frames = save_frames_from(&frame, config.as_ref(), &ocr_results);
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
use std::{collections::HashMap, path::PathBuf};
|
use std::{path::PathBuf};
|
||||||
|
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use serde::{Serialize, Deserialize, de::DeserializeOwned};
|
use serde::{Serialize, Deserialize, de::DeserializeOwned};
|
||||||
|
@ -19,26 +19,11 @@ pub struct Config {
|
||||||
|
|
||||||
impl Config {
|
impl Config {
|
||||||
pub fn load() -> Result<Self> {
|
pub fn load() -> Result<Self> {
|
||||||
load_or_make_default("config.json", include_str!("configs/config.default.json"))
|
load_config_or_make_default("config.json", include_str!("configs/config.default.json"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Default, Serialize, Deserialize, Clone)]
|
pub fn load_config_or_make_default<T: DeserializeOwned>(path: &str, default: &str) -> Result<T> {
|
||||||
pub struct LearnedConfig {
|
|
||||||
pub learned_images: HashMap<String, String>,
|
|
||||||
pub learned_tracks: HashMap<String, String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl LearnedConfig {
|
|
||||||
pub fn load() -> Result<Self> {
|
|
||||||
load_or_make_default("learned.json", include_str!("configs/learned.default.json"))
|
|
||||||
}
|
|
||||||
pub fn save(&self) -> Result<()> {
|
|
||||||
save_json_config("learned.json", self)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn load_or_make_default<T: DeserializeOwned>(path: &str, default: &str) -> Result<T> {
|
|
||||||
let file_path = PathBuf::from(path);
|
let file_path = PathBuf::from(path);
|
||||||
if !file_path.exists() {
|
if !file_path.exists() {
|
||||||
std::fs::write(&path, default)?;
|
std::fs::write(&path, default)?;
|
||||||
|
@ -52,7 +37,7 @@ fn load_json_config<T: DeserializeOwned>(path: &str) -> Result<T> {
|
||||||
Ok(value)
|
Ok(value)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn save_json_config<T: Serialize>(path: &str, val: &T) -> Result<()> {
|
pub fn save_json_config<T: Serialize>(path: &str, val: &T) -> Result<()> {
|
||||||
let serialized = serde_json::to_vec_pretty(val)?;
|
let serialized = serde_json::to_vec_pretty(val)?;
|
||||||
Ok(std::fs::write(path, &serialized)?)
|
Ok(std::fs::write(path, &serialized)?)
|
||||||
}
|
}
|
|
@ -76,7 +76,7 @@ pub fn check_target_color_fraction(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return ((region.height() * region.width()) as f64) / (color_area as f64);
|
((region.height() * region.width()) as f64) / (color_area as f64)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn to_png_bytes(image: &RgbImage) -> Vec<u8> {
|
pub fn to_png_bytes(image: &RgbImage) -> Vec<u8> {
|
||||||
|
|
|
@ -0,0 +1,51 @@
|
||||||
|
use std::{collections::HashMap, sync::Arc};
|
||||||
|
|
||||||
|
use crate::{config::{load_config_or_make_default, save_json_config, Config}, image_processing::{extract_and_filter, hash_image}};
|
||||||
|
|
||||||
|
use anyhow::Result;
|
||||||
|
use image::RgbImage;
|
||||||
|
use img_hash::ImageHash;
|
||||||
|
use serde::{Serialize, Deserialize};
|
||||||
|
|
||||||
|
|
||||||
|
#[derive(Default, Serialize, Deserialize, Clone)]
|
||||||
|
pub struct LearnedTracks {
|
||||||
|
learned_tracks: HashMap<String, String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LearnedTracks {
|
||||||
|
pub fn load() -> Result<Self> {
|
||||||
|
load_config_or_make_default("learned.json", include_str!("configs/learned.default.json"))
|
||||||
|
}
|
||||||
|
pub fn save(&self) -> Result<()> {
|
||||||
|
save_json_config("learned.json", self)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn infer_track(&self, hash: &Option<String>, config: &Config) -> Option<String> {
|
||||||
|
let hash = hash.as_ref()?;
|
||||||
|
for (learned_hash_b64, learned_track) in &self.learned_tracks {
|
||||||
|
let learned_hash: ImageHash<Vec<u8>> = ImageHash::from_base64(learned_hash_b64).ok()?;
|
||||||
|
let current_hash: ImageHash<Vec<u8>> = ImageHash::from_base64(hash).ok()?;
|
||||||
|
if current_hash.dist(&learned_hash) <= config.track_recognition_threshold.unwrap_or(10) {
|
||||||
|
return Some(learned_track.to_owned())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn learn_and_save(self: &mut Arc<Self>, hash: &str, track: &str) -> Result<()> {
|
||||||
|
let mut tracks = (**self).clone();
|
||||||
|
tracks
|
||||||
|
.learned_tracks
|
||||||
|
.insert(hash.to_owned(), track.to_owned());
|
||||||
|
tracks.save()?;
|
||||||
|
*self = Arc::new(tracks);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_track_hash(config: &Config, image: &RgbImage) -> Option<String> {
|
||||||
|
let track_region = config.track_region.as_ref()?;
|
||||||
|
let extracted = extract_and_filter(image, track_region);
|
||||||
|
Some(hash_image(&extracted))
|
||||||
|
}
|
|
@ -1,5 +1,5 @@
|
||||||
use image::{DynamicImage, Rgb, RgbImage};
|
use image::{Rgb, RgbImage};
|
||||||
use img_hash::{image::GenericImageView, ImageHash};
|
use img_hash::{image::GenericImageView};
|
||||||
|
|
||||||
use crate::image_processing;
|
use crate::image_processing;
|
||||||
|
|
||||||
|
@ -33,7 +33,7 @@ fn row_has_any_dark(image: &RgbImage, y: u32, start_x: u32, width: u32) -> bool
|
||||||
|
|
||||||
fn take_while<F: Fn(u32) -> bool>(x: &mut u32, max: u32, f: F) {
|
fn take_while<F: Fn(u32) -> bool>(x: &mut u32, max: u32, f: F) {
|
||||||
while *x < max && f(*x) {
|
while *x < max && f(*x) {
|
||||||
*x = *x + 1;
|
*x += 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
89
src/main.rs
89
src/main.rs
|
@ -1,43 +1,46 @@
|
||||||
// #![cfg_attr(not(debug_assertions), windows_subsystem = "windows")] // hide console window on Windows in release
|
// #![cfg_attr(not(debug_assertions), windows_subsystem = "windows")] // hide console window on Windows in release
|
||||||
|
|
||||||
|
mod analysis;
|
||||||
mod capture;
|
mod capture;
|
||||||
mod config;
|
mod config;
|
||||||
mod analysis;
|
|
||||||
mod image_processing;
|
mod image_processing;
|
||||||
|
mod local_ocr;
|
||||||
mod ocr;
|
mod ocr;
|
||||||
mod state;
|
mod state;
|
||||||
mod stats_writer;
|
mod stats_writer;
|
||||||
mod local_ocr;
|
|
||||||
mod training_ui;
|
mod training_ui;
|
||||||
|
mod learned_tracks;
|
||||||
|
|
||||||
use std::{
|
use std::{
|
||||||
collections::HashMap,
|
collections::HashMap,
|
||||||
ops::DerefMut,
|
ops::DerefMut,
|
||||||
|
path::PathBuf,
|
||||||
sync::{Arc, Mutex},
|
sync::{Arc, Mutex},
|
||||||
thread,
|
thread,
|
||||||
time::Duration, path::PathBuf,
|
time::Duration,
|
||||||
};
|
};
|
||||||
|
|
||||||
use config::{Config, LearnedConfig};
|
|
||||||
use analysis::save_frames_from;
|
use analysis::save_frames_from;
|
||||||
|
use config::Config;
|
||||||
use eframe::{
|
use eframe::{
|
||||||
egui::{self, Ui, Visuals},
|
egui::{self, Ui, Visuals},
|
||||||
emath::Vec2,
|
emath::Vec2,
|
||||||
epaint::Color32,
|
epaint::Color32,
|
||||||
};
|
};
|
||||||
use egui_extras::RetainedImage;
|
use egui_extras::RetainedImage;
|
||||||
use image_processing::{to_png_bytes, from_png_bytes};
|
use image_processing::{from_png_bytes, to_png_bytes};
|
||||||
|
use learned_tracks::LearnedTracks;
|
||||||
use state::{AppState, DebugOcrFrame, LapState, OcrCache, RaceState, SharedAppState};
|
use state::{AppState, DebugOcrFrame, LapState, OcrCache, RaceState, SharedAppState};
|
||||||
use stats_writer::export_race_stats;
|
use stats_writer::export_race_stats;
|
||||||
|
|
||||||
fn main() -> anyhow::Result<()> {
|
fn main() -> anyhow::Result<()> {
|
||||||
let mode = std::env::args().nth(1).unwrap_or_default().to_string();
|
let mode = std::env::args().nth(1).unwrap_or_default();
|
||||||
if mode == "train" {
|
if mode == "train" {
|
||||||
return training_ui::training_ui();
|
return training_ui::training_ui();
|
||||||
}
|
}
|
||||||
let app_state = AppState {
|
let app_state = AppState {
|
||||||
config: Arc::new(Config::load().unwrap()),
|
config: Arc::new(Config::load().unwrap()),
|
||||||
learned: Arc::new(LearnedConfig::load().unwrap()),
|
learned_tracks: Arc::new(LearnedTracks::load().unwrap()),
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
let state = Arc::new(Mutex::new(app_state));
|
let state = Arc::new(Mutex::new(app_state));
|
||||||
|
@ -51,7 +54,10 @@ fn main() -> anyhow::Result<()> {
|
||||||
let options = eframe::NativeOptions::default();
|
let options = eframe::NativeOptions::default();
|
||||||
let current_exe = std::env::current_exe().unwrap_or_else(|_| PathBuf::from("supper.exe"));
|
let current_exe = std::env::current_exe().unwrap_or_else(|_| PathBuf::from("supper.exe"));
|
||||||
eframe::run_native(
|
eframe::run_native(
|
||||||
&format!("Supper OCR ({})", current_exe.file_name().unwrap().to_string_lossy()),
|
&format!(
|
||||||
|
"Supper OCR ({})",
|
||||||
|
current_exe.file_name().unwrap().to_string_lossy()
|
||||||
|
),
|
||||||
options,
|
options,
|
||||||
Box::new(|_cc| Box::new(AppUi::new(state))),
|
Box::new(|_cc| Box::new(AppUi::new(state))),
|
||||||
);
|
);
|
||||||
|
@ -172,7 +178,6 @@ fn show_race_state(
|
||||||
race_name: &str,
|
race_name: &str,
|
||||||
race: &mut RaceState,
|
race: &mut RaceState,
|
||||||
config: Arc<Config>,
|
config: Arc<Config>,
|
||||||
learned: Arc<LearnedConfig>,
|
|
||||||
ocr_cache: Arc<OcrCache>,
|
ocr_cache: Arc<OcrCache>,
|
||||||
) {
|
) {
|
||||||
egui::Grid::new(format!("race:{}", race_name)).show(ui, |ui| {
|
egui::Grid::new(format!("race:{}", race_name)).show(ui, |ui| {
|
||||||
|
@ -228,7 +233,6 @@ fn show_race_state(
|
||||||
ui_state,
|
ui_state,
|
||||||
lap,
|
lap,
|
||||||
config.clone(),
|
config.clone(),
|
||||||
learned.clone(),
|
|
||||||
ocr_cache.clone(),
|
ocr_cache.clone(),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -275,36 +279,17 @@ fn show_config_controls(ui: &mut Ui, ui_state: &mut UiState, state: &mut AppStat
|
||||||
if let Some(e) = &ui_state.config_load_err {
|
if let Some(e) = &ui_state.config_load_err {
|
||||||
ui.colored_label(Color32::RED, e);
|
ui.colored_label(Color32::RED, e);
|
||||||
}
|
}
|
||||||
|
|
||||||
ui.separator();
|
|
||||||
ui.label("Hash");
|
|
||||||
ui.text_edit_singleline(&mut ui_state.hash_to_learn);
|
|
||||||
ui.label("Value");
|
|
||||||
ui.text_edit_singleline(&mut ui_state.value_to_learn);
|
|
||||||
if ui.button("Learn").clicked() {
|
|
||||||
let mut learned_config = (*state.learned).clone();
|
|
||||||
learned_config.learned_images.insert(
|
|
||||||
ui_state.hash_to_learn.clone(),
|
|
||||||
ui_state.value_to_learn.clone(),
|
|
||||||
);
|
|
||||||
learned_config.save().unwrap();
|
|
||||||
state.learned = Arc::new(learned_config);
|
|
||||||
|
|
||||||
ui_state.hash_to_learn = "".to_owned();
|
|
||||||
ui_state.value_to_learn = "".to_owned();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn open_debug_lap(
|
fn open_debug_lap(
|
||||||
ui_state: &mut UiState,
|
ui_state: &mut UiState,
|
||||||
lap: &LapState,
|
lap: &LapState,
|
||||||
config: Arc<Config>,
|
config: Arc<Config>,
|
||||||
learned: Arc<LearnedConfig>,
|
|
||||||
ocr_cache: Arc<OcrCache>,
|
ocr_cache: Arc<OcrCache>,
|
||||||
) {
|
) {
|
||||||
if let Some(screenshot_bytes) = &lap.screenshot {
|
if let Some(screenshot_bytes) = &lap.screenshot {
|
||||||
let screenshot = from_png_bytes(screenshot_bytes);
|
let screenshot = from_png_bytes(screenshot_bytes);
|
||||||
let ocr_results = ocr::ocr_all_regions(&screenshot, config.clone(), learned, ocr_cache, false);
|
let ocr_results = ocr::ocr_all_regions(&screenshot, config.clone(), ocr_cache, false);
|
||||||
let debug_lap = DebugLap {
|
let debug_lap = DebugLap {
|
||||||
screenshot: RetainedImage::from_image_bytes("debug-lap", &to_png_bytes(&screenshot))
|
screenshot: RetainedImage::from_image_bytes("debug-lap", &to_png_bytes(&screenshot))
|
||||||
.unwrap(),
|
.unwrap(),
|
||||||
|
@ -323,14 +308,7 @@ fn show_combo_box(ui: &mut Ui, name: &str, label: &str, options: &[String], valu
|
||||||
*value = options[index].clone();
|
*value = options[index].clone();
|
||||||
}
|
}
|
||||||
|
|
||||||
fn save_learned_track(learned: &mut Arc<LearnedConfig>, track: &str, hash: &str) {
|
fn save_learned_track(_learned_tracks: &mut Arc<LearnedTracks>, _track: &str, _hash: &str) {
|
||||||
let mut learned_config = (**learned).clone();
|
|
||||||
learned_config.learned_tracks.insert(
|
|
||||||
hash.to_owned(),
|
|
||||||
track.to_owned(),
|
|
||||||
);
|
|
||||||
learned_config.save().unwrap();
|
|
||||||
*learned = Arc::new(learned_config);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl eframe::App for AppUi {
|
impl eframe::App for AppUi {
|
||||||
|
@ -388,13 +366,19 @@ impl eframe::App for AppUi {
|
||||||
if let Some(tyre_wear) = race.tyre_wear() {
|
if let Some(tyre_wear) = race.tyre_wear() {
|
||||||
ui.heading(&format!("p50 Tyre Wear: {}", tyre_wear));
|
ui.heading(&format!("p50 Tyre Wear: {}", tyre_wear));
|
||||||
if let Some(tyres) = frame.tyres {
|
if let Some(tyres) = frame.tyres {
|
||||||
ui.label(&format!("Out of tires in {:.1} lap(s)", (tyres as f64) / (tyre_wear as f64)));
|
ui.label(&format!(
|
||||||
|
"Out of tires in {:.1} lap(s)",
|
||||||
|
(tyres as f64) / (tyre_wear as f64)
|
||||||
|
));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if let Some(gas_wear) = race.gas_per_lap() {
|
if let Some(gas_wear) = race.gas_per_lap() {
|
||||||
ui.heading(&format!("p50 Gas Wear: {}", gas_wear));
|
ui.heading(&format!("p50 Gas Wear: {}", gas_wear));
|
||||||
if let Some(gas) = frame.gas {
|
if let Some(gas) = frame.gas {
|
||||||
ui.label(&format!("Out of gas in {:.1} lap(s)", (gas as f64) / (gas_wear as f64)));
|
ui.label(&format!(
|
||||||
|
"Out of gas in {:.1} lap(s)",
|
||||||
|
(gas as f64) / (gas_wear as f64)
|
||||||
|
));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -402,12 +386,15 @@ impl eframe::App for AppUi {
|
||||||
|
|
||||||
ui.separator();
|
ui.separator();
|
||||||
ui.checkbox(&mut state.debug_frames, "Debug OCR regions");
|
ui.checkbox(&mut state.debug_frames, "Debug OCR regions");
|
||||||
ui.checkbox(&mut state.should_sample_ocr_data, "Dump OCR training frames");
|
ui.checkbox(
|
||||||
|
&mut state.should_sample_ocr_data,
|
||||||
|
"Dump OCR training frames",
|
||||||
|
);
|
||||||
});
|
});
|
||||||
egui::CentralPanel::default().show(ctx, |ui| {
|
egui::CentralPanel::default().show(ctx, |ui| {
|
||||||
egui::ScrollArea::vertical().show(ui, |ui| {
|
egui::ScrollArea::vertical().show(ui, |ui| {
|
||||||
let config = state.config.clone();
|
let config = state.config.clone();
|
||||||
let learned = state.learned.clone();
|
let _learned_tracks = state.learned_tracks.clone();
|
||||||
let ocr_cache = state.ocr_cache.clone();
|
let ocr_cache = state.ocr_cache.clone();
|
||||||
if let Some(race) = &mut state.current_race {
|
if let Some(race) = &mut state.current_race {
|
||||||
ui.heading(&format!("Current Race: {}", race.name()));
|
ui.heading(&format!("Current Race: {}", race.name()));
|
||||||
|
@ -417,13 +404,12 @@ impl eframe::App for AppUi {
|
||||||
"current",
|
"current",
|
||||||
race,
|
race,
|
||||||
config.clone(),
|
config.clone(),
|
||||||
learned,
|
|
||||||
ocr_cache.clone(),
|
ocr_cache.clone(),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
let len = state.past_races.len();
|
let len = state.past_races.len();
|
||||||
let mut races_to_remove = Vec::new();
|
let mut races_to_remove = Vec::new();
|
||||||
let mut learned = state.learned.clone();
|
let mut learned_tracks = state.learned_tracks.clone();
|
||||||
for (i, race) in state.past_races.iter_mut().enumerate() {
|
for (i, race) in state.past_races.iter_mut().enumerate() {
|
||||||
ui.separator();
|
ui.separator();
|
||||||
ui.heading(format!("Race #{}: {}", len - i, race.name()));
|
ui.heading(format!("Race #{}: {}", len - i, race.name()));
|
||||||
|
@ -433,14 +419,19 @@ impl eframe::App for AppUi {
|
||||||
&format!("race {}:", i),
|
&format!("race {}:", i),
|
||||||
race,
|
race,
|
||||||
config.clone(),
|
config.clone(),
|
||||||
learned.clone(),
|
|
||||||
ocr_cache.clone(),
|
ocr_cache.clone(),
|
||||||
);
|
);
|
||||||
if let Some(img) = &race.screencap {
|
if let Some(img) = &race.screencap {
|
||||||
img.show_max_size(ui, Vec2::new(600.0, 500.0));
|
img.show_max_size(ui, Vec2::new(600.0, 500.0));
|
||||||
}
|
}
|
||||||
if !race.exported {
|
if !race.exported {
|
||||||
show_combo_box(ui, &format!("car-combo {}", i), "Car", &self.data.cars, &mut race.car);
|
show_combo_box(
|
||||||
|
ui,
|
||||||
|
&format!("car-combo {}", i),
|
||||||
|
"Car",
|
||||||
|
&self.data.cars,
|
||||||
|
&mut race.car,
|
||||||
|
);
|
||||||
show_combo_box(
|
show_combo_box(
|
||||||
ui,
|
ui,
|
||||||
&format!("track-combo {}", i),
|
&format!("track-combo {}", i),
|
||||||
|
@ -452,7 +443,9 @@ impl eframe::App for AppUi {
|
||||||
ui.text_edit_singleline(&mut race.comments);
|
ui.text_edit_singleline(&mut race.comments);
|
||||||
if ui.button("Export").clicked() {
|
if ui.button("Export").clicked() {
|
||||||
if let Some(track_hash) = &race.track_hash {
|
if let Some(track_hash) = &race.track_hash {
|
||||||
save_learned_track(&mut learned, &race.track, track_hash);
|
if !race.inferred_track {
|
||||||
|
learned_tracks.learn_and_save(track_hash, &race.track).unwrap();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
match export_race_stats(race) {
|
match export_race_stats(race) {
|
||||||
Ok(_) => {
|
Ok(_) => {
|
||||||
|
@ -474,7 +467,7 @@ impl eframe::App for AppUi {
|
||||||
races_to_remove.push(i);
|
races_to_remove.push(i);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
state.learned = learned;
|
state.learned_tracks = learned_tracks;
|
||||||
for index in races_to_remove {
|
for index in races_to_remove {
|
||||||
state.past_races.remove(index);
|
state.past_races.remove(index);
|
||||||
}
|
}
|
||||||
|
|
18
src/ocr.rs
18
src/ocr.rs
|
@ -8,8 +8,9 @@ use image::RgbImage;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
config::{Config, LearnedConfig},
|
config::Config,
|
||||||
image_processing::{extract_and_filter, hash_image}, state::OcrCache,
|
image_processing::{extract_and_filter, hash_image},
|
||||||
|
state::OcrCache,
|
||||||
};
|
};
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
|
@ -81,9 +82,8 @@ async fn run_ocr_cached(
|
||||||
pub async fn ocr_all_regions(
|
pub async fn ocr_all_regions(
|
||||||
image: &RgbImage,
|
image: &RgbImage,
|
||||||
config: Arc<Config>,
|
config: Arc<Config>,
|
||||||
learned: Arc<LearnedConfig>,
|
|
||||||
ocr_cache: Arc<OcrCache>,
|
ocr_cache: Arc<OcrCache>,
|
||||||
should_sample: bool
|
should_sample: bool,
|
||||||
) -> HashMap<String, Option<String>> {
|
) -> HashMap<String, Option<String>> {
|
||||||
let results = Arc::new(Mutex::new(HashMap::new()));
|
let results = Arc::new(Mutex::new(HashMap::new()));
|
||||||
|
|
||||||
|
@ -93,19 +93,15 @@ pub async fn ocr_all_regions(
|
||||||
let region = region.clone();
|
let region = region.clone();
|
||||||
let results = results.clone();
|
let results = results.clone();
|
||||||
let config = config.clone();
|
let config = config.clone();
|
||||||
let learned = learned.clone();
|
|
||||||
let ocr_cache = ocr_cache.clone();
|
let ocr_cache = ocr_cache.clone();
|
||||||
handles.push(tokio::spawn(async move {
|
handles.push(tokio::spawn(async move {
|
||||||
let filtered_image = filtered_image;
|
let filtered_image = filtered_image;
|
||||||
let hash = hash_image(&filtered_image);
|
let hash = hash_image(&filtered_image);
|
||||||
let value = if let Some(learned_value) = learned.learned_images.get(&hash) {
|
let value =
|
||||||
Some(learned_value.clone())
|
run_ocr_cached(ocr_cache, hash, ®ion, config.clone(), &filtered_image).await;
|
||||||
} else {
|
|
||||||
run_ocr_cached(ocr_cache, hash, ®ion, config.clone(), &filtered_image).await
|
|
||||||
};
|
|
||||||
|
|
||||||
if let Some(sample_fraction) = &config.dump_frame_fraction {
|
if let Some(sample_fraction) = &config.dump_frame_fraction {
|
||||||
if rand::random::<f64>() < *sample_fraction {
|
if rand::random::<f64>() < *sample_fraction && should_sample {
|
||||||
let file_id = rand::random::<usize>();
|
let file_id = rand::random::<usize>();
|
||||||
let img_filename = format!("ocr_data/{}.png", file_id);
|
let img_filename = format!("ocr_data/{}.png", file_id);
|
||||||
filtered_image.save(img_filename).unwrap();
|
filtered_image.save(img_filename).unwrap();
|
||||||
|
|
|
@ -4,7 +4,7 @@ use egui_extras::RetainedImage;
|
||||||
use image::RgbImage;
|
use image::RgbImage;
|
||||||
use time::{OffsetDateTime, format_description};
|
use time::{OffsetDateTime, format_description};
|
||||||
|
|
||||||
use crate::config::{Config, LearnedConfig};
|
use crate::{config::Config, learned_tracks::LearnedTracks};
|
||||||
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Default)]
|
#[derive(Debug, Clone, Default)]
|
||||||
|
@ -74,7 +74,7 @@ fn median_wear(values: Vec<Option<usize>>) -> Option<usize> {
|
||||||
last_value = val;
|
last_value = val;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
wear_values.sort();
|
wear_values.sort_unstable();
|
||||||
wear_values.get(wear_values.len() / 2).cloned()
|
wear_values.get(wear_values.len() / 2).cloned()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -93,6 +93,8 @@ pub struct RaceState {
|
||||||
pub car: String,
|
pub car: String,
|
||||||
pub track: String,
|
pub track: String,
|
||||||
pub comments: String,
|
pub comments: String,
|
||||||
|
|
||||||
|
pub inferred_track: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl RaceState {
|
impl RaceState {
|
||||||
|
@ -143,7 +145,7 @@ pub struct AppState {
|
||||||
pub should_sample_ocr_data: bool,
|
pub should_sample_ocr_data: bool,
|
||||||
|
|
||||||
pub config: Arc<Config>,
|
pub config: Arc<Config>,
|
||||||
pub learned: Arc<LearnedConfig>,
|
pub learned_tracks: Arc<LearnedTracks>,
|
||||||
|
|
||||||
pub ocr_cache: Arc<OcrCache>,
|
pub ocr_cache: Arc<OcrCache>,
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,15 +1,10 @@
|
||||||
use std::{
|
use std::{
|
||||||
collections::HashMap,
|
|
||||||
io::Write,
|
io::Write,
|
||||||
path::{PathBuf, Path},
|
path::{PathBuf, Path},
|
||||||
sync::{Arc, Mutex},
|
|
||||||
thread,
|
|
||||||
time::Duration,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
use eframe::{
|
use eframe::{
|
||||||
egui::{self, Ui, Visuals},
|
egui::{self, Visuals},
|
||||||
emath::Vec2,
|
|
||||||
epaint::Color32,
|
epaint::Color32,
|
||||||
};
|
};
|
||||||
use egui_extras::RetainedImage;
|
use egui_extras::RetainedImage;
|
||||||
|
@ -63,7 +58,7 @@ fn get_training_data_paths() -> Vec<(PathBuf, PathBuf)> {
|
||||||
|
|
||||||
fn predict_ocr(hashes: &[(String, char)], hash: &str) -> Option<char> {
|
fn predict_ocr(hashes: &[(String, char)], hash: &str) -> Option<char> {
|
||||||
let hash = img_hash::ImageHash::<Vec<u8>>::from_base64(hash).unwrap();
|
let hash = img_hash::ImageHash::<Vec<u8>>::from_base64(hash).unwrap();
|
||||||
let (_, best_char) = hashes.iter().min_by_key(|(learned_hash, c)| {
|
let (_, best_char) = hashes.iter().min_by_key(|(learned_hash, _c)| {
|
||||||
img_hash::ImageHash::from_base64(learned_hash)
|
img_hash::ImageHash::from_base64(learned_hash)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.dist(&hash)
|
.dist(&hash)
|
||||||
|
@ -106,8 +101,8 @@ fn load_learned_hashes() -> Vec<(String, char)> {
|
||||||
let data = String::from_utf8(std::fs::read(path).unwrap()).unwrap();
|
let data = String::from_utf8(std::fs::read(path).unwrap()).unwrap();
|
||||||
let mut parsed = Vec::new();
|
let mut parsed = Vec::new();
|
||||||
for line in data.lines() {
|
for line in data.lines() {
|
||||||
if let Some((c, hash)) = line.split_once(" ") {
|
if let Some((c, hash)) = line.split_once(' ') {
|
||||||
if let Some(c) = c.chars().nth(0) {
|
if let Some(c) = c.chars().next() {
|
||||||
parsed.push((hash.to_owned(), c));
|
parsed.push((hash.to_owned(), c));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -197,7 +192,7 @@ impl eframe::App for TrainingUi {
|
||||||
}
|
}
|
||||||
for c in ¤t_image.char_hashes {
|
for c in ¤t_image.char_hashes {
|
||||||
ui.label(c);
|
ui.label(c);
|
||||||
if let Some(predicted) = predict_ocr(&self.learned_char_hashes, &c) {
|
if let Some(predicted) = predict_ocr(&self.learned_char_hashes, c) {
|
||||||
ui.label(format!("Predicted: {}", predicted));
|
ui.label(format!("Predicted: {}", predicted));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue