diff --git a/Cargo.lock b/Cargo.lock index 7f20237..b8715ff 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -106,6 +106,18 @@ version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0d8c1fef690941d3e7788d328517591fecc684c084084702d6ff1641e993699a" +[[package]] +name = "bstr" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba3569f383e8f1598449f1a423e72e99569137b47740b1da11ef19af3d5c3223" +dependencies = [ + "lazy_static", + "memchr", + "regex-automata", + "serde", +] + [[package]] name = "bumpalo" version = "3.9.1" @@ -387,6 +399,28 @@ dependencies = [ "lazy_static", ] +[[package]] +name = "csv" +version = "1.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22813a6dc45b335f9bade10bf7271dc477e81113e89eb251a0bc2a8a81c536e1" +dependencies = [ + "bstr", + "csv-core", + "itoa 0.4.8", + "ryu", + "serde", +] + +[[package]] +name = "csv-core" +version = "0.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b2466559f260f48ad25fe6317b3c8dac77b5bdb5763ac7d9d6103530663bc90" +dependencies = [ + "memchr", +] + [[package]] name = "cty" version = "0.2.2" @@ -938,7 +972,7 @@ checksum = "ff8670570af52249509a86f5e3e18a08c60b177071826898fde8997cf5f6bfbb" dependencies = [ "bytes", "fnv", - "itoa", + "itoa 1.0.2", ] [[package]] @@ -979,7 +1013,7 @@ dependencies = [ "http-body", "httparse", "httpdate", - "itoa", + "itoa 1.0.2", "pin-project-lite", "socket2", "tokio", @@ -1102,6 +1136,12 @@ version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "879d54834c8c76457ef4293a689b2a8c59b076067ad77b15efafbb05f92a592b" +[[package]] +name = "itoa" +version = "0.4.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b71991ff56294aa922b450139ee08b3bfc70982c6b2c7562771375cf73542dd4" + [[package]] name = "itoa" version = "1.0.2" @@ -1791,6 +1831,12 @@ dependencies = [ "bitflags", ] +[[package]] +name = "regex-automata" +version = "0.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c230d73fb8d8c1b9c0b3135c5142a8acee3a0558fb8db5cf1cb65f8d7862132" + [[package]] name = "remove_dir_all" version = "0.5.3" @@ -1999,7 +2045,7 @@ version = "1.0.81" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9b7ce2b32a1aed03c558dc61a5cd328f15aff2dbc17daad8fb8af04d2100e15c" dependencies = [ - "itoa", + "itoa 1.0.2", "ryu", "serde", ] @@ -2011,7 +2057,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d3491c14715ca2294c4d6a88f15e84739788c1d030eed8c110436aafdaa2f3fd" dependencies = [ "form_urlencoded", - "itoa", + "itoa 1.0.2", "ryu", "serde", ] @@ -2129,6 +2175,7 @@ name = "supper" version = "0.1.0" dependencies = [ "anyhow", + "csv", "eframe", "egui_extras", "ehttp", diff --git a/Cargo.toml b/Cargo.toml index e2f0881..6cbb633 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,4 +22,6 @@ tokio = { version = "1", features = ["full"] } reqwest = { version = "0.11", features = ["json"] } -img_hash = "3" \ No newline at end of file +img_hash = "3" + +csv = "1" \ No newline at end of file diff --git a/config.json b/config.json index 6e39311..8a63bb6 100644 --- a/config.json +++ b/config.json @@ -2,7 +2,7 @@ "ocr_regions": [ { "name": "lap", - "x": 2290, + "x": 2300, "y": 46, "width": 145, "height": 90 @@ -29,20 +29,27 @@ "height": 24 }, { - "name": "lap_time", + "name": "best", "x": 2325, "y": 169, "width": 183, "height": 43 }, { - "name": "best_time", + "name": "lap_time", "x": 2325, "y": 222, "width": 183, "height": 43 } ], - + "track_region": { + "name": "track", + "x": 2020, + "y": 1030, + "width": 540, + "height": 410, + "threshold": 0.85 + }, "ocr_server_endpoint": "https://tesserver.spruett.dev/" } \ No newline at end of file diff --git a/config_1080p.json b/config_1080p.json new file mode 100644 index 0000000..971de45 --- /dev/null +++ b/config_1080p.json @@ -0,0 +1,47 @@ +{ + "ocr_regions": [ + { + "name": "lap", + "x": 1718, + "y": 34, + "width": 109, + "height": 68 + }, + { + "name": "health", + "x": 68, + "y": 1023, + "width": 39, + "height": 18 + }, + { + "name": "gas", + "x": 156, + "y": 1023, + "width": 39, + "height": 18 + }, + { + "name": "tyres", + "x": 244, + "y": 1023, + "width": 39, + "height": 18 + }, + { + "name": "best_time", + "x": 1744, + "y": 127, + "width": 137, + "height": 32 + }, + { + "name": "lap_time", + "x": 1744, + "y": 166, + "width": 137, + "height": 32 + } + ], + "ocr_server_endpoint": "https://tesserver.spruett.dev/" +} diff --git a/scale_config.py b/scale_config.py new file mode 100644 index 0000000..28ba38d --- /dev/null +++ b/scale_config.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python3 + +import argparse +import json +from typing import Tuple + + +def parse_resolution(resolution: str) -> Tuple[int, int]: + a, b = resolution.split('x') + return int(a), int(b) + +def scale_x_y(x, y, from_resolution, to_resolution): + return (x * to_resolution[0] / from_resolution[0], y * to_resolution[1] / from_resolution[1]) + +def scale_region(region, from_resolution, to_resolution): + x, y = scale_x_y(region['x'], region['y'], from_resolution, to_resolution) + width, height = scale_x_y(region['width'], region['height'], from_resolution, to_resolution) + region['x'] = round(x) + region['y'] = round(y) + region['width'] = round(width) + region['height'] = round(height) + +def main(): + argparser = argparse.ArgumentParser() + argparser.add_argument("--from_res", help="From resolution", default="2560x1440") + argparser.add_argument("--to_res", help="To resolution (e.g. 1920x1080)") + argparser.add_argument("--config", help="Config file", default="config.json") + + args = argparser.parse_args() + + from_resolution = parse_resolution(args.from_res) + to_resolution = parse_resolution(args.to_res) + config = json.load(open(args.config, 'r')) + for region in config['ocr_regions']: + scale_region(region, from_resolution, to_resolution) + print(json.dumps(config, indent=4)) + + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/src/capture.rs b/src/capture.rs index 74a21ac..0cf30e8 100644 --- a/src/capture.rs +++ b/src/capture.rs @@ -2,7 +2,7 @@ use std::{time::Duration, thread}; use anyhow::Result; use image::{RgbImage, Rgb}; -use scrap::{Capturer, Display}; +use scrap::Capturer; fn get_raw_frame(capturer: &mut Capturer) -> Result> { loop { diff --git a/src/config.rs b/src/config.rs index b5230fd..4ef6612 100644 --- a/src/config.rs +++ b/src/config.rs @@ -10,15 +10,14 @@ pub struct Config { pub ocr_regions: Vec, pub track_region: Option, pub ocr_server_endpoint: String, + pub filter_threshold: Option, + pub use_ocr_cache: Option, } impl Config { pub fn load() -> Result { load_or_make_default("config.json", include_str!("configs/config.default.json")) } - pub fn save(&self) -> Result<()> { - save_json_config("config.json", self) - } } #[derive(Default, Serialize, Deserialize, Clone)] diff --git a/src/configs/config.default.json b/src/configs/config.default.json index c021f96..e31994e 100644 --- a/src/configs/config.default.json +++ b/src/configs/config.default.json @@ -29,14 +29,14 @@ "height": 24 }, { - "name": "lap_time", + "name": "best", "x": 2325, "y": 169, "width": 183, "height": 43 }, { - "name": "best_time", + "name": "lap_time", "x": 2325, "y": 222, "width": 183, diff --git a/src/control_loop.rs b/src/control_loop.rs index 928773b..3755094 100644 --- a/src/control_loop.rs +++ b/src/control_loop.rs @@ -1,6 +1,7 @@ use std::{ collections::HashMap, - time::{Duration, Instant}, thread, + thread, + time::{Duration, Instant}, }; use anyhow::Result; @@ -8,12 +9,11 @@ use egui_extras::RetainedImage; use image::RgbImage; use scrap::{Capturer, Display}; - use crate::{ capture, - image_processing::{self, hash_image}, + image_processing::{self, hash_image, Region, extract_and_filter}, ocr, - state::{AppState, DebugOcrFrame, ParsedFrame, RaceState, SharedAppState}, + state::{AppState, DebugOcrFrame, ParsedFrame, RaceState, SharedAppState}, config::Config, }; fn is_finished_lap(state: &AppState, frame: &ParsedFrame) -> bool { @@ -104,34 +104,47 @@ fn handle_new_frame(state: &mut AppState, frame: ParsedFrame, image: &RgbImage) } } -fn run_loop_once(state: &SharedAppState) -> Result<()> { - let mut capturer = Capturer::new(Display::primary()?)?; - let config = state.lock().unwrap().config.clone(); - let learned_config = state.lock().unwrap().learned.clone(); - let frame = capture::get_frame(&mut capturer)?; - let ocr_results = ocr::ocr_all_regions(&frame, config.clone(), learned_config.clone()); +fn add_saved_frame( + saved_frames: &mut HashMap, + frame: &RgbImage, + region: &Region, + config: &Config, +) { + let extracted = extract_and_filter(frame, region); + let retained = + RetainedImage::from_image_bytes(®ion.name, &image_processing::to_png_bytes(&extracted)) + .unwrap(); + let hash = hash_image(&extracted); + saved_frames.insert( + region.name.clone(), + DebugOcrFrame { + image: retained, + rgb_image: extracted, + img_hash: hash, + }, + ); +} + +fn run_loop_once(capturer: &mut Capturer, state: &SharedAppState) -> Result<()> { + let (config, learned_config, ocr_cache) = { + let locked = state.lock().unwrap(); + ( + locked.config.clone(), + locked.learned.clone(), + locked.ocr_cache.clone(), + ) + }; + let frame = capture::get_frame(capturer)?; + let ocr_results = ocr::ocr_all_regions(&frame, config.clone(), learned_config, ocr_cache); let mut saved_frames = HashMap::new(); if state.lock().unwrap().debug_frames { - let hasher = img_hash::HasherConfig::new().to_hasher(); for region in &config.ocr_regions { - let mut extracted = image_processing::extract_region(&frame, region); - image_processing::filter_to_white(&mut extracted); - let retained = RetainedImage::from_image_bytes( - ®ion.name, - &image_processing::to_png_bytes(&extracted), - ) - .unwrap(); - let hash = hash_image(&extracted); - saved_frames.insert( - region.name.clone(), - DebugOcrFrame { - image: retained, - rgb_image: extracted, - img_hash: hash, - }, - ); + add_saved_frame(&mut saved_frames, &frame, region, config.as_ref()); + } + if let Some(track_region) = &config.track_region { + add_saved_frame(&mut saved_frames, &frame, track_region, config.as_ref()); } } { @@ -145,8 +158,9 @@ fn run_loop_once(state: &SharedAppState) -> Result<()> { } pub fn run_control_loop(state: SharedAppState) { + let mut capturer = Capturer::new(Display::primary().unwrap()).unwrap(); loop { - if let Err(e) = run_loop_once(&state) { + if let Err(e) = run_loop_once(&mut capturer, &state) { eprintln!("Error in control loop: {:?}", e) } thread::sleep(Duration::from_millis(500)); diff --git a/src/image_processing.rs b/src/image_processing.rs index d6831bb..7188da3 100644 --- a/src/image_processing.rs +++ b/src/image_processing.rs @@ -1,6 +1,8 @@ use image::{codecs::png::PngEncoder, ColorType, ImageEncoder, Rgb, RgbImage}; use serde::{Deserialize, Serialize}; +use crate::config::Config; + #[derive(Clone, Deserialize, Serialize)] pub struct Region { pub name: String, @@ -8,6 +10,13 @@ pub struct Region { y: usize, width: usize, height: usize, + pub threshold: Option, +} + +pub fn extract_and_filter(image: &RgbImage, region: &Region) -> RgbImage { + let mut extracted = extract_region(image, region); + filter_to_white(&mut extracted, ®ion.threshold); + extracted } pub fn extract_region(image: &RgbImage, region: &Region) -> RgbImage { @@ -24,9 +33,9 @@ pub fn extract_region(image: &RgbImage, region: &Region) -> RgbImage { buffer } -pub fn filter_to_white(image: &mut RgbImage) { - let threshold = 0.98; - let variance_threshold = 0.02; +pub fn filter_to_white(image: &mut RgbImage, threshold: &Option) { + let threshold = threshold.unwrap_or(0.95); + let variance_threshold = 1.0 - threshold; let past_threshold_color = |v: u8| v as f64 >= (u8::MAX as f64 * threshold); let color_diff = |a: u8, b: u8| (a.abs_diff(b) as f64) / (u8::MAX as f64); for y in 0..image.height() { diff --git a/src/main.rs b/src/main.rs index fbabc49..4adc7e9 100644 --- a/src/main.rs +++ b/src/main.rs @@ -6,6 +6,7 @@ mod image_processing; mod ocr; mod state; mod config; +mod stats_writer; use std::{ sync::{Arc, Mutex}, @@ -18,6 +19,7 @@ use eframe::{ epaint::Color32, emath::Vec2, }; use state::{AppState, RaceState, SharedAppState}; +use stats_writer::export_race_stats; fn main() -> anyhow::Result<()> { let mut app_state = AppState::default(); @@ -101,8 +103,8 @@ impl MyApp { } } -fn show_race_state(ui: &mut Ui, race: &RaceState) { - egui::Grid::new("current-race").show(ui, |ui| { +fn show_race_state(ui: &mut Ui, race_name: &str, race: &RaceState) { + egui::Grid::new(format!("race:{}", race_name)).show(ui, |ui| { ui.label("Lap"); ui.label("Time"); ui.label("Δ Previous"); @@ -187,22 +189,36 @@ impl eframe::App for MyApp { egui::CentralPanel::default().show(ctx, |ui| { if let Some(race) = &state.current_race { ui.heading("Current Race"); - show_race_state(ui, race); + show_race_state(ui, "current", race); } let len = state.past_races.len(); for (i, race) in state.past_races.iter_mut().enumerate() { ui.separator(); ui.heading(format!("Race #{}", len - i)); - show_race_state(ui, race); + show_race_state(ui, &format!("{}", i), race); if let Some(img) = &race.screencap { img.show_max_size(ui, Vec2::new(600.0, 500.0)); } - ui.label("Car:"); - ui.text_edit_singleline(&mut race.car); - ui.label("Track:"); - ui.text_edit_singleline(&mut race.track); - if ui.button("Export").clicked() { - println!("EXPORT: TODO"); + if !race.exported { + ui.label("Car:"); + ui.text_edit_singleline(&mut race.car); + ui.label("Track:"); + ui.text_edit_singleline(&mut race.track); + if ui.button("Export").clicked() { + match export_race_stats(race) { + Ok(_) => { + race.exported = true; + } + Err(e) => { + race.export_error = Some(format!("failed to export race: {:?}", e)); + } + } + } + if let Some(e) = &race.export_error { + ui.colored_label(Color32::RED, e); + } + } else { + ui.label("Exported ✅"); } } }); diff --git a/src/ocr.rs b/src/ocr.rs index 2c57c48..43d4b02 100644 --- a/src/ocr.rs +++ b/src/ocr.rs @@ -1,6 +1,6 @@ use std::{ collections::HashMap, - sync::{Arc, Mutex}, + sync::{Arc, Mutex, RwLock}, }; use anyhow::Result; @@ -9,7 +9,7 @@ use serde::{Deserialize, Serialize}; use crate::{ config::{Config, LearnedConfig}, - image_processing::{extract_region, filter_to_white, hash_image, Region}, + image_processing::{extract_region, filter_to_white, hash_image, extract_and_filter}, }; #[derive(Serialize, Deserialize, Debug)] @@ -54,26 +54,41 @@ pub async fn ocr_all_regions( image: &RgbImage, config: Arc, learned: Arc, + ocr_cache: Arc>>>, ) -> HashMap> { let results = Arc::new(Mutex::new(HashMap::new())); let mut handles = Vec::new(); for region in &config.ocr_regions { - let filtered_image = extract_region(image, region); + let filtered_image = extract_and_filter(image, region); let region = region.clone(); let results = results.clone(); let config = config.clone(); let learned = learned.clone(); + let ocr_cache = ocr_cache.clone(); handles.push(tokio::spawn(async move { - let mut image = filtered_image; - filter_to_white(&mut image); - let hash = hash_image(&image); + let filtered_image = filtered_image; + let hash = hash_image(&filtered_image); let value = if let Some(learned_value) = learned.learned_images.get(&hash) { Some(learned_value.clone()) } else { - run_ocr(&image, &config.ocr_server_endpoint) - .await - .unwrap_or(None) + let cached = { + let locked = ocr_cache.read().unwrap(); + locked.get(&hash).cloned() + }; + if let Some(cached) = cached { + cached + } else { + match run_ocr(&filtered_image, &config.ocr_server_endpoint).await { + Ok(v) => { + if config.use_ocr_cache.unwrap_or(true) { + ocr_cache.write().unwrap().insert(hash.clone(), v.clone()); + } + v + } + Err(_) => None + } + } }; results.lock().unwrap().insert(region.name, value); })); diff --git a/src/state.rs b/src/state.rs index af39f8f..2cdd0eb 100644 --- a/src/state.rs +++ b/src/state.rs @@ -1,4 +1,4 @@ -use std::{sync::{Arc, Mutex}, time::{Duration, Instant}, collections::{HashMap, VecDeque}}; +use std::{sync::{Arc, Mutex, RwLock}, time::{Duration, Instant}, collections::{HashMap, VecDeque}}; use egui_extras::RetainedImage; use image::RgbImage; @@ -65,6 +65,7 @@ pub struct RaceState { pub screencap: Option, pub exported: bool, + pub export_error: Option, pub car: String, pub track: String, @@ -92,6 +93,8 @@ pub struct AppState { pub config: Arc, pub learned: Arc, + + pub ocr_cache: Arc>>>, } pub type SharedAppState = Arc>; \ No newline at end of file diff --git a/src/stats_writer.rs b/src/stats_writer.rs new file mode 100644 index 0000000..3b663b2 --- /dev/null +++ b/src/stats_writer.rs @@ -0,0 +1,45 @@ +use std::{ + io::BufWriter, + time::{Duration, Instant}, +}; + +use crate::state::RaceState; + +use anyhow::Result; + +pub fn export_race_stats(race_stats: &mut RaceState) -> Result<()> { + let file = std::fs::OpenOptions::new() + .create(true) + .append(true) + .open("race_stats.csv")?; + let writer = BufWriter::new(file); + let mut csv_writer = csv::Writer::from_writer(writer); + + for lap in &race_stats.laps { + csv_writer.write_record(vec![ + race_stats.track.clone(), + race_stats.car.clone(), + format!( + "{:.3}", + lap.lap_time.unwrap_or(Duration::from_secs(0)).as_secs_f64() + ), + format!( + "{:.3}", + lap.best_time + .unwrap_or(Duration::from_secs(0)) + .as_secs_f64() + ), + lap.health + .map(|x| x.to_string()) + .unwrap_or_else(|| "".to_owned()), + lap.gas + .map(|x| x.to_string()) + .unwrap_or_else(|| "".to_owned()), + lap.tyres + .map(|x| x.to_string()) + .unwrap_or_else(|| "".to_owned()), + ])?; + } + + Ok(()) +}