local OCR working
This commit is contained in:
parent
8baad566aa
commit
e195158218
|
@ -33,15 +33,18 @@
|
|||
"x": 1744,
|
||||
"y": 127,
|
||||
"width": 137,
|
||||
"height": 32
|
||||
"height": 32,
|
||||
"threshold": 0.88
|
||||
},
|
||||
{
|
||||
"name": "lap_time",
|
||||
"x": 1744,
|
||||
"y": 166,
|
||||
"width": 137,
|
||||
"height": 32
|
||||
"height": 32,
|
||||
"threshold": 0.88
|
||||
}
|
||||
],
|
||||
"track_recognition_threshold": 10,
|
||||
"ocr_server_endpoint": "https://tesserver.spruett.dev/"
|
||||
}
|
||||
|
|
|
@ -0,0 +1,101 @@
|
|||
{
|
||||
"learned_chars": {
|
||||
"//8P+Afw4+fzx/HP+c/5z/nP8c/zx/Pnx/MP+H/+//8=": "0",
|
||||
"//8H4Afg/+P/4//j/+P/4//j/+P/4//j/+P/4//j//8=": "1",
|
||||
"//8f+A/g5+fzz/PP88/zn/PP88/zz+fP5+Mf8H/+//8=": "0",
|
||||
"/////P/8f/4//p//3//H++P78/sBwAGA//v/+//7//8=": "4",
|
||||
"//8P4A/gD+AP/w//D/gP4AfA/8P/w4fAA+AD8B/8//8=": "5",
|
||||
"//8P4A/gD+AP/4//D/gH4Afg/8H/wYfgA+AD8B/8//8=": "5",
|
||||
"//9/4H/wP/Af+A/8B/4DggODAYABgAGAAYD/g/+D//8=": "4",
|
||||
"/////f/8f/w//5//j//H++f78/MBgP/z//v/+//7//8=": "4",
|
||||
"/////wfwB+DD4cPBA+AP8Afgw8PDw4PBA+AH8P////8=": "8",
|
||||
"//8BgAGAAYDhgeGB/8D/wH/gP+A/8B/wH/gP+A/8//8=": "7",
|
||||
"////z/+fH/4H+AP4w/Dh8eHx4fHh8cPwA/gP/B/+//8=": "0",
|
||||
"//8H4Afg/+H/4f/h/+H/4f/h/+H/4f/h/+H/4f/h//8=": "1",
|
||||
"//8P+Afw4+fz5/HP+c/5z/nP+c/zx/Pnx/MP+H////8=": "0",
|
||||
"///f/w/wB+DH48fhB+AH4APgw8PDw8PDB+AP8P////8=": "8",
|
||||
"/////w/wB+DH48fjB+AH4Afgw8PDw8PDB+AP8P////8=": "8",
|
||||
"//8f/Efx8+fz7/vP88fjww/I/8//z//n/+MH+A/+//8=": "9",
|
||||
"//8/4I/g5//z//P/O/wL4OPH88/zz/fPx8cP4H/8//8=": "6",
|
||||
"//8P8AfgA8CDgcGDwYPBg8GDwYPBg4OBA8AH4A/w//8=": "0",
|
||||
"//8f+Afi88/zz/PPx+EP8OPH88/73/Pf48cH4H/+//8=": "8",
|
||||
"/////wPwA+DD4P/h/+H/4H/wP/gP/gf+A8ADwP////8=": "2",
|
||||
"//8f+A/g5+fzz/PP88/zz/PP88/zz+fH5+cP8H/+//8=": "0",
|
||||
"//8D4APg//P/+X/8P/4/8P/H/8//z//P8+cD8D////8=": "3",
|
||||
"/////wPAA8DDwcPBw+H/4P/wf/A/8D/4P/w/+P////8=": "7",
|
||||
"//8P+APw++f/7//v/+f/4//x//w//h//z/8DwAPA//8=": "2",
|
||||
"/////wPAA8Dj4OPg4+B/8H/wP/g/+D/8H/wf/P////8=": "7",
|
||||
"//8/+A/w5+fjz/PP88/zz/PP88/zz+PP5+cP8D/8//8=": "0",
|
||||
"//8P+CPw8+f77/vP88fjwQ/M/8//z//n//EH+D////8=": "9",
|
||||
"//8HwAfA//P/8f/8f/5/8P/H/8//3//P88cD8H/+//8=": "3",
|
||||
"//8DwAPA+8/75//n//P/8//5//n//H/+f/4//z////8=": "7",
|
||||
"//8/4Ifg5//z//P/O/wL4OPH88/zz/PP5+cP4H/8//8=": "6",
|
||||
"//+//Q/4B/CD4cPjw8PDw+PDw+PD4QfgD/Af+H/+//8=": "0",
|
||||
"//9/7x/AD+AH74f/g/AD4APAg8OHwwfDD8Af4H/7//8=": "6",
|
||||
"//8f+Afw++f/7//v/+f/4//4f/w//4//x/8DwAPA//8=": "2",
|
||||
"///f/Q/wB+CH4cPDw8PDw8PDw8PDwwfgB/AP+L/9//8=": "0",
|
||||
"//9/7x/gD+AH/4P/g/gD4APAg8ODw4fBB+Af8H/9//8=": "6",
|
||||
"//8H4Afi//P/+P/8f/4/8P/j/8f/z//H8+MD8D/+//8=": "3",
|
||||
"//8P4A/g7//n/+f/B/wH8P/H/8//3//P88cD8H/+//8=": "5",
|
||||
"//8DwAPA88/z5//j/+P/8f/5//j//H/+P/4//h////8=": "7",
|
||||
"//8P8A/wD/CH/4f/B/gH8Afg/8H/wYfgA+AD+B/+//8=": "5",
|
||||
"//8P8APgAcDBgcGDwYEBgAOAD4L/g//AA8AD4AP4//8=": "9",
|
||||
"//8f+Afg5+fzz/PP+9/73/vf+9/zz+PP5+cP4B/4//8=": "0",
|
||||
"/////wfgA+DHwf/B/8H/4H/gP/gf/A/+B8ADwP////8=": "2",
|
||||
"//9/7x/AD+AH7of/g/AD4APAg8ODwwfDD+Af4H/7//8=": "6",
|
||||
"/////wfgB+B/4H/gf+B/4H/gf+B/4H/gf+B/4P////8=": "1",
|
||||
"//8BgAGAAYDhgeGB7cD/wH/gf+A/8D/wH/gP+A/87/8=": "7",
|
||||
"////+z/4P/wf/g/+D/+H8cPwA8ADwAPAA8D/8P////8=": "4",
|
||||
"//8f+Afg5+fzz/Pf+9/73/vf+9/zz+PHx+MP8H/+//8=": "0",
|
||||
"//////////8//B/4H/gf+B/4H/gf+D/8//////////8=": ".",
|
||||
"//8H4Afg//P/+f/8f/w/4P/H/8//z//P4+cH4D/8//8=": "3",
|
||||
"//8H4A/g/+P/4//j/+P/4//j/+P/4//j/+P/4//j//8=": "1",
|
||||
"//8D4APg//P/+X/8f/4/+P/j/8//z//P+8cD8B/8//8=": "3",
|
||||
"//8DgAOAA8D/4H/wH/AfwB+A/4H/g+OBA4ABwAPw//8=": "3",
|
||||
"///f+w/wB+CDwcPDw8MHwA/AP8P/w0/gB/AH+F////8=": "9",
|
||||
"/////f/8f/x//j//n//P8+fz4/MBwAOA//P/8//z//8=": "4",
|
||||
"//8HwAfA/+P/8//9f/x/8P/H/8//3//P888D4B/4//8=": "3",
|
||||
"///f/Q/4B/DD48PDw8EHwA/AH8P/4U/gB/AH+F////8=": "9",
|
||||
"//8/8A/lx//3//P/8/8b4MPH48/z3/PP58eH4D/8//8=": "6",
|
||||
"/////w/wD/D//////////////////w/wD/AP8H/+//8=": ":",
|
||||
"//8f/Efh4+fzz/PP88fnwwfY/9//z//v/+On8A/8//8=": "9",
|
||||
"////////////+D/wP/Af8A/wD/AP8B/4//////////8=": ".",
|
||||
"//8D4APg//P/+f/8P/4/8P/H/8//z//P88cD8H////8=": "3",
|
||||
"/////wfgB+Af4H/gf+B/4H/gf+B/4H/gf+B/4P////8=": "1",
|
||||
"/////wfgB+A/4D/gP+A/4D/gP+A/4D/gP+A/4P////8=": "1",
|
||||
"//9/7x/AD+AH74f/g/AD4APAg8ODwwfDD8Af4H/7//8=": "6",
|
||||
"//8f+Afi8+fzz/PP48fHww/Y/8//z//v/+MH+D////8=": "9",
|
||||
"//8HwAfAB8CD/wP/A+ADwAOA/4H/g+OBA4ABwAPw//8=": "5",
|
||||
"//8f+Efi5+fzz/vf+9/73/vf+9/zz+PHx+MP8D/8//8=": "0",
|
||||
"/////wfwA+DD4P/h/+H/4H/wP/gf/Af+B8AHwP////8=": "2",
|
||||
"//8D4APg//P/+X/8f/4/+P/j/8//z//P+8cD8A/8//8=": "3",
|
||||
"//8f+Afg5+fzz/vf+9/73/vf+9/zz+PHx+MP8H/+//8=": "0",
|
||||
"//8P/APw++P/5//n/+f/8//4f/w//4//x/8DwAPA//8=": "2",
|
||||
"//8P4A/g7//v/+//B/wH4P/P/8//3//P88cD8H/+//8=": "5",
|
||||
"//8/gA+AB4AD34P/geABgAGAAYOBh4GHA4AHgA/g//8=": "6",
|
||||
"/////w/wD/D/////////////////////D/AP8P////8=": ":",
|
||||
"//8f/Afw5+fz5/HP+c/5z/nP+c/xz/Pn5+cP8B/8//8=": "0",
|
||||
"//9/7x/gD+AH/4P/g/gD4APAg8ODw4fBD+Af8H/9//8=": "6",
|
||||
"//8BgAGAAYA/gD+AP4A/gD+AP4A/gD+AP4A/gD+A//8=": "1",
|
||||
"///f/Q/4B/DD4cPDw8EDwA/AH+P/4W/gB/AH+F////8=": "9",
|
||||
"//8DwAPA+8/75//n//P/8f/5//n//H/+f/4//z////8=": "7",
|
||||
"//8/gA+AB4AD34H/geABgAGAAYOBh4GHA4AHgA/g//8=": "6",
|
||||
"//8P+CPw8+f77/vP88fjwQ/M/8//z//n//EH+J////8=": "9",
|
||||
"//8H8APAAcDDgP+B/4H/wH/gH/AP+Af8A4ADgAOA//8=": "2",
|
||||
"//8f+A/g5+fzz/PP88/zz/PP88/zz+fH5+cf8H/+//8=": "0",
|
||||
"///v/QfwA+DD4cPBw8EDwAfAn8H/4UfwB/AD/G////8=": "9",
|
||||
"/////P/8f/4//p//3//n++P78/sBwAGA//v/+//7//8=": "4",
|
||||
"/////////////wfgB+AH4AfgB+AH4P////////////8=": "-",
|
||||
"//8H4Afg/+H/4f/h/+H/4f/h/+H/4f/h/+H/4f/j//8=": "1",
|
||||
"//8H4AfgB+D/8H/4P/w/8D/g/8P/w4fAA+AD8B/8//8=": "3",
|
||||
"//8P8APAA8CBgYGDA8AH4APAgYHBg8GDgYEDwAfg//8=": "8",
|
||||
"//8P+APw++f/7//v/+f/8//x//w//h//z/8DwAPA//8=": "2",
|
||||
"//8f+Afw4+fzx/HP+c/5z/nP8c/zx/Pnx/MP+H////8=": "0",
|
||||
"//8HwAfAB8D/4H/wP/g/4D/A/8H/w4fAA8AD8B/8//8=": "3",
|
||||
"//+//R/4D/CH4cPDw8PDw8PDw8PHwwfgD/Af+H/+//8=": "0",
|
||||
"//+//R/wD+CH4cfDw8PDw8PDw8PHwwfgD/Af+H/+//8=": "0",
|
||||
"//////////8//D/8H/gP8A/wH/g//D/8//////////8=": ".",
|
||||
"////+3/4P/w//B/+D/8H8YfhA8ADwAPAA8D/4f////8=": "4",
|
||||
"//////////8BgAGAAYABgAGAAYABgAGA//////////8=": "-"
|
||||
}
|
||||
}
|
|
@ -178,3 +178,4 @@
|
|||
2022-06-03-00:30 (Magdalena Club),Magdalena Club,60s GP,11,21.062,20.296,100,48,69,
|
||||
2022-06-03-00:30 (Magdalena Club),Magdalena Club,60s GP,12,20.942,20.296,100,39,61,
|
||||
2022-06-03-00:30 (Magdalena Club),Magdalena Club,60s GP,13,20.576,20.296,100,30,52,
|
||||
2022-06-04-02:10,,,3,25.000,25.000,75,75,75,
|
||||
|
|
|
|
@ -1,7 +1,7 @@
|
|||
use std::{
|
||||
collections::HashMap,
|
||||
thread,
|
||||
time::{Duration, Instant, SystemTime},
|
||||
time::{Duration, Instant, SystemTime}, sync::Arc,
|
||||
};
|
||||
|
||||
use anyhow::Result;
|
||||
|
@ -14,8 +14,7 @@ use crate::{
|
|||
capture,
|
||||
config::Config,
|
||||
image_processing::{self, extract_and_filter, hash_image, Region, to_png_bytes},
|
||||
remote_ocr,
|
||||
state::{AppState, DebugOcrFrame, LapState, RaceState, SharedAppState}, learned_tracks::get_track_hash,
|
||||
state::{AppState, DebugOcrFrame, LapState, RaceState, SharedAppState}, learned_tracks::get_track_hash, ocr_db::OcrDatabase,
|
||||
};
|
||||
|
||||
|
||||
|
@ -130,36 +129,62 @@ fn add_saved_frame(
|
|||
saved_frames: &mut HashMap<String, DebugOcrFrame>,
|
||||
frame: &RgbImage,
|
||||
region: &Region,
|
||||
ocr_results: &HashMap<String, Option<String>>,
|
||||
ocr_results: &HashMap<String, String>,
|
||||
) {
|
||||
let extracted = extract_and_filter(frame, region);
|
||||
let retained =
|
||||
RetainedImage::from_image_bytes(®ion.name, &image_processing::to_png_bytes(&extracted))
|
||||
.unwrap();
|
||||
let hash = hash_image(&extracted);
|
||||
let hash = hash_image(&extracted).to_base64();
|
||||
saved_frames.insert(
|
||||
region.name.clone(),
|
||||
DebugOcrFrame {
|
||||
image: retained,
|
||||
rgb_image: extracted,
|
||||
img_hash: hash,
|
||||
recognized_text: ocr_results.get(®ion.name).and_then(|p| p.clone()),
|
||||
recognized_text: ocr_results.get(®ion.name).map(|p| p.clone()),
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
pub fn ocr_all_regions(
|
||||
ocr_db: &OcrDatabase,
|
||||
image: &RgbImage,
|
||||
config: &Config,
|
||||
should_sample: bool,
|
||||
) -> HashMap<String, String> {
|
||||
let mut results = HashMap::new();
|
||||
|
||||
for region in &config.ocr_regions {
|
||||
let filtered_image = extract_and_filter(image, region);
|
||||
let value = ocr_db.ocr_image(&filtered_image);
|
||||
|
||||
if let Some(sample_fraction) = &config.dump_frame_fraction {
|
||||
if rand::random::<f64>() < *sample_fraction && should_sample {
|
||||
let file_id = rand::random::<usize>();
|
||||
let img_filename = format!("ocr_data/{}.png", file_id);
|
||||
filtered_image.save(img_filename).unwrap();
|
||||
let value_filename = format!("ocr_data/{}.txt", file_id);
|
||||
std::fs::write(value_filename, value.clone()).unwrap();
|
||||
}
|
||||
}
|
||||
results.insert(region.name.clone(), value);
|
||||
}
|
||||
results
|
||||
}
|
||||
|
||||
fn run_loop_once(capturer: &mut Capturer, state: &SharedAppState) -> Result<()> {
|
||||
let frame = capture::get_frame(capturer)?;
|
||||
|
||||
let (config, ocr_cache, should_sample) = {
|
||||
let (ocr_db, config, should_sample) = {
|
||||
let locked = state.lock().unwrap();
|
||||
(
|
||||
locked.ocr_db.clone(),
|
||||
locked.config.clone(),
|
||||
locked.ocr_cache.clone(),
|
||||
locked.should_sample_ocr_data
|
||||
)
|
||||
};
|
||||
let ocr_results = remote_ocr::ocr_all_regions(&frame, config.clone(), ocr_cache, should_sample);
|
||||
let ocr_results = ocr_all_regions(ocr_db.as_ref(), &frame, config.as_ref(), should_sample);
|
||||
|
||||
if state.lock().unwrap().debug_frames {
|
||||
let debug_frames = save_frames_from(&frame, config.as_ref(), &ocr_results);
|
||||
|
@ -178,7 +203,7 @@ fn run_loop_once(capturer: &mut Capturer, state: &SharedAppState) -> Result<()>
|
|||
pub fn save_frames_from(
|
||||
frame: &RgbImage,
|
||||
config: &Config,
|
||||
ocr_results: &HashMap<String, Option<String>>,
|
||||
ocr_results: &HashMap<String, String>,
|
||||
) -> HashMap<String, DebugOcrFrame> {
|
||||
let mut saved_frames = HashMap::new();
|
||||
for region in &config.ocr_regions {
|
||||
|
|
|
@ -9,9 +9,6 @@ use crate::image_processing::Region;
|
|||
pub struct Config {
|
||||
pub ocr_regions: Vec<Region>,
|
||||
pub track_region: Option<Region>,
|
||||
pub ocr_server_endpoint: String,
|
||||
pub filter_threshold: Option<f64>,
|
||||
pub use_ocr_cache: Option<bool>,
|
||||
pub ocr_interval_ms: Option<u64>,
|
||||
pub track_recognition_threshold: Option<u32>,
|
||||
pub dump_frame_fraction: Option<f64>,
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
use image::{codecs::png::PngEncoder, ColorType, ImageEncoder, Rgb, RgbImage};
|
||||
use img_hash::ImageHash;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize)]
|
||||
|
@ -98,7 +99,7 @@ pub fn from_png_bytes(bytes: &[u8]) -> RgbImage {
|
|||
image.to_rgb8()
|
||||
}
|
||||
|
||||
pub fn hash_image(image: &RgbImage) -> String {
|
||||
pub fn hash_image(image: &RgbImage) -> ImageHash {
|
||||
let hasher = img_hash::HasherConfig::new()
|
||||
.hash_alg(img_hash::HashAlg::Mean)
|
||||
.hash_size(16, 16)
|
||||
|
@ -107,5 +108,5 @@ pub fn hash_image(image: &RgbImage) -> String {
|
|||
img_hash::image::RgbImage::from_raw(image.width(), image.height(), image.as_raw().to_vec())
|
||||
.unwrap();
|
||||
let hash = hasher.hash_image(&have_to_use_other_image_library_version);
|
||||
hash.to_base64()
|
||||
hash
|
||||
}
|
||||
|
|
|
@ -47,5 +47,5 @@ impl LearnedTracks {
|
|||
pub fn get_track_hash(config: &Config, image: &RgbImage) -> Option<String> {
|
||||
let track_region = config.track_region.as_ref()?;
|
||||
let extracted = extract_and_filter(image, track_region);
|
||||
Some(hash_image(&extracted))
|
||||
Some(hash_image(&extracted).to_base64())
|
||||
}
|
||||
|
|
52
src/main.rs
52
src/main.rs
|
@ -5,11 +5,11 @@ mod capture;
|
|||
mod config;
|
||||
mod image_processing;
|
||||
mod ocr;
|
||||
mod remote_ocr;
|
||||
mod state;
|
||||
mod stats_writer;
|
||||
mod training_ui;
|
||||
mod learned_tracks;
|
||||
mod ocr_db;
|
||||
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
|
@ -30,7 +30,8 @@ use eframe::{
|
|||
use egui_extras::RetainedImage;
|
||||
use image_processing::{from_png_bytes, to_png_bytes};
|
||||
use learned_tracks::LearnedTracks;
|
||||
use state::{AppState, DebugOcrFrame, LapState, OcrCache, RaceState, SharedAppState};
|
||||
use ocr_db::OcrDatabase;
|
||||
use state::{AppState, DebugOcrFrame, LapState, RaceState, SharedAppState};
|
||||
use stats_writer::export_race_stats;
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
|
@ -41,6 +42,7 @@ fn main() -> anyhow::Result<()> {
|
|||
let app_state = AppState {
|
||||
config: Arc::new(Config::load().unwrap()),
|
||||
learned_tracks: Arc::new(LearnedTracks::load().unwrap()),
|
||||
ocr_db: Arc::new(OcrDatabase::load().unwrap()),
|
||||
..Default::default()
|
||||
};
|
||||
let state = Arc::new(Mutex::new(app_state));
|
||||
|
@ -174,8 +176,8 @@ fn show_race_state(
|
|||
ui_state: &mut UiState,
|
||||
race_name: &str,
|
||||
race: &mut RaceState,
|
||||
config: Arc<Config>,
|
||||
ocr_cache: Arc<OcrCache>,
|
||||
config: &Config,
|
||||
ocr_db: &OcrDatabase,
|
||||
) {
|
||||
egui::Grid::new(format!("race:{}", race_name)).show(ui, |ui| {
|
||||
ui.label("Lap");
|
||||
|
@ -228,9 +230,9 @@ fn show_race_state(
|
|||
if ui.button("Debug").clicked() {
|
||||
open_debug_lap(
|
||||
ui_state,
|
||||
ocr_db,
|
||||
lap,
|
||||
config.clone(),
|
||||
ocr_cache.clone(),
|
||||
config,
|
||||
)
|
||||
}
|
||||
|
||||
|
@ -241,10 +243,8 @@ fn show_race_state(
|
|||
});
|
||||
}
|
||||
|
||||
fn show_debug_frames(ui: &mut Ui, debug_frames: &HashMap<String, DebugOcrFrame>) {
|
||||
let mut screenshots_sorted: Vec<_> = debug_frames.iter().collect();
|
||||
screenshots_sorted.sort_by_key(|(name, _)| *name);
|
||||
for (name, debug_image) in screenshots_sorted {
|
||||
fn show_debug_frames(ui: &mut Ui, debug_frames: &mut HashMap<String, DebugOcrFrame>) {
|
||||
for (name, debug_image) in debug_frames.iter_mut() {
|
||||
ui.label(name);
|
||||
if let Some(text) = &debug_image.recognized_text {
|
||||
ui.label(text);
|
||||
|
@ -280,13 +280,13 @@ fn show_config_controls(ui: &mut Ui, ui_state: &mut UiState, state: &mut AppStat
|
|||
|
||||
fn open_debug_lap(
|
||||
ui_state: &mut UiState,
|
||||
ocr_db: &OcrDatabase,
|
||||
lap: &LapState,
|
||||
config: Arc<Config>,
|
||||
ocr_cache: Arc<OcrCache>,
|
||||
config: &Config,
|
||||
) {
|
||||
if let Some(screenshot_bytes) = &lap.screenshot {
|
||||
let screenshot = from_png_bytes(screenshot_bytes);
|
||||
let ocr_results = remote_ocr::ocr_all_regions(&screenshot, config.clone(), ocr_cache, false);
|
||||
let ocr_results = analysis::ocr_all_regions(ocr_db, &screenshot, config, false);
|
||||
let debug_lap = DebugLap {
|
||||
screenshot: RetainedImage::from_image_bytes("debug-lap", &to_png_bytes(&screenshot))
|
||||
.unwrap(),
|
||||
|
@ -308,6 +308,7 @@ fn show_combo_box(ui: &mut Ui, name: &str, label: &str, options: &[String], valu
|
|||
impl eframe::App for AppUi {
|
||||
fn update(&mut self, ctx: &egui::Context, _frame: &mut eframe::Frame) {
|
||||
let mut state = self.state.lock().unwrap();
|
||||
let ocr_db = state.ocr_db.clone();
|
||||
|
||||
let mut debug_lap_window = self.ui_state.debug_lap.is_some();
|
||||
let window = egui::Window::new("Debug Lap").open(&mut debug_lap_window);
|
||||
|
@ -318,8 +319,8 @@ impl eframe::App for AppUi {
|
|||
.screenshot
|
||||
.show_max_size(ui, Vec2::new(800.0, 600.0));
|
||||
ui.separator();
|
||||
if let Some(debug_lap) = &self.ui_state.debug_lap {
|
||||
show_debug_frames(ui, &debug_lap.debug_regions);
|
||||
if let Some(debug_lap) = &mut self.ui_state.debug_lap {
|
||||
show_debug_frames(ui, &mut debug_lap.debug_regions);
|
||||
}
|
||||
show_config_controls(ui, &mut self.ui_state, state.deref_mut());
|
||||
}
|
||||
|
@ -380,16 +381,17 @@ impl eframe::App for AppUi {
|
|||
|
||||
ui.separator();
|
||||
ui.checkbox(&mut state.debug_frames, "Debug OCR regions");
|
||||
ui.checkbox(
|
||||
&mut state.should_sample_ocr_data,
|
||||
"Dump OCR training frames",
|
||||
);
|
||||
if state.config.dump_frame_fraction.is_some() {
|
||||
ui.checkbox(
|
||||
&mut state.should_sample_ocr_data,
|
||||
"Dump OCR training frames",
|
||||
);
|
||||
}
|
||||
});
|
||||
egui::CentralPanel::default().show(ctx, |ui| {
|
||||
egui::ScrollArea::vertical().show(ui, |ui| {
|
||||
let config = state.config.clone();
|
||||
let _learned_tracks = state.learned_tracks.clone();
|
||||
let ocr_cache = state.ocr_cache.clone();
|
||||
if let Some(race) = &mut state.current_race {
|
||||
ui.heading(&format!("Current Race: {}", race.name()));
|
||||
show_race_state(
|
||||
|
@ -397,8 +399,8 @@ impl eframe::App for AppUi {
|
|||
&mut self.ui_state,
|
||||
"current",
|
||||
race,
|
||||
config.clone(),
|
||||
ocr_cache.clone(),
|
||||
config.as_ref(),
|
||||
ocr_db.as_ref(),
|
||||
);
|
||||
}
|
||||
let len = state.past_races.len();
|
||||
|
@ -412,8 +414,8 @@ impl eframe::App for AppUi {
|
|||
&mut self.ui_state,
|
||||
&format!("race {}:", i),
|
||||
race,
|
||||
config.clone(),
|
||||
ocr_cache.clone(),
|
||||
config.as_ref(),
|
||||
ocr_db.as_ref(),
|
||||
);
|
||||
if let Some(img) = &race.screencap {
|
||||
img.show_max_size(ui, Vec2::new(600.0, 500.0));
|
||||
|
@ -471,7 +473,7 @@ impl eframe::App for AppUi {
|
|||
if state.debug_frames {
|
||||
egui::SidePanel::right("screenshots").show(ctx, |ui| {
|
||||
egui::ScrollArea::vertical().show(ui, |ui| {
|
||||
show_debug_frames(ui, &state.saved_frames);
|
||||
show_debug_frames(ui, &mut state.saved_frames);
|
||||
show_config_controls(ui, &mut self.ui_state, state.deref_mut());
|
||||
});
|
||||
});
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
use image::{Rgb, RgbImage};
|
||||
use img_hash::ImageHash;
|
||||
|
||||
use crate::image_processing;
|
||||
|
||||
|
@ -102,7 +103,7 @@ pub fn bounding_box_images(image: &RgbImage) -> Vec<RgbImage> {
|
|||
trimmed
|
||||
}
|
||||
|
||||
pub fn compute_box_hashes(image: &RgbImage) -> Vec<String> {
|
||||
pub fn compute_box_hashes(image: &RgbImage) -> Vec<ImageHash> {
|
||||
let mut hashes = Vec::new();
|
||||
|
||||
let boxes = get_character_bounding_boxes(image);
|
||||
|
|
|
@ -0,0 +1,100 @@
|
|||
use std::{collections::HashMap, sync::{Arc, RwLock}};
|
||||
|
||||
use crate::{
|
||||
config::{load_config_or_make_default, save_json_config, Config},
|
||||
image_processing::{extract_and_filter, hash_image}, ocr,
|
||||
};
|
||||
|
||||
use anyhow::Result;
|
||||
use image::RgbImage;
|
||||
use img_hash::ImageHash;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Default, Serialize, Deserialize, Clone)]
|
||||
struct RawOcrDatabase {
|
||||
learned_chars: HashMap<String, String>,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct OcrDatabase {
|
||||
learned_chars: RwLock<Vec<(ImageHash, char)>>,
|
||||
}
|
||||
|
||||
impl From<&RawOcrDatabase> for OcrDatabase {
|
||||
fn from(raw: &RawOcrDatabase) -> Self {
|
||||
let db = Self::default();
|
||||
{
|
||||
let mut state = db.learned_chars.write().unwrap();
|
||||
for (hash, s) in &raw.learned_chars {
|
||||
state.push((
|
||||
ImageHash::from_base64(&hash).unwrap(),
|
||||
s.chars().nth(0).unwrap(),
|
||||
));
|
||||
}
|
||||
}
|
||||
db
|
||||
}
|
||||
}
|
||||
impl From<&OcrDatabase> for RawOcrDatabase {
|
||||
fn from(db: &OcrDatabase) -> Self {
|
||||
Self {
|
||||
learned_chars: db
|
||||
.learned_chars
|
||||
.read()
|
||||
.unwrap()
|
||||
.iter()
|
||||
.map(|(hash, c)| (hash.to_base64(), c.to_string()))
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl OcrDatabase {
|
||||
pub fn load() -> Result<Self> {
|
||||
let raw: RawOcrDatabase =
|
||||
load_config_or_make_default("ocr.json", include_str!("configs/learned.default.json"))?;
|
||||
Ok((&raw).into())
|
||||
}
|
||||
pub fn save(&self) -> Result<()> {
|
||||
let raw: RawOcrDatabase = self.into();
|
||||
save_json_config("ocr.json", &raw)
|
||||
}
|
||||
|
||||
pub fn learn(&self, hash: &str, val: char) {
|
||||
self.learned_chars.write().unwrap().push((ImageHash::from_base64(hash).unwrap(), val));
|
||||
}
|
||||
|
||||
pub fn learn_phrase(&self, hashes: &[ImageHash], phrase: &str) -> Result<()> {
|
||||
if phrase.len() > hashes.len() {
|
||||
anyhow::bail!("too many characters detected in OCR result for learned phrase")
|
||||
}
|
||||
if phrase.len() < hashes.len() - 1 {
|
||||
anyhow::bail!("too few characters detected in OCR result for learned phrase")
|
||||
}
|
||||
if !phrase.is_ascii() {
|
||||
anyhow::bail!("cannot learn non-ASCII characters")
|
||||
}
|
||||
|
||||
let chars: Vec<char> = phrase.chars().collect();
|
||||
for (i, hash) in hashes.iter().enumerate() {
|
||||
self.learn(&hash.to_base64(), chars[i]);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn ocr_char(&self, hash: &ImageHash) -> Option<char> {
|
||||
let state = self.learned_chars.read().unwrap();
|
||||
let (_, c) = state.iter().min_by_key(|(learned_hash, _)| hash.dist(learned_hash))?;
|
||||
Some(*c)
|
||||
}
|
||||
|
||||
pub fn ocr_hashes(&self, hashes: &[ImageHash]) -> String {
|
||||
let buffer: String = hashes.iter().filter_map(|hash| self.ocr_char(hash)).collect();
|
||||
buffer.trim_end_matches(|c: char| !c.is_alphanumeric()).to_owned()
|
||||
}
|
||||
|
||||
pub fn ocr_image(&self, image: &RgbImage) -> String {
|
||||
let hashes = ocr::compute_box_hashes(image);
|
||||
self.ocr_hashes(&hashes)
|
||||
}
|
||||
}
|
|
@ -1,121 +0,0 @@
|
|||
use std::{
|
||||
collections::HashMap,
|
||||
sync::{Arc, Mutex, RwLock},
|
||||
};
|
||||
|
||||
use anyhow::Result;
|
||||
use image::RgbImage;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{
|
||||
config::Config,
|
||||
image_processing::{extract_and_filter, hash_image},
|
||||
state::OcrCache,
|
||||
};
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct OcrRegion {
|
||||
pub confidence: f64,
|
||||
pub value: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
struct OcrResult {
|
||||
regions: Vec<OcrRegion>,
|
||||
error: Option<String>,
|
||||
}
|
||||
|
||||
async fn run_ocr(image: &RgbImage, url: &str) -> Result<Option<String>> {
|
||||
let client = reqwest::Client::new();
|
||||
let response = client
|
||||
.post(url)
|
||||
.body(crate::image_processing::to_png_bytes(image))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
eprintln!("failed to run OCR query");
|
||||
anyhow::bail!("failed to run OCR query")
|
||||
}
|
||||
let result: OcrResult = response.json().await?;
|
||||
let result = if result.regions.is_empty() {
|
||||
None
|
||||
} else {
|
||||
let mut buffer = String::new();
|
||||
for r in &result.regions {
|
||||
buffer += &r.value;
|
||||
}
|
||||
Some(buffer)
|
||||
};
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
async fn run_ocr_cached(
|
||||
ocr_cache: Arc<RwLock<HashMap<String, Option<String>>>>,
|
||||
hash: String,
|
||||
region: &crate::image_processing::Region,
|
||||
config: Arc<Config>,
|
||||
filtered_image: &image::ImageBuffer<image::Rgb<u8>, Vec<u8>>,
|
||||
) -> Option<String> {
|
||||
let cached = {
|
||||
let locked = ocr_cache.read().unwrap();
|
||||
locked.get(&hash).cloned()
|
||||
};
|
||||
let use_cache = region.use_ocr_cache.unwrap_or(true) && config.use_ocr_cache.unwrap_or(true);
|
||||
if let Some(cached) = cached {
|
||||
if use_cache {
|
||||
return cached;
|
||||
}
|
||||
}
|
||||
match run_ocr(filtered_image, &config.ocr_server_endpoint).await {
|
||||
Ok(v) => {
|
||||
if use_cache {
|
||||
ocr_cache.write().unwrap().insert(hash.clone(), v.clone());
|
||||
}
|
||||
v
|
||||
}
|
||||
Err(_) => None,
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::main(flavor = "current_thread")]
|
||||
pub async fn ocr_all_regions(
|
||||
image: &RgbImage,
|
||||
config: Arc<Config>,
|
||||
ocr_cache: Arc<OcrCache>,
|
||||
should_sample: bool,
|
||||
) -> HashMap<String, Option<String>> {
|
||||
let results = Arc::new(Mutex::new(HashMap::new()));
|
||||
|
||||
let mut handles = Vec::new();
|
||||
for region in &config.ocr_regions {
|
||||
let filtered_image = extract_and_filter(image, region);
|
||||
let region = region.clone();
|
||||
let results = results.clone();
|
||||
let config = config.clone();
|
||||
let ocr_cache = ocr_cache.clone();
|
||||
handles.push(tokio::spawn(async move {
|
||||
let filtered_image = filtered_image;
|
||||
let hash = hash_image(&filtered_image);
|
||||
let value =
|
||||
run_ocr_cached(ocr_cache, hash, ®ion, config.clone(), &filtered_image).await;
|
||||
|
||||
if let Some(sample_fraction) = &config.dump_frame_fraction {
|
||||
if rand::random::<f64>() < *sample_fraction && should_sample {
|
||||
let file_id = rand::random::<usize>();
|
||||
let img_filename = format!("ocr_data/{}.png", file_id);
|
||||
filtered_image.save(img_filename).unwrap();
|
||||
let value_filename = format!("ocr_data/{}.txt", file_id);
|
||||
std::fs::write(value_filename, value.clone().unwrap_or_default()).unwrap();
|
||||
}
|
||||
}
|
||||
results.lock().unwrap().insert(region.name, value);
|
||||
}));
|
||||
}
|
||||
for handle in handles {
|
||||
handle.await.expect("failed to join task in OCR");
|
||||
}
|
||||
|
||||
let results = results.lock().unwrap().clone();
|
||||
results
|
||||
}
|
18
src/state.rs
18
src/state.rs
|
@ -4,7 +4,7 @@ use egui_extras::RetainedImage;
|
|||
use image::RgbImage;
|
||||
use time::{OffsetDateTime, format_description};
|
||||
|
||||
use crate::{config::Config, learned_tracks::LearnedTracks};
|
||||
use crate::{config::Config, learned_tracks::LearnedTracks, ocr_db::OcrDatabase};
|
||||
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
|
@ -33,8 +33,8 @@ fn parse_duration(time: &str) -> Option<Duration> {
|
|||
Some(Duration::from_secs_f64(60.0 * minutes + secs))
|
||||
}
|
||||
|
||||
fn parse_to_duration(time: Option<&Option<String>>) -> Option<Duration> {
|
||||
parse_duration(&time?.clone()?)
|
||||
fn parse_to_duration(time: Option<&String>) -> Option<Duration> {
|
||||
parse_duration(&time?)
|
||||
}
|
||||
|
||||
fn check_0_100(v: usize) -> Option<usize> {
|
||||
|
@ -45,12 +45,12 @@ fn check_0_100(v: usize) -> Option<usize> {
|
|||
}
|
||||
}
|
||||
|
||||
fn parse_to_0_100(v: Option<&Option<String>>) -> Option<usize> {
|
||||
check_0_100(v?.clone()?.parse::<usize>().ok()?)
|
||||
fn parse_to_0_100(v: Option<&String>) -> Option<usize> {
|
||||
check_0_100(v?.parse::<usize>().ok()?)
|
||||
}
|
||||
|
||||
impl LapState {
|
||||
pub fn parse(raw: &HashMap<String, Option<String>>) -> Self {
|
||||
pub fn parse(raw: &HashMap<String, String>) -> Self {
|
||||
Self {
|
||||
lap: parse_to_0_100(raw.get("lap")),
|
||||
health: parse_to_0_100(raw.get("health")),
|
||||
|
@ -126,10 +126,9 @@ pub struct DebugOcrFrame {
|
|||
pub recognized_text: Option<String>,
|
||||
}
|
||||
|
||||
pub type OcrCache = RwLock<HashMap<String, Option<String>>>;
|
||||
#[derive(Default)]
|
||||
pub struct AppState {
|
||||
pub raw_data: HashMap<String, Option<String>>,
|
||||
pub raw_data: HashMap<String, String>,
|
||||
pub last_frame: Option<LapState>,
|
||||
|
||||
pub buffered_frames: VecDeque<LapState>,
|
||||
|
@ -144,8 +143,7 @@ pub struct AppState {
|
|||
|
||||
pub config: Arc<Config>,
|
||||
pub learned_tracks: Arc<LearnedTracks>,
|
||||
|
||||
pub ocr_cache: Arc<OcrCache>,
|
||||
pub ocr_db: Arc<OcrDatabase>,
|
||||
}
|
||||
|
||||
pub type SharedAppState = Arc<Mutex<AppState>>;
|
Binary file not shown.
After Width: | Height: | Size: 707 B |
|
@ -9,8 +9,9 @@ use eframe::{
|
|||
};
|
||||
use egui_extras::RetainedImage;
|
||||
use image::RgbImage;
|
||||
use img_hash::ImageHash;
|
||||
|
||||
use crate::{image_processing::to_png_bytes, ocr};
|
||||
use crate::{image_processing::to_png_bytes, ocr, ocr_db::OcrDatabase};
|
||||
|
||||
#[derive(Default)]
|
||||
struct TrainingUi {
|
||||
|
@ -18,7 +19,7 @@ struct TrainingUi {
|
|||
|
||||
current_image_index: usize,
|
||||
|
||||
learned_char_hashes: Vec<(String, char)>,
|
||||
ocr_db: OcrDatabase,
|
||||
}
|
||||
|
||||
struct TrainingImage {
|
||||
|
@ -28,7 +29,7 @@ struct TrainingImage {
|
|||
|
||||
ui_image: RetainedImage,
|
||||
char_images: Vec<RetainedImage>,
|
||||
char_hashes: Vec<String>,
|
||||
char_hashes: Vec<ImageHash>,
|
||||
}
|
||||
|
||||
impl TrainingImage {
|
||||
|
@ -55,16 +56,6 @@ fn get_training_data_paths() -> Vec<(PathBuf, PathBuf)> {
|
|||
data_paths
|
||||
}
|
||||
|
||||
fn predict_ocr(hashes: &[(String, char)], hash: &str) -> Option<char> {
|
||||
let hash = img_hash::ImageHash::<Vec<u8>>::from_base64(hash).unwrap();
|
||||
let (_, best_char) = hashes.iter().min_by_key(|(learned_hash, _c)| {
|
||||
img_hash::ImageHash::from_base64(learned_hash)
|
||||
.unwrap()
|
||||
.dist(&hash)
|
||||
})?;
|
||||
Some(*best_char)
|
||||
}
|
||||
|
||||
fn get_training_data() -> Vec<TrainingImage> {
|
||||
let mut data = Vec::new();
|
||||
for (img_file, ocr_file) in get_training_data_paths() {
|
||||
|
@ -90,24 +81,6 @@ fn get_training_data() -> Vec<TrainingImage> {
|
|||
data
|
||||
}
|
||||
|
||||
fn load_learned_hashes() -> Vec<(String, char)> {
|
||||
let path = Path::new("learned_chars.txt");
|
||||
if !path.exists() {
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
let data = String::from_utf8(std::fs::read(path).unwrap()).unwrap();
|
||||
let mut parsed = Vec::new();
|
||||
for line in data.lines() {
|
||||
if let Some((c, hash)) = line.split_once(' ') {
|
||||
if let Some(c) = c.chars().next() {
|
||||
parsed.push((hash.to_owned(), c));
|
||||
}
|
||||
}
|
||||
}
|
||||
parsed
|
||||
}
|
||||
|
||||
fn load_retained_image(image: &RgbImage) -> RetainedImage {
|
||||
RetainedImage::from_image_bytes("", &to_png_bytes(image)).unwrap()
|
||||
}
|
||||
|
@ -116,7 +89,7 @@ pub fn training_ui() -> anyhow::Result<()> {
|
|||
let options = eframe::NativeOptions::default();
|
||||
let state = TrainingUi {
|
||||
training_images: get_training_data(),
|
||||
learned_char_hashes: load_learned_hashes(),
|
||||
ocr_db: OcrDatabase::load().unwrap(),
|
||||
..Default::default()
|
||||
};
|
||||
eframe::run_native("OCR Trainer", options, Box::new(|_cc| Box::new(state)));
|
||||
|
@ -139,46 +112,22 @@ impl eframe::App for TrainingUi {
|
|||
current_image.save_ocr_text();
|
||||
}
|
||||
if ui.button("Learn").clicked() {
|
||||
for (i, char) in current_image.text.chars().enumerate() {
|
||||
if let Some(hash) = current_image.char_hashes.get(i) {
|
||||
self.learned_char_hashes.push((hash.clone(), char));
|
||||
eprintln!("Learned {}={}", hash, char);
|
||||
}
|
||||
}
|
||||
self.ocr_db.learn_phrase(¤t_image.char_hashes, ¤t_image.text).unwrap();
|
||||
self.current_image_index += 1;
|
||||
}
|
||||
if ui.button("Learn and delete").clicked() {
|
||||
for (i, char) in current_image.text.chars().enumerate() {
|
||||
if let Some(hash) = current_image.char_hashes.get(i) {
|
||||
self.learned_char_hashes.push((hash.clone(), char));
|
||||
eprintln!("Learned {}={}", hash, char);
|
||||
}
|
||||
}
|
||||
self.ocr_db.learn_phrase(¤t_image.char_hashes, ¤t_image.text).unwrap();
|
||||
current_image.delete_data();
|
||||
self.current_image_index += 1;
|
||||
}
|
||||
if ui.button("Save learned results").clicked() {
|
||||
let mut buffer = String::new();
|
||||
for (hash, c) in &self.learned_char_hashes {
|
||||
buffer += &format!("{} {}\n", c, hash);
|
||||
}
|
||||
let mut file = std::fs::OpenOptions::new()
|
||||
.append(true)
|
||||
.create(true)
|
||||
.write(true)
|
||||
.open("learned_chars.txt")
|
||||
.unwrap();
|
||||
file.write_all(buffer.as_bytes()).unwrap();
|
||||
self.ocr_db.save().unwrap();
|
||||
}
|
||||
|
||||
current_image.ui_image.show(ui);
|
||||
ui.label("OCR value");
|
||||
ui.text_edit_singleline(&mut current_image.text);
|
||||
let predicted: String = current_image
|
||||
.char_hashes
|
||||
.iter()
|
||||
.filter_map(|hash| predict_ocr(&self.learned_char_hashes, hash))
|
||||
.collect();
|
||||
let predicted: String = self.ocr_db.ocr_hashes(¤t_image.char_hashes);
|
||||
if predicted == current_image.text {
|
||||
ui.colored_label(Color32::GREEN, format!("Predicted: {}", predicted));
|
||||
} else {
|
||||
|
@ -189,8 +138,8 @@ impl eframe::App for TrainingUi {
|
|||
c.show(ui);
|
||||
}
|
||||
for c in ¤t_image.char_hashes {
|
||||
ui.label(c);
|
||||
if let Some(predicted) = predict_ocr(&self.learned_char_hashes, c) {
|
||||
ui.label(c.to_base64());
|
||||
if let Some(predicted) = self.ocr_db.ocr_char(c) {
|
||||
ui.label(format!("Predicted: {}", predicted));
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue