local OCR working

This commit is contained in:
Scott Pruett 2022-06-03 22:18:46 -04:00
parent 8baad566aa
commit e195158218
14 changed files with 294 additions and 237 deletions

View File

@ -33,15 +33,18 @@
"x": 1744,
"y": 127,
"width": 137,
"height": 32
"height": 32,
"threshold": 0.88
},
{
"name": "lap_time",
"x": 1744,
"y": 166,
"width": 137,
"height": 32
"height": 32,
"threshold": 0.88
}
],
"track_recognition_threshold": 10,
"ocr_server_endpoint": "https://tesserver.spruett.dev/"
}

101
ocr.json Normal file
View File

@ -0,0 +1,101 @@
{
"learned_chars": {
"//8P+Afw4+fzx/HP+c/5z/nP8c/zx/Pnx/MP+H/+//8=": "0",
"//8H4Afg/+P/4//j/+P/4//j/+P/4//j/+P/4//j//8=": "1",
"//8f+A/g5+fzz/PP88/zn/PP88/zz+fP5+Mf8H/+//8=": "0",
"/////P/8f/4//p//3//H++P78/sBwAGA//v/+//7//8=": "4",
"//8P4A/gD+AP/w//D/gP4AfA/8P/w4fAA+AD8B/8//8=": "5",
"//8P4A/gD+AP/4//D/gH4Afg/8H/wYfgA+AD8B/8//8=": "5",
"//9/4H/wP/Af+A/8B/4DggODAYABgAGAAYD/g/+D//8=": "4",
"/////f/8f/w//5//j//H++f78/MBgP/z//v/+//7//8=": "4",
"/////wfwB+DD4cPBA+AP8Afgw8PDw4PBA+AH8P////8=": "8",
"//8BgAGAAYDhgeGB/8D/wH/gP+A/8B/wH/gP+A/8//8=": "7",
"////z/+fH/4H+AP4w/Dh8eHx4fHh8cPwA/gP/B/+//8=": "0",
"//8H4Afg/+H/4f/h/+H/4f/h/+H/4f/h/+H/4f/h//8=": "1",
"//8P+Afw4+fz5/HP+c/5z/nP+c/zx/Pnx/MP+H////8=": "0",
"///f/w/wB+DH48fhB+AH4APgw8PDw8PDB+AP8P////8=": "8",
"/////w/wB+DH48fjB+AH4Afgw8PDw8PDB+AP8P////8=": "8",
"//8f/Efx8+fz7/vP88fjww/I/8//z//n/+MH+A/+//8=": "9",
"//8/4I/g5//z//P/O/wL4OPH88/zz/fPx8cP4H/8//8=": "6",
"//8P8AfgA8CDgcGDwYPBg8GDwYPBg4OBA8AH4A/w//8=": "0",
"//8f+Afi88/zz/PPx+EP8OPH88/73/Pf48cH4H/+//8=": "8",
"/////wPwA+DD4P/h/+H/4H/wP/gP/gf+A8ADwP////8=": "2",
"//8f+A/g5+fzz/PP88/zz/PP88/zz+fH5+cP8H/+//8=": "0",
"//8D4APg//P/+X/8P/4/8P/H/8//z//P8+cD8D////8=": "3",
"/////wPAA8DDwcPBw+H/4P/wf/A/8D/4P/w/+P////8=": "7",
"//8P+APw++f/7//v/+f/4//x//w//h//z/8DwAPA//8=": "2",
"/////wPAA8Dj4OPg4+B/8H/wP/g/+D/8H/wf/P////8=": "7",
"//8/+A/w5+fjz/PP88/zz/PP88/zz+PP5+cP8D/8//8=": "0",
"//8P+CPw8+f77/vP88fjwQ/M/8//z//n//EH+D////8=": "9",
"//8HwAfA//P/8f/8f/5/8P/H/8//3//P88cD8H/+//8=": "3",
"//8DwAPA+8/75//n//P/8//5//n//H/+f/4//z////8=": "7",
"//8/4Ifg5//z//P/O/wL4OPH88/zz/PP5+cP4H/8//8=": "6",
"//+//Q/4B/CD4cPjw8PDw+PDw+PD4QfgD/Af+H/+//8=": "0",
"//9/7x/AD+AH74f/g/AD4APAg8OHwwfDD8Af4H/7//8=": "6",
"//8f+Afw++f/7//v/+f/4//4f/w//4//x/8DwAPA//8=": "2",
"///f/Q/wB+CH4cPDw8PDw8PDw8PDwwfgB/AP+L/9//8=": "0",
"//9/7x/gD+AH/4P/g/gD4APAg8ODw4fBB+Af8H/9//8=": "6",
"//8H4Afi//P/+P/8f/4/8P/j/8f/z//H8+MD8D/+//8=": "3",
"//8P4A/g7//n/+f/B/wH8P/H/8//3//P88cD8H/+//8=": "5",
"//8DwAPA88/z5//j/+P/8f/5//j//H/+P/4//h////8=": "7",
"//8P8A/wD/CH/4f/B/gH8Afg/8H/wYfgA+AD+B/+//8=": "5",
"//8P8APgAcDBgcGDwYEBgAOAD4L/g//AA8AD4AP4//8=": "9",
"//8f+Afg5+fzz/PP+9/73/vf+9/zz+PP5+cP4B/4//8=": "0",
"/////wfgA+DHwf/B/8H/4H/gP/gf/A/+B8ADwP////8=": "2",
"//9/7x/AD+AH7of/g/AD4APAg8ODwwfDD+Af4H/7//8=": "6",
"/////wfgB+B/4H/gf+B/4H/gf+B/4H/gf+B/4P////8=": "1",
"//8BgAGAAYDhgeGB7cD/wH/gf+A/8D/wH/gP+A/87/8=": "7",
"////+z/4P/wf/g/+D/+H8cPwA8ADwAPAA8D/8P////8=": "4",
"//8f+Afg5+fzz/Pf+9/73/vf+9/zz+PHx+MP8H/+//8=": "0",
"//////////8//B/4H/gf+B/4H/gf+D/8//////////8=": ".",
"//8H4Afg//P/+f/8f/w/4P/H/8//z//P4+cH4D/8//8=": "3",
"//8H4A/g/+P/4//j/+P/4//j/+P/4//j/+P/4//j//8=": "1",
"//8D4APg//P/+X/8f/4/+P/j/8//z//P+8cD8B/8//8=": "3",
"//8DgAOAA8D/4H/wH/AfwB+A/4H/g+OBA4ABwAPw//8=": "3",
"///f+w/wB+CDwcPDw8MHwA/AP8P/w0/gB/AH+F////8=": "9",
"/////f/8f/x//j//n//P8+fz4/MBwAOA//P/8//z//8=": "4",
"//8HwAfA/+P/8//9f/x/8P/H/8//3//P888D4B/4//8=": "3",
"///f/Q/4B/DD48PDw8EHwA/AH8P/4U/gB/AH+F////8=": "9",
"//8/8A/lx//3//P/8/8b4MPH48/z3/PP58eH4D/8//8=": "6",
"/////w/wD/D//////////////////w/wD/AP8H/+//8=": ":",
"//8f/Efh4+fzz/PP88fnwwfY/9//z//v/+On8A/8//8=": "9",
"////////////+D/wP/Af8A/wD/AP8B/4//////////8=": ".",
"//8D4APg//P/+f/8P/4/8P/H/8//z//P88cD8H////8=": "3",
"/////wfgB+Af4H/gf+B/4H/gf+B/4H/gf+B/4P////8=": "1",
"/////wfgB+A/4D/gP+A/4D/gP+A/4D/gP+A/4P////8=": "1",
"//9/7x/AD+AH74f/g/AD4APAg8ODwwfDD8Af4H/7//8=": "6",
"//8f+Afi8+fzz/PP48fHww/Y/8//z//v/+MH+D////8=": "9",
"//8HwAfAB8CD/wP/A+ADwAOA/4H/g+OBA4ABwAPw//8=": "5",
"//8f+Efi5+fzz/vf+9/73/vf+9/zz+PHx+MP8D/8//8=": "0",
"/////wfwA+DD4P/h/+H/4H/wP/gf/Af+B8AHwP////8=": "2",
"//8D4APg//P/+X/8f/4/+P/j/8//z//P+8cD8A/8//8=": "3",
"//8f+Afg5+fzz/vf+9/73/vf+9/zz+PHx+MP8H/+//8=": "0",
"//8P/APw++P/5//n/+f/8//4f/w//4//x/8DwAPA//8=": "2",
"//8P4A/g7//v/+//B/wH4P/P/8//3//P88cD8H/+//8=": "5",
"//8/gA+AB4AD34P/geABgAGAAYOBh4GHA4AHgA/g//8=": "6",
"/////w/wD/D/////////////////////D/AP8P////8=": ":",
"//8f/Afw5+fz5/HP+c/5z/nP+c/xz/Pn5+cP8B/8//8=": "0",
"//9/7x/gD+AH/4P/g/gD4APAg8ODw4fBD+Af8H/9//8=": "6",
"//8BgAGAAYA/gD+AP4A/gD+AP4A/gD+AP4A/gD+A//8=": "1",
"///f/Q/4B/DD4cPDw8EDwA/AH+P/4W/gB/AH+F////8=": "9",
"//8DwAPA+8/75//n//P/8f/5//n//H/+f/4//z////8=": "7",
"//8/gA+AB4AD34H/geABgAGAAYOBh4GHA4AHgA/g//8=": "6",
"//8P+CPw8+f77/vP88fjwQ/M/8//z//n//EH+J////8=": "9",
"//8H8APAAcDDgP+B/4H/wH/gH/AP+Af8A4ADgAOA//8=": "2",
"//8f+A/g5+fzz/PP88/zz/PP88/zz+fH5+cf8H/+//8=": "0",
"///v/QfwA+DD4cPBw8EDwAfAn8H/4UfwB/AD/G////8=": "9",
"/////P/8f/4//p//3//n++P78/sBwAGA//v/+//7//8=": "4",
"/////////////wfgB+AH4AfgB+AH4P////////////8=": "-",
"//8H4Afg/+H/4f/h/+H/4f/h/+H/4f/h/+H/4f/j//8=": "1",
"//8H4AfgB+D/8H/4P/w/8D/g/8P/w4fAA+AD8B/8//8=": "3",
"//8P8APAA8CBgYGDA8AH4APAgYHBg8GDgYEDwAfg//8=": "8",
"//8P+APw++f/7//v/+f/8//x//w//h//z/8DwAPA//8=": "2",
"//8f+Afw4+fzx/HP+c/5z/nP8c/zx/Pnx/MP+H////8=": "0",
"//8HwAfAB8D/4H/wP/g/4D/A/8H/w4fAA8AD8B/8//8=": "3",
"//+//R/4D/CH4cPDw8PDw8PDw8PHwwfgD/Af+H/+//8=": "0",
"//+//R/wD+CH4cfDw8PDw8PDw8PHwwfgD/Af+H/+//8=": "0",
"//////////8//D/8H/gP8A/wH/g//D/8//////////8=": ".",
"////+3/4P/w//B/+D/8H8YfhA8ADwAPAA8D/4f////8=": "4",
"//////////8BgAGAAYABgAGAAYABgAGA//////////8=": "-"
}
}

View File

@ -178,3 +178,4 @@
2022-06-03-00:30 (Magdalena Club),Magdalena Club,60s GP,11,21.062,20.296,100,48,69,
2022-06-03-00:30 (Magdalena Club),Magdalena Club,60s GP,12,20.942,20.296,100,39,61,
2022-06-03-00:30 (Magdalena Club),Magdalena Club,60s GP,13,20.576,20.296,100,30,52,
2022-06-04-02:10,,,3,25.000,25.000,75,75,75,

1 2022-05-22-20:39,whistle valley,gp,1,22.349,22.349,100,,
178 2022-06-03-00:30 (Magdalena Club),Magdalena Club,60s GP,11,21.062,20.296,100,48,69,
179 2022-06-03-00:30 (Magdalena Club),Magdalena Club,60s GP,12,20.942,20.296,100,39,61,
180 2022-06-03-00:30 (Magdalena Club),Magdalena Club,60s GP,13,20.576,20.296,100,30,52,
181 2022-06-04-02:10,,,3,25.000,25.000,75,75,75,

View File

@ -1,7 +1,7 @@
use std::{
collections::HashMap,
thread,
time::{Duration, Instant, SystemTime},
time::{Duration, Instant, SystemTime}, sync::Arc,
};
use anyhow::Result;
@ -14,8 +14,7 @@ use crate::{
capture,
config::Config,
image_processing::{self, extract_and_filter, hash_image, Region, to_png_bytes},
remote_ocr,
state::{AppState, DebugOcrFrame, LapState, RaceState, SharedAppState}, learned_tracks::get_track_hash,
state::{AppState, DebugOcrFrame, LapState, RaceState, SharedAppState}, learned_tracks::get_track_hash, ocr_db::OcrDatabase,
};
@ -130,36 +129,62 @@ fn add_saved_frame(
saved_frames: &mut HashMap<String, DebugOcrFrame>,
frame: &RgbImage,
region: &Region,
ocr_results: &HashMap<String, Option<String>>,
ocr_results: &HashMap<String, String>,
) {
let extracted = extract_and_filter(frame, region);
let retained =
RetainedImage::from_image_bytes(&region.name, &image_processing::to_png_bytes(&extracted))
.unwrap();
let hash = hash_image(&extracted);
let hash = hash_image(&extracted).to_base64();
saved_frames.insert(
region.name.clone(),
DebugOcrFrame {
image: retained,
rgb_image: extracted,
img_hash: hash,
recognized_text: ocr_results.get(&region.name).and_then(|p| p.clone()),
recognized_text: ocr_results.get(&region.name).map(|p| p.clone()),
},
);
}
pub fn ocr_all_regions(
ocr_db: &OcrDatabase,
image: &RgbImage,
config: &Config,
should_sample: bool,
) -> HashMap<String, String> {
let mut results = HashMap::new();
for region in &config.ocr_regions {
let filtered_image = extract_and_filter(image, region);
let value = ocr_db.ocr_image(&filtered_image);
if let Some(sample_fraction) = &config.dump_frame_fraction {
if rand::random::<f64>() < *sample_fraction && should_sample {
let file_id = rand::random::<usize>();
let img_filename = format!("ocr_data/{}.png", file_id);
filtered_image.save(img_filename).unwrap();
let value_filename = format!("ocr_data/{}.txt", file_id);
std::fs::write(value_filename, value.clone()).unwrap();
}
}
results.insert(region.name.clone(), value);
}
results
}
fn run_loop_once(capturer: &mut Capturer, state: &SharedAppState) -> Result<()> {
let frame = capture::get_frame(capturer)?;
let (config, ocr_cache, should_sample) = {
let (ocr_db, config, should_sample) = {
let locked = state.lock().unwrap();
(
locked.ocr_db.clone(),
locked.config.clone(),
locked.ocr_cache.clone(),
locked.should_sample_ocr_data
)
};
let ocr_results = remote_ocr::ocr_all_regions(&frame, config.clone(), ocr_cache, should_sample);
let ocr_results = ocr_all_regions(ocr_db.as_ref(), &frame, config.as_ref(), should_sample);
if state.lock().unwrap().debug_frames {
let debug_frames = save_frames_from(&frame, config.as_ref(), &ocr_results);
@ -178,7 +203,7 @@ fn run_loop_once(capturer: &mut Capturer, state: &SharedAppState) -> Result<()>
pub fn save_frames_from(
frame: &RgbImage,
config: &Config,
ocr_results: &HashMap<String, Option<String>>,
ocr_results: &HashMap<String, String>,
) -> HashMap<String, DebugOcrFrame> {
let mut saved_frames = HashMap::new();
for region in &config.ocr_regions {

View File

@ -9,9 +9,6 @@ use crate::image_processing::Region;
pub struct Config {
pub ocr_regions: Vec<Region>,
pub track_region: Option<Region>,
pub ocr_server_endpoint: String,
pub filter_threshold: Option<f64>,
pub use_ocr_cache: Option<bool>,
pub ocr_interval_ms: Option<u64>,
pub track_recognition_threshold: Option<u32>,
pub dump_frame_fraction: Option<f64>,

View File

@ -1,4 +1,5 @@
use image::{codecs::png::PngEncoder, ColorType, ImageEncoder, Rgb, RgbImage};
use img_hash::ImageHash;
use serde::{Deserialize, Serialize};
#[derive(Clone, Deserialize, Serialize)]
@ -98,7 +99,7 @@ pub fn from_png_bytes(bytes: &[u8]) -> RgbImage {
image.to_rgb8()
}
pub fn hash_image(image: &RgbImage) -> String {
pub fn hash_image(image: &RgbImage) -> ImageHash {
let hasher = img_hash::HasherConfig::new()
.hash_alg(img_hash::HashAlg::Mean)
.hash_size(16, 16)
@ -107,5 +108,5 @@ pub fn hash_image(image: &RgbImage) -> String {
img_hash::image::RgbImage::from_raw(image.width(), image.height(), image.as_raw().to_vec())
.unwrap();
let hash = hasher.hash_image(&have_to_use_other_image_library_version);
hash.to_base64()
hash
}

View File

@ -47,5 +47,5 @@ impl LearnedTracks {
pub fn get_track_hash(config: &Config, image: &RgbImage) -> Option<String> {
let track_region = config.track_region.as_ref()?;
let extracted = extract_and_filter(image, track_region);
Some(hash_image(&extracted))
Some(hash_image(&extracted).to_base64())
}

View File

@ -5,11 +5,11 @@ mod capture;
mod config;
mod image_processing;
mod ocr;
mod remote_ocr;
mod state;
mod stats_writer;
mod training_ui;
mod learned_tracks;
mod ocr_db;
use std::{
collections::HashMap,
@ -30,7 +30,8 @@ use eframe::{
use egui_extras::RetainedImage;
use image_processing::{from_png_bytes, to_png_bytes};
use learned_tracks::LearnedTracks;
use state::{AppState, DebugOcrFrame, LapState, OcrCache, RaceState, SharedAppState};
use ocr_db::OcrDatabase;
use state::{AppState, DebugOcrFrame, LapState, RaceState, SharedAppState};
use stats_writer::export_race_stats;
fn main() -> anyhow::Result<()> {
@ -41,6 +42,7 @@ fn main() -> anyhow::Result<()> {
let app_state = AppState {
config: Arc::new(Config::load().unwrap()),
learned_tracks: Arc::new(LearnedTracks::load().unwrap()),
ocr_db: Arc::new(OcrDatabase::load().unwrap()),
..Default::default()
};
let state = Arc::new(Mutex::new(app_state));
@ -174,8 +176,8 @@ fn show_race_state(
ui_state: &mut UiState,
race_name: &str,
race: &mut RaceState,
config: Arc<Config>,
ocr_cache: Arc<OcrCache>,
config: &Config,
ocr_db: &OcrDatabase,
) {
egui::Grid::new(format!("race:{}", race_name)).show(ui, |ui| {
ui.label("Lap");
@ -228,9 +230,9 @@ fn show_race_state(
if ui.button("Debug").clicked() {
open_debug_lap(
ui_state,
ocr_db,
lap,
config.clone(),
ocr_cache.clone(),
config,
)
}
@ -241,10 +243,8 @@ fn show_race_state(
});
}
fn show_debug_frames(ui: &mut Ui, debug_frames: &HashMap<String, DebugOcrFrame>) {
let mut screenshots_sorted: Vec<_> = debug_frames.iter().collect();
screenshots_sorted.sort_by_key(|(name, _)| *name);
for (name, debug_image) in screenshots_sorted {
fn show_debug_frames(ui: &mut Ui, debug_frames: &mut HashMap<String, DebugOcrFrame>) {
for (name, debug_image) in debug_frames.iter_mut() {
ui.label(name);
if let Some(text) = &debug_image.recognized_text {
ui.label(text);
@ -280,13 +280,13 @@ fn show_config_controls(ui: &mut Ui, ui_state: &mut UiState, state: &mut AppStat
fn open_debug_lap(
ui_state: &mut UiState,
ocr_db: &OcrDatabase,
lap: &LapState,
config: Arc<Config>,
ocr_cache: Arc<OcrCache>,
config: &Config,
) {
if let Some(screenshot_bytes) = &lap.screenshot {
let screenshot = from_png_bytes(screenshot_bytes);
let ocr_results = remote_ocr::ocr_all_regions(&screenshot, config.clone(), ocr_cache, false);
let ocr_results = analysis::ocr_all_regions(ocr_db, &screenshot, config, false);
let debug_lap = DebugLap {
screenshot: RetainedImage::from_image_bytes("debug-lap", &to_png_bytes(&screenshot))
.unwrap(),
@ -308,6 +308,7 @@ fn show_combo_box(ui: &mut Ui, name: &str, label: &str, options: &[String], valu
impl eframe::App for AppUi {
fn update(&mut self, ctx: &egui::Context, _frame: &mut eframe::Frame) {
let mut state = self.state.lock().unwrap();
let ocr_db = state.ocr_db.clone();
let mut debug_lap_window = self.ui_state.debug_lap.is_some();
let window = egui::Window::new("Debug Lap").open(&mut debug_lap_window);
@ -318,8 +319,8 @@ impl eframe::App for AppUi {
.screenshot
.show_max_size(ui, Vec2::new(800.0, 600.0));
ui.separator();
if let Some(debug_lap) = &self.ui_state.debug_lap {
show_debug_frames(ui, &debug_lap.debug_regions);
if let Some(debug_lap) = &mut self.ui_state.debug_lap {
show_debug_frames(ui, &mut debug_lap.debug_regions);
}
show_config_controls(ui, &mut self.ui_state, state.deref_mut());
}
@ -380,16 +381,17 @@ impl eframe::App for AppUi {
ui.separator();
ui.checkbox(&mut state.debug_frames, "Debug OCR regions");
ui.checkbox(
&mut state.should_sample_ocr_data,
"Dump OCR training frames",
);
if state.config.dump_frame_fraction.is_some() {
ui.checkbox(
&mut state.should_sample_ocr_data,
"Dump OCR training frames",
);
}
});
egui::CentralPanel::default().show(ctx, |ui| {
egui::ScrollArea::vertical().show(ui, |ui| {
let config = state.config.clone();
let _learned_tracks = state.learned_tracks.clone();
let ocr_cache = state.ocr_cache.clone();
if let Some(race) = &mut state.current_race {
ui.heading(&format!("Current Race: {}", race.name()));
show_race_state(
@ -397,8 +399,8 @@ impl eframe::App for AppUi {
&mut self.ui_state,
"current",
race,
config.clone(),
ocr_cache.clone(),
config.as_ref(),
ocr_db.as_ref(),
);
}
let len = state.past_races.len();
@ -412,8 +414,8 @@ impl eframe::App for AppUi {
&mut self.ui_state,
&format!("race {}:", i),
race,
config.clone(),
ocr_cache.clone(),
config.as_ref(),
ocr_db.as_ref(),
);
if let Some(img) = &race.screencap {
img.show_max_size(ui, Vec2::new(600.0, 500.0));
@ -471,7 +473,7 @@ impl eframe::App for AppUi {
if state.debug_frames {
egui::SidePanel::right("screenshots").show(ctx, |ui| {
egui::ScrollArea::vertical().show(ui, |ui| {
show_debug_frames(ui, &state.saved_frames);
show_debug_frames(ui, &mut state.saved_frames);
show_config_controls(ui, &mut self.ui_state, state.deref_mut());
});
});

View File

@ -1,4 +1,5 @@
use image::{Rgb, RgbImage};
use img_hash::ImageHash;
use crate::image_processing;
@ -102,7 +103,7 @@ pub fn bounding_box_images(image: &RgbImage) -> Vec<RgbImage> {
trimmed
}
pub fn compute_box_hashes(image: &RgbImage) -> Vec<String> {
pub fn compute_box_hashes(image: &RgbImage) -> Vec<ImageHash> {
let mut hashes = Vec::new();
let boxes = get_character_bounding_boxes(image);

100
src/ocr_db.rs Normal file
View File

@ -0,0 +1,100 @@
use std::{collections::HashMap, sync::{Arc, RwLock}};
use crate::{
config::{load_config_or_make_default, save_json_config, Config},
image_processing::{extract_and_filter, hash_image}, ocr,
};
use anyhow::Result;
use image::RgbImage;
use img_hash::ImageHash;
use serde::{Deserialize, Serialize};
#[derive(Default, Serialize, Deserialize, Clone)]
struct RawOcrDatabase {
learned_chars: HashMap<String, String>,
}
#[derive(Default)]
pub struct OcrDatabase {
learned_chars: RwLock<Vec<(ImageHash, char)>>,
}
impl From<&RawOcrDatabase> for OcrDatabase {
fn from(raw: &RawOcrDatabase) -> Self {
let db = Self::default();
{
let mut state = db.learned_chars.write().unwrap();
for (hash, s) in &raw.learned_chars {
state.push((
ImageHash::from_base64(&hash).unwrap(),
s.chars().nth(0).unwrap(),
));
}
}
db
}
}
impl From<&OcrDatabase> for RawOcrDatabase {
fn from(db: &OcrDatabase) -> Self {
Self {
learned_chars: db
.learned_chars
.read()
.unwrap()
.iter()
.map(|(hash, c)| (hash.to_base64(), c.to_string()))
.collect(),
}
}
}
impl OcrDatabase {
pub fn load() -> Result<Self> {
let raw: RawOcrDatabase =
load_config_or_make_default("ocr.json", include_str!("configs/learned.default.json"))?;
Ok((&raw).into())
}
pub fn save(&self) -> Result<()> {
let raw: RawOcrDatabase = self.into();
save_json_config("ocr.json", &raw)
}
pub fn learn(&self, hash: &str, val: char) {
self.learned_chars.write().unwrap().push((ImageHash::from_base64(hash).unwrap(), val));
}
pub fn learn_phrase(&self, hashes: &[ImageHash], phrase: &str) -> Result<()> {
if phrase.len() > hashes.len() {
anyhow::bail!("too many characters detected in OCR result for learned phrase")
}
if phrase.len() < hashes.len() - 1 {
anyhow::bail!("too few characters detected in OCR result for learned phrase")
}
if !phrase.is_ascii() {
anyhow::bail!("cannot learn non-ASCII characters")
}
let chars: Vec<char> = phrase.chars().collect();
for (i, hash) in hashes.iter().enumerate() {
self.learn(&hash.to_base64(), chars[i]);
}
Ok(())
}
pub fn ocr_char(&self, hash: &ImageHash) -> Option<char> {
let state = self.learned_chars.read().unwrap();
let (_, c) = state.iter().min_by_key(|(learned_hash, _)| hash.dist(learned_hash))?;
Some(*c)
}
pub fn ocr_hashes(&self, hashes: &[ImageHash]) -> String {
let buffer: String = hashes.iter().filter_map(|hash| self.ocr_char(hash)).collect();
buffer.trim_end_matches(|c: char| !c.is_alphanumeric()).to_owned()
}
pub fn ocr_image(&self, image: &RgbImage) -> String {
let hashes = ocr::compute_box_hashes(image);
self.ocr_hashes(&hashes)
}
}

View File

@ -1,121 +0,0 @@
use std::{
collections::HashMap,
sync::{Arc, Mutex, RwLock},
};
use anyhow::Result;
use image::RgbImage;
use serde::{Deserialize, Serialize};
use crate::{
config::Config,
image_processing::{extract_and_filter, hash_image},
state::OcrCache,
};
#[derive(Serialize, Deserialize, Debug)]
pub struct OcrRegion {
pub confidence: f64,
pub value: String,
}
#[derive(Serialize, Deserialize, Debug)]
struct OcrResult {
regions: Vec<OcrRegion>,
error: Option<String>,
}
async fn run_ocr(image: &RgbImage, url: &str) -> Result<Option<String>> {
let client = reqwest::Client::new();
let response = client
.post(url)
.body(crate::image_processing::to_png_bytes(image))
.send()
.await?;
if !response.status().is_success() {
eprintln!("failed to run OCR query");
anyhow::bail!("failed to run OCR query")
}
let result: OcrResult = response.json().await?;
let result = if result.regions.is_empty() {
None
} else {
let mut buffer = String::new();
for r in &result.regions {
buffer += &r.value;
}
Some(buffer)
};
Ok(result)
}
async fn run_ocr_cached(
ocr_cache: Arc<RwLock<HashMap<String, Option<String>>>>,
hash: String,
region: &crate::image_processing::Region,
config: Arc<Config>,
filtered_image: &image::ImageBuffer<image::Rgb<u8>, Vec<u8>>,
) -> Option<String> {
let cached = {
let locked = ocr_cache.read().unwrap();
locked.get(&hash).cloned()
};
let use_cache = region.use_ocr_cache.unwrap_or(true) && config.use_ocr_cache.unwrap_or(true);
if let Some(cached) = cached {
if use_cache {
return cached;
}
}
match run_ocr(filtered_image, &config.ocr_server_endpoint).await {
Ok(v) => {
if use_cache {
ocr_cache.write().unwrap().insert(hash.clone(), v.clone());
}
v
}
Err(_) => None,
}
}
#[tokio::main(flavor = "current_thread")]
pub async fn ocr_all_regions(
image: &RgbImage,
config: Arc<Config>,
ocr_cache: Arc<OcrCache>,
should_sample: bool,
) -> HashMap<String, Option<String>> {
let results = Arc::new(Mutex::new(HashMap::new()));
let mut handles = Vec::new();
for region in &config.ocr_regions {
let filtered_image = extract_and_filter(image, region);
let region = region.clone();
let results = results.clone();
let config = config.clone();
let ocr_cache = ocr_cache.clone();
handles.push(tokio::spawn(async move {
let filtered_image = filtered_image;
let hash = hash_image(&filtered_image);
let value =
run_ocr_cached(ocr_cache, hash, &region, config.clone(), &filtered_image).await;
if let Some(sample_fraction) = &config.dump_frame_fraction {
if rand::random::<f64>() < *sample_fraction && should_sample {
let file_id = rand::random::<usize>();
let img_filename = format!("ocr_data/{}.png", file_id);
filtered_image.save(img_filename).unwrap();
let value_filename = format!("ocr_data/{}.txt", file_id);
std::fs::write(value_filename, value.clone().unwrap_or_default()).unwrap();
}
}
results.lock().unwrap().insert(region.name, value);
}));
}
for handle in handles {
handle.await.expect("failed to join task in OCR");
}
let results = results.lock().unwrap().clone();
results
}

View File

@ -4,7 +4,7 @@ use egui_extras::RetainedImage;
use image::RgbImage;
use time::{OffsetDateTime, format_description};
use crate::{config::Config, learned_tracks::LearnedTracks};
use crate::{config::Config, learned_tracks::LearnedTracks, ocr_db::OcrDatabase};
#[derive(Debug, Clone, Default)]
@ -33,8 +33,8 @@ fn parse_duration(time: &str) -> Option<Duration> {
Some(Duration::from_secs_f64(60.0 * minutes + secs))
}
fn parse_to_duration(time: Option<&Option<String>>) -> Option<Duration> {
parse_duration(&time?.clone()?)
fn parse_to_duration(time: Option<&String>) -> Option<Duration> {
parse_duration(&time?)
}
fn check_0_100(v: usize) -> Option<usize> {
@ -45,12 +45,12 @@ fn check_0_100(v: usize) -> Option<usize> {
}
}
fn parse_to_0_100(v: Option<&Option<String>>) -> Option<usize> {
check_0_100(v?.clone()?.parse::<usize>().ok()?)
fn parse_to_0_100(v: Option<&String>) -> Option<usize> {
check_0_100(v?.parse::<usize>().ok()?)
}
impl LapState {
pub fn parse(raw: &HashMap<String, Option<String>>) -> Self {
pub fn parse(raw: &HashMap<String, String>) -> Self {
Self {
lap: parse_to_0_100(raw.get("lap")),
health: parse_to_0_100(raw.get("health")),
@ -126,10 +126,9 @@ pub struct DebugOcrFrame {
pub recognized_text: Option<String>,
}
pub type OcrCache = RwLock<HashMap<String, Option<String>>>;
#[derive(Default)]
pub struct AppState {
pub raw_data: HashMap<String, Option<String>>,
pub raw_data: HashMap<String, String>,
pub last_frame: Option<LapState>,
pub buffered_frames: VecDeque<LapState>,
@ -144,8 +143,7 @@ pub struct AppState {
pub config: Arc<Config>,
pub learned_tracks: Arc<LearnedTracks>,
pub ocr_cache: Arc<OcrCache>,
pub ocr_db: Arc<OcrDatabase>,
}
pub type SharedAppState = Arc<Mutex<AppState>>;

Binary file not shown.

After

Width:  |  Height:  |  Size: 707 B

View File

@ -9,8 +9,9 @@ use eframe::{
};
use egui_extras::RetainedImage;
use image::RgbImage;
use img_hash::ImageHash;
use crate::{image_processing::to_png_bytes, ocr};
use crate::{image_processing::to_png_bytes, ocr, ocr_db::OcrDatabase};
#[derive(Default)]
struct TrainingUi {
@ -18,7 +19,7 @@ struct TrainingUi {
current_image_index: usize,
learned_char_hashes: Vec<(String, char)>,
ocr_db: OcrDatabase,
}
struct TrainingImage {
@ -28,7 +29,7 @@ struct TrainingImage {
ui_image: RetainedImage,
char_images: Vec<RetainedImage>,
char_hashes: Vec<String>,
char_hashes: Vec<ImageHash>,
}
impl TrainingImage {
@ -55,16 +56,6 @@ fn get_training_data_paths() -> Vec<(PathBuf, PathBuf)> {
data_paths
}
fn predict_ocr(hashes: &[(String, char)], hash: &str) -> Option<char> {
let hash = img_hash::ImageHash::<Vec<u8>>::from_base64(hash).unwrap();
let (_, best_char) = hashes.iter().min_by_key(|(learned_hash, _c)| {
img_hash::ImageHash::from_base64(learned_hash)
.unwrap()
.dist(&hash)
})?;
Some(*best_char)
}
fn get_training_data() -> Vec<TrainingImage> {
let mut data = Vec::new();
for (img_file, ocr_file) in get_training_data_paths() {
@ -90,24 +81,6 @@ fn get_training_data() -> Vec<TrainingImage> {
data
}
fn load_learned_hashes() -> Vec<(String, char)> {
let path = Path::new("learned_chars.txt");
if !path.exists() {
return Vec::new();
}
let data = String::from_utf8(std::fs::read(path).unwrap()).unwrap();
let mut parsed = Vec::new();
for line in data.lines() {
if let Some((c, hash)) = line.split_once(' ') {
if let Some(c) = c.chars().next() {
parsed.push((hash.to_owned(), c));
}
}
}
parsed
}
fn load_retained_image(image: &RgbImage) -> RetainedImage {
RetainedImage::from_image_bytes("", &to_png_bytes(image)).unwrap()
}
@ -116,7 +89,7 @@ pub fn training_ui() -> anyhow::Result<()> {
let options = eframe::NativeOptions::default();
let state = TrainingUi {
training_images: get_training_data(),
learned_char_hashes: load_learned_hashes(),
ocr_db: OcrDatabase::load().unwrap(),
..Default::default()
};
eframe::run_native("OCR Trainer", options, Box::new(|_cc| Box::new(state)));
@ -139,46 +112,22 @@ impl eframe::App for TrainingUi {
current_image.save_ocr_text();
}
if ui.button("Learn").clicked() {
for (i, char) in current_image.text.chars().enumerate() {
if let Some(hash) = current_image.char_hashes.get(i) {
self.learned_char_hashes.push((hash.clone(), char));
eprintln!("Learned {}={}", hash, char);
}
}
self.ocr_db.learn_phrase(&current_image.char_hashes, &current_image.text).unwrap();
self.current_image_index += 1;
}
if ui.button("Learn and delete").clicked() {
for (i, char) in current_image.text.chars().enumerate() {
if let Some(hash) = current_image.char_hashes.get(i) {
self.learned_char_hashes.push((hash.clone(), char));
eprintln!("Learned {}={}", hash, char);
}
}
self.ocr_db.learn_phrase(&current_image.char_hashes, &current_image.text).unwrap();
current_image.delete_data();
self.current_image_index += 1;
}
if ui.button("Save learned results").clicked() {
let mut buffer = String::new();
for (hash, c) in &self.learned_char_hashes {
buffer += &format!("{} {}\n", c, hash);
}
let mut file = std::fs::OpenOptions::new()
.append(true)
.create(true)
.write(true)
.open("learned_chars.txt")
.unwrap();
file.write_all(buffer.as_bytes()).unwrap();
self.ocr_db.save().unwrap();
}
current_image.ui_image.show(ui);
ui.label("OCR value");
ui.text_edit_singleline(&mut current_image.text);
let predicted: String = current_image
.char_hashes
.iter()
.filter_map(|hash| predict_ocr(&self.learned_char_hashes, hash))
.collect();
let predicted: String = self.ocr_db.ocr_hashes(&current_image.char_hashes);
if predicted == current_image.text {
ui.colored_label(Color32::GREEN, format!("Predicted: {}", predicted));
} else {
@ -189,8 +138,8 @@ impl eframe::App for TrainingUi {
c.show(ui);
}
for c in &current_image.char_hashes {
ui.label(c);
if let Some(predicted) = predict_ocr(&self.learned_char_hashes, c) {
ui.label(c.to_base64());
if let Some(predicted) = self.ocr_db.ocr_char(c) {
ui.label(format!("Predicted: {}", predicted));
}
}