more testing

This commit is contained in:
Scott Pruett 2022-06-04 23:06:53 -04:00
parent f0b9a9e6f3
commit 1502110e62
14 changed files with 80 additions and 30 deletions

View File

@ -53,6 +53,13 @@
"height": 43, "height": 43,
"threshold": null, "threshold": null,
"use_ocr_cache": null "use_ocr_cache": null
},
{
"name": "position",
"x": 3,
"y": 52,
"width": 107,
"height": 79
} }
], ],
"track_region": { "track_region": {

View File

@ -63,13 +63,13 @@ fn merge_frames(prev: &LapState, next: &LapState) -> LapState {
..Default::default() ..Default::default()
} }
} }
fn handle_new_frame(state: &mut AppState, frame: LapState, image: RgbImage) { fn handle_new_frame(state: &mut AppState, lap_state: LapState, image: &RgbImage) {
if frame.lap_time.is_some() { if lap_state.lap_time.is_some() {
state.last_frame = Some(frame.clone()); state.last_frame = Some(lap_state.clone());
state.frames_without_lap = 0; state.frames_without_lap = 0;
if state.current_race.is_none() { if state.current_race.is_none() {
let track_hash = get_track_hash(state.config.as_ref(), &image); let track_hash = get_track_hash(state.config.as_ref(), image);
let track_name = state let track_name = state
.learned_tracks .learned_tracks
.infer_track(&track_hash, state.config.as_ref()); .infer_track(&track_hash, state.config.as_ref());
@ -78,7 +78,7 @@ fn handle_new_frame(state: &mut AppState, frame: LapState, image: RgbImage) {
screencap: Some( screencap: Some(
RetainedImage::from_image_bytes( RetainedImage::from_image_bytes(
"screencap", "screencap",
&image_processing::to_png_bytes(&image), &image_processing::to_png_bytes(image),
) )
.expect("failed to save screenshot"), .expect("failed to save screenshot"),
), ),
@ -100,12 +100,12 @@ fn handle_new_frame(state: &mut AppState, frame: LapState, image: RgbImage) {
} }
} }
if is_finished_lap(state, &frame) { if is_finished_lap(state, &lap_state) {
let mut merged = merge_frames(state.buffered_frames.back().unwrap(), &frame); let mut merged = merge_frames(state.buffered_frames.back().unwrap(), &lap_state);
if let Some(lap) = &merged.lap { if let Some(lap) = &merged.lap {
merged.lap = Some(lap - 1); merged.lap = Some(lap - 1);
} }
merged.screenshot = Some(to_png_bytes(&image)); merged.screenshot = Some(to_png_bytes(image));
if let Some(race) = state.current_race.as_mut() { if let Some(race) = state.current_race.as_mut() {
if let Some(prev_lap) = race.laps.last() { if let Some(prev_lap) = race.laps.last() {
@ -122,7 +122,7 @@ fn handle_new_frame(state: &mut AppState, frame: LapState, image: RgbImage) {
} }
} }
state.buffered_frames.push_back(frame); state.buffered_frames.push_back(lap_state);
if state.buffered_frames.len() >= 20 { if state.buffered_frames.len() >= 20 {
state.buffered_frames.pop_front(); state.buffered_frames.pop_front();
} }
@ -176,9 +176,7 @@ pub fn ocr_all_regions(
results results
} }
fn run_loop_once(capturer: &mut Capturer, state: &SharedAppState) -> Result<()> { fn analyze_frame(frame: &RgbImage, state: &SharedAppState) {
let frame = capture::get_frame(capturer)?;
let (ocr_db, config, should_sample) = { let (ocr_db, config, should_sample) = {
let locked = state.lock().unwrap(); let locked = state.lock().unwrap();
( (
@ -187,19 +185,23 @@ fn run_loop_once(capturer: &mut Capturer, state: &SharedAppState) -> Result<()>
locked.should_sample_ocr_data, locked.should_sample_ocr_data,
) )
}; };
let ocr_results = ocr_all_regions(ocr_db.as_ref(), &frame, config.as_ref(), should_sample); let ocr_results = ocr_all_regions(ocr_db.as_ref(), frame, config.as_ref(), should_sample);
if state.lock().unwrap().debug_frames { if state.lock().unwrap().debug_frames {
let debug_frames = save_frames_from(&frame, config.as_ref(), &ocr_results); let debug_frames = save_frames_from(frame, config.as_ref(), &ocr_results);
state.lock().unwrap().saved_frames = debug_frames; state.lock().unwrap().saved_frames = debug_frames;
} }
{ {
let mut state = state.lock().unwrap(); let mut state = state.lock().unwrap();
let parsed = LapState::parse(&ocr_results); let parsed = LapState::parse(&ocr_results);
state.raw_data = ocr_results;
handle_new_frame(&mut state, parsed, frame); handle_new_frame(&mut state, parsed, frame);
} }
}
fn run_loop_once(capturer: &mut Capturer, state: &SharedAppState) -> Result<()> {
let frame = capture::get_frame(capturer)?;
analyze_frame(&frame, state);
Ok(()) Ok(())
} }
@ -228,3 +230,48 @@ pub fn run_control_loop(state: SharedAppState) {
thread::sleep(Duration::from_millis(interval)); thread::sleep(Duration::from_millis(interval));
} }
} }
#[cfg(test)]
mod test {
use std::{
sync::{Arc, Mutex},
time::Duration,
};
use crate::{
config::load_config_or_make_default,
ocr_db::OcrDatabase,
state::{AppState, SharedAppState},
};
use super::analyze_frame;
fn make_test_state() -> SharedAppState {
let state = AppState {
config: Arc::new(
load_config_or_make_default("src/configs/config.default.json", "").unwrap(),
),
ocr_db: Arc::new(OcrDatabase::load().unwrap()),
..Default::default()
};
Arc::new(Mutex::new(state))
}
#[test]
fn test_basic_analysis() {
let state = make_test_state();
let image = image::load_from_memory(include_bytes!("test_data/test-full-1.png")).unwrap();
analyze_frame(&image.to_rgb8(), &state);
let lap_state = state.lock().unwrap().last_frame.as_ref().unwrap().clone();
assert_eq!(4, lap_state.lap.unwrap());
assert_eq!(95, lap_state.health.unwrap());
assert_eq!(79, lap_state.gas.unwrap());
assert_eq!(76, lap_state.tyres.unwrap());
assert!(
Duration::from_secs(24) <= lap_state.lap_time.unwrap()
&& lap_state.lap_time.unwrap() <= Duration::from_secs(25)
);
}
}

View File

@ -13,7 +13,6 @@ mod training_ui;
use std::{ use std::{
collections::HashMap, collections::HashMap,
ops::DerefMut,
path::PathBuf, path::PathBuf,
sync::{Arc, Mutex}, sync::{Arc, Mutex},
thread, thread,
@ -248,14 +247,12 @@ fn show_debug_frames(
ui.text_edit_singleline(&mut debug_image.recognized_text); ui.text_edit_singleline(&mut debug_image.recognized_text);
debug_image.image.show_max_size(ui, Vec2::new(300.0, 300.0)); debug_image.image.show_max_size(ui, Vec2::new(300.0, 300.0));
if name != "track" { if name != "track" && ui.button("Learn OCR").clicked() {
if ui.button("Learn OCR").clicked() {
let hashes = ocr::compute_box_hashes(&debug_image.rgb_image); let hashes = ocr::compute_box_hashes(&debug_image.rgb_image);
ocr_db ocr_db
.learn_phrase(&hashes, &debug_image.recognized_text) .learn_phrase(&hashes, &debug_image.recognized_text)
.unwrap(); .unwrap();
} }
}
ui.separator(); ui.separator();
} }
} }

View File

@ -116,7 +116,7 @@ pub fn compute_box_hashes(image: &RgbImage) -> Vec<ImageHash> {
#[test] #[test]
fn test_bounding_boxes() { fn test_bounding_boxes() {
let image_bytes = include_bytes!("test_data/test-image-2.png"); let image_bytes = include_bytes!("test_data/test-montserrat.png");
let image = image::load_from_memory(image_bytes).unwrap().to_rgb8(); let image = image::load_from_memory(image_bytes).unwrap().to_rgb8();
let boxes = get_character_bounding_boxes(&image); let boxes = get_character_bounding_boxes(&image);
assert_eq!(boxes.len(), 10); assert_eq!(boxes.len(), 10);
@ -128,7 +128,7 @@ fn test_bounding_boxes() {
#[test] #[test]
fn test_box_hashes() { fn test_box_hashes() {
let image_bytes = include_bytes!("test_data/test-image-2.png"); let image_bytes = include_bytes!("test_data/test-montserrat.png");
let image = image::load_from_memory(image_bytes).unwrap().to_rgb8(); let image = image::load_from_memory(image_bytes).unwrap().to_rgb8();
let hashes = compute_box_hashes(&image); let hashes = compute_box_hashes(&image);
assert_eq!(hashes.len(), 10); assert_eq!(hashes.len(), 10);

View File

@ -115,22 +115,22 @@ fn test_ocr() {
serde_json::from_str(include_str!("configs/ocr.default.json")).unwrap(); serde_json::from_str(include_str!("configs/ocr.default.json")).unwrap();
let db: OcrDatabase = (&raw).into(); let db: OcrDatabase = (&raw).into();
let image = image::load_from_memory(include_bytes!("test_data/test-image-3.png")) let image = image::load_from_memory(include_bytes!("test_data/test-time-1.png"))
.unwrap() .unwrap()
.to_rgb8(); .to_rgb8();
assert_eq!(db.ocr_image(&image), "00:30.625"); assert_eq!(db.ocr_image(&image), "00:30.625");
let image = image::load_from_memory(include_bytes!("test_data/test-image-4.png")) let image = image::load_from_memory(include_bytes!("test_data/test-time-2.png"))
.unwrap() .unwrap()
.to_rgb8(); .to_rgb8();
assert_eq!(db.ocr_image(&image), "00:20.296"); assert_eq!(db.ocr_image(&image), "00:20.296");
let image = image::load_from_memory(include_bytes!("test_data/test-image-num-1.png")) let image = image::load_from_memory(include_bytes!("test_data/test-num-1.png"))
.unwrap() .unwrap()
.to_rgb8(); .to_rgb8();
assert_eq!(db.ocr_image(&image), "1"); assert_eq!(db.ocr_image(&image), "1");
let image = image::load_from_memory(include_bytes!("test_data/test-image-blank.png")) let image = image::load_from_memory(include_bytes!("test_data/test-blank.png"))
.unwrap() .unwrap()
.to_rgb8(); .to_rgb8();
assert_eq!(db.ocr_image(&image), ""); assert_eq!(db.ocr_image(&image), "");

View File

@ -133,7 +133,6 @@ pub struct DebugOcrFrame {
#[derive(Default)] #[derive(Default)]
pub struct AppState { pub struct AppState {
pub raw_data: HashMap<String, String>,
pub last_frame: Option<LapState>, pub last_frame: Option<LapState>,
pub buffered_frames: VecDeque<LapState>, pub buffered_frames: VecDeque<LapState>,

View File

Before

Width:  |  Height:  |  Size: 183 B

After

Width:  |  Height:  |  Size: 183 B

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.2 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.4 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 53 KiB

View File

Before

Width:  |  Height:  |  Size: 51 KiB

After

Width:  |  Height:  |  Size: 51 KiB

View File

Before

Width:  |  Height:  |  Size: 322 B

After

Width:  |  Height:  |  Size: 322 B

View File

Before

Width:  |  Height:  |  Size: 707 B

After

Width:  |  Height:  |  Size: 707 B

View File

Before

Width:  |  Height:  |  Size: 724 B

After

Width:  |  Height:  |  Size: 724 B