rename local_ocr to ocr

This commit is contained in:
Scott Pruett 2022-06-03 18:11:49 -04:00
parent eaed474ccf
commit 8baad566aa
7 changed files with 254 additions and 265 deletions

View File

@ -14,7 +14,7 @@ use crate::{
capture, capture,
config::Config, config::Config,
image_processing::{self, extract_and_filter, hash_image, Region, to_png_bytes}, image_processing::{self, extract_and_filter, hash_image, Region, to_png_bytes},
ocr, remote_ocr,
state::{AppState, DebugOcrFrame, LapState, RaceState, SharedAppState}, learned_tracks::get_track_hash, state::{AppState, DebugOcrFrame, LapState, RaceState, SharedAppState}, learned_tracks::get_track_hash,
}; };
@ -159,7 +159,7 @@ fn run_loop_once(capturer: &mut Capturer, state: &SharedAppState) -> Result<()>
locked.should_sample_ocr_data locked.should_sample_ocr_data
) )
}; };
let ocr_results = ocr::ocr_all_regions(&frame, config.clone(), ocr_cache, should_sample); let ocr_results = remote_ocr::ocr_all_regions(&frame, config.clone(), ocr_cache, should_sample);
if state.lock().unwrap().debug_frames { if state.lock().unwrap().debug_frames {
let debug_frames = save_frames_from(&frame, config.as_ref(), &ocr_results); let debug_frames = save_frames_from(&frame, config.as_ref(), &ocr_results);

View File

@ -1,135 +0,0 @@
use image::{Rgb, RgbImage};
use img_hash::{image::GenericImageView};
use crate::image_processing;
#[derive(Debug)]
struct BoundingBox {
x: u32,
y: u32,
width: u32,
height: u32,
}
fn column_has_any_dark(image: &RgbImage, x: u32) -> bool {
for y in 0..image.height() {
let [r, g, b] = image.get_pixel(x, y).0;
if r < 100 && g < 100 && b < 100 {
return true;
}
}
false
}
fn row_has_any_dark(image: &RgbImage, y: u32, start_x: u32, width: u32) -> bool {
for x in start_x..(start_x + width) {
let [r, g, b] = image.get_pixel(x, y).0;
if r < 100 && g < 100 && b < 100 {
return true;
}
}
false
}
fn take_while<F: Fn(u32) -> bool>(x: &mut u32, max: u32, f: F) {
while *x < max && f(*x) {
*x += 1;
}
}
fn get_character_bounding_boxes(image: &RgbImage) -> Vec<BoundingBox> {
let mut x = 0;
let mut boxes = Vec::new();
while x < image.width() {
take_while(&mut x, image.width(), |x| !column_has_any_dark(image, x));
let start_x = x;
take_while(&mut x, image.width(), |x| column_has_any_dark(image, x));
let width = x - start_x;
if width >= 1 {
let mut y = 0;
take_while(&mut y, image.height(), |y| {
!row_has_any_dark(image, y, start_x, width)
});
let start_y = y;
let mut inverse_y = 0;
take_while(&mut inverse_y, image.height(), |y| {
!row_has_any_dark(image, image.height() - 1 - y, start_x, width)
});
let end_y = image.height() - inverse_y;
let height = end_y - start_y;
if height >= 1 {
boxes.push(BoundingBox {
x: start_x,
y: start_y,
width,
height,
});
}
}
}
boxes
}
fn trim_to_bounding_box(image: &RgbImage, bounding_box: &BoundingBox) -> RgbImage {
const PADDING: u32 = 2;
let mut buffer = RgbImage::from_pixel(
bounding_box.width + 2 * PADDING,
bounding_box.height + 2 * PADDING,
Rgb([0xFF, 0xFF, 0xFF]),
);
for y in 0..bounding_box.height {
for x in 0..bounding_box.width {
buffer.put_pixel(
x + PADDING,
y + PADDING,
*image.get_pixel(bounding_box.x + x, bounding_box.y + y),
);
}
}
buffer
}
pub fn bounding_box_images(image: &RgbImage) -> Vec<RgbImage> {
let mut trimmed = Vec::new();
let boxes = get_character_bounding_boxes(image);
for bounding_box in boxes {
trimmed.push(trim_to_bounding_box(image, &bounding_box));
}
trimmed
}
pub fn compute_box_hashes(image: &RgbImage) -> Vec<String> {
let mut hashes = Vec::new();
let boxes = get_character_bounding_boxes(image);
for bounding_box in boxes {
let trimmed = trim_to_bounding_box(image, &bounding_box);
hashes.push(image_processing::hash_image(&trimmed))
}
hashes
}
#[test]
fn test_bounding_boxes() {
let image_bytes = include_bytes!("test_data/test-image-2.png");
let image = image::load_from_memory(image_bytes).unwrap().to_rgb8();
let boxes = get_character_bounding_boxes(&image);
assert_eq!(boxes.len(), 10);
assert_ne!(boxes[0].x, 0);
assert_ne!(boxes[0].y, 0);
assert_ne!(boxes[0].height, 0);
assert_ne!(boxes[0].width, 0);
}
#[test]
fn test_box_hashes() {
let image_bytes = include_bytes!("test_data/test-image-2.png");
let image = image::load_from_memory(image_bytes).unwrap().to_rgb8();
let hashes = compute_box_hashes(&image);
assert_eq!(hashes.len(), 10);
}

View File

@ -4,8 +4,8 @@ mod analysis;
mod capture; mod capture;
mod config; mod config;
mod image_processing; mod image_processing;
mod local_ocr;
mod ocr; mod ocr;
mod remote_ocr;
mod state; mod state;
mod stats_writer; mod stats_writer;
mod training_ui; mod training_ui;
@ -125,9 +125,6 @@ struct DebugLap {
struct UiState { struct UiState {
config_load_err: Option<String>, config_load_err: Option<String>,
hash_to_learn: String,
value_to_learn: String,
debug_lap: Option<DebugLap>, debug_lap: Option<DebugLap>,
} }
@ -289,7 +286,7 @@ fn open_debug_lap(
) { ) {
if let Some(screenshot_bytes) = &lap.screenshot { if let Some(screenshot_bytes) = &lap.screenshot {
let screenshot = from_png_bytes(screenshot_bytes); let screenshot = from_png_bytes(screenshot_bytes);
let ocr_results = ocr::ocr_all_regions(&screenshot, config.clone(), ocr_cache, false); let ocr_results = remote_ocr::ocr_all_regions(&screenshot, config.clone(), ocr_cache, false);
let debug_lap = DebugLap { let debug_lap = DebugLap {
screenshot: RetainedImage::from_image_bytes("debug-lap", &to_png_bytes(&screenshot)) screenshot: RetainedImage::from_image_bytes("debug-lap", &to_png_bytes(&screenshot))
.unwrap(), .unwrap(),
@ -308,9 +305,6 @@ fn show_combo_box(ui: &mut Ui, name: &str, label: &str, options: &[String], valu
*value = options[index].clone(); *value = options[index].clone();
} }
fn save_learned_track(_learned_tracks: &mut Arc<LearnedTracks>, _track: &str, _hash: &str) {
}
impl eframe::App for AppUi { impl eframe::App for AppUi {
fn update(&mut self, ctx: &egui::Context, _frame: &mut eframe::Frame) { fn update(&mut self, ctx: &egui::Context, _frame: &mut eframe::Frame) {
let mut state = self.state.lock().unwrap(); let mut state = self.state.lock().unwrap();

View File

@ -1,121 +1,134 @@
use std::{ use image::{Rgb, RgbImage};
collections::HashMap,
sync::{Arc, Mutex, RwLock},
};
use anyhow::Result; use crate::image_processing;
use image::RgbImage;
use serde::{Deserialize, Serialize};
use crate::{ #[derive(Debug)]
config::Config, struct BoundingBox {
image_processing::{extract_and_filter, hash_image}, x: u32,
state::OcrCache, y: u32,
}; width: u32,
height: u32,
#[derive(Serialize, Deserialize, Debug)]
pub struct OcrRegion {
pub confidence: f64,
pub value: String,
} }
#[derive(Serialize, Deserialize, Debug)] fn column_has_any_dark(image: &RgbImage, x: u32) -> bool {
struct OcrResult { for y in 0..image.height() {
regions: Vec<OcrRegion>, let [r, g, b] = image.get_pixel(x, y).0;
error: Option<String>, if r < 100 && g < 100 && b < 100 {
} return true;
async fn run_ocr(image: &RgbImage, url: &str) -> Result<Option<String>> {
let client = reqwest::Client::new();
let response = client
.post(url)
.body(crate::image_processing::to_png_bytes(image))
.send()
.await?;
if !response.status().is_success() {
eprintln!("failed to run OCR query");
anyhow::bail!("failed to run OCR query")
}
let result: OcrResult = response.json().await?;
let result = if result.regions.is_empty() {
None
} else {
let mut buffer = String::new();
for r in &result.regions {
buffer += &r.value;
}
Some(buffer)
};
Ok(result)
}
async fn run_ocr_cached(
ocr_cache: Arc<RwLock<HashMap<String, Option<String>>>>,
hash: String,
region: &crate::image_processing::Region,
config: Arc<Config>,
filtered_image: &image::ImageBuffer<image::Rgb<u8>, Vec<u8>>,
) -> Option<String> {
let cached = {
let locked = ocr_cache.read().unwrap();
locked.get(&hash).cloned()
};
let use_cache = region.use_ocr_cache.unwrap_or(true) && config.use_ocr_cache.unwrap_or(true);
if let Some(cached) = cached {
if use_cache {
return cached;
} }
} }
match run_ocr(filtered_image, &config.ocr_server_endpoint).await { false
Ok(v) => { }
if use_cache {
ocr_cache.write().unwrap().insert(hash.clone(), v.clone()); fn row_has_any_dark(image: &RgbImage, y: u32, start_x: u32, width: u32) -> bool {
for x in start_x..(start_x + width) {
let [r, g, b] = image.get_pixel(x, y).0;
if r < 100 && g < 100 && b < 100 {
return true;
}
}
false
}
fn take_while<F: Fn(u32) -> bool>(x: &mut u32, max: u32, f: F) {
while *x < max && f(*x) {
*x += 1;
}
}
fn get_character_bounding_boxes(image: &RgbImage) -> Vec<BoundingBox> {
let mut x = 0;
let mut boxes = Vec::new();
while x < image.width() {
take_while(&mut x, image.width(), |x| !column_has_any_dark(image, x));
let start_x = x;
take_while(&mut x, image.width(), |x| column_has_any_dark(image, x));
let width = x - start_x;
if width >= 1 {
let mut y = 0;
take_while(&mut y, image.height(), |y| {
!row_has_any_dark(image, y, start_x, width)
});
let start_y = y;
let mut inverse_y = 0;
take_while(&mut inverse_y, image.height(), |y| {
!row_has_any_dark(image, image.height() - 1 - y, start_x, width)
});
let end_y = image.height() - inverse_y;
let height = end_y - start_y;
if height >= 1 {
boxes.push(BoundingBox {
x: start_x,
y: start_y,
width,
height,
});
} }
v
} }
Err(_) => None,
} }
boxes
} }
#[tokio::main(flavor = "current_thread")] fn trim_to_bounding_box(image: &RgbImage, bounding_box: &BoundingBox) -> RgbImage {
pub async fn ocr_all_regions( const PADDING: u32 = 2;
image: &RgbImage, let mut buffer = RgbImage::from_pixel(
config: Arc<Config>, bounding_box.width + 2 * PADDING,
ocr_cache: Arc<OcrCache>, bounding_box.height + 2 * PADDING,
should_sample: bool, Rgb([0xFF, 0xFF, 0xFF]),
) -> HashMap<String, Option<String>> { );
let results = Arc::new(Mutex::new(HashMap::new())); for y in 0..bounding_box.height {
for x in 0..bounding_box.width {
let mut handles = Vec::new(); buffer.put_pixel(
for region in &config.ocr_regions { x + PADDING,
let filtered_image = extract_and_filter(image, region); y + PADDING,
let region = region.clone(); *image.get_pixel(bounding_box.x + x, bounding_box.y + y),
let results = results.clone(); );
let config = config.clone(); }
let ocr_cache = ocr_cache.clone();
handles.push(tokio::spawn(async move {
let filtered_image = filtered_image;
let hash = hash_image(&filtered_image);
let value =
run_ocr_cached(ocr_cache, hash, &region, config.clone(), &filtered_image).await;
if let Some(sample_fraction) = &config.dump_frame_fraction {
if rand::random::<f64>() < *sample_fraction && should_sample {
let file_id = rand::random::<usize>();
let img_filename = format!("ocr_data/{}.png", file_id);
filtered_image.save(img_filename).unwrap();
let value_filename = format!("ocr_data/{}.txt", file_id);
std::fs::write(value_filename, value.clone().unwrap_or_default()).unwrap();
}
}
results.lock().unwrap().insert(region.name, value);
}));
} }
for handle in handles { buffer
handle.await.expect("failed to join task in OCR"); }
}
pub fn bounding_box_images(image: &RgbImage) -> Vec<RgbImage> {
let results = results.lock().unwrap().clone(); let mut trimmed = Vec::new();
results
let boxes = get_character_bounding_boxes(image);
for bounding_box in boxes {
trimmed.push(trim_to_bounding_box(image, &bounding_box));
}
trimmed
}
pub fn compute_box_hashes(image: &RgbImage) -> Vec<String> {
let mut hashes = Vec::new();
let boxes = get_character_bounding_boxes(image);
for bounding_box in boxes {
let trimmed = trim_to_bounding_box(image, &bounding_box);
hashes.push(image_processing::hash_image(&trimmed))
}
hashes
}
#[test]
fn test_bounding_boxes() {
let image_bytes = include_bytes!("test_data/test-image-2.png");
let image = image::load_from_memory(image_bytes).unwrap().to_rgb8();
let boxes = get_character_bounding_boxes(&image);
assert_eq!(boxes.len(), 10);
assert_ne!(boxes[0].x, 0);
assert_ne!(boxes[0].y, 0);
assert_ne!(boxes[0].height, 0);
assert_ne!(boxes[0].width, 0);
}
#[test]
fn test_box_hashes() {
let image_bytes = include_bytes!("test_data/test-image-2.png");
let image = image::load_from_memory(image_bytes).unwrap().to_rgb8();
let hashes = compute_box_hashes(&image);
assert_eq!(hashes.len(), 10);
} }

121
src/remote_ocr.rs Normal file
View File

@ -0,0 +1,121 @@
use std::{
collections::HashMap,
sync::{Arc, Mutex, RwLock},
};
use anyhow::Result;
use image::RgbImage;
use serde::{Deserialize, Serialize};
use crate::{
config::Config,
image_processing::{extract_and_filter, hash_image},
state::OcrCache,
};
#[derive(Serialize, Deserialize, Debug)]
pub struct OcrRegion {
pub confidence: f64,
pub value: String,
}
#[derive(Serialize, Deserialize, Debug)]
struct OcrResult {
regions: Vec<OcrRegion>,
error: Option<String>,
}
async fn run_ocr(image: &RgbImage, url: &str) -> Result<Option<String>> {
let client = reqwest::Client::new();
let response = client
.post(url)
.body(crate::image_processing::to_png_bytes(image))
.send()
.await?;
if !response.status().is_success() {
eprintln!("failed to run OCR query");
anyhow::bail!("failed to run OCR query")
}
let result: OcrResult = response.json().await?;
let result = if result.regions.is_empty() {
None
} else {
let mut buffer = String::new();
for r in &result.regions {
buffer += &r.value;
}
Some(buffer)
};
Ok(result)
}
async fn run_ocr_cached(
ocr_cache: Arc<RwLock<HashMap<String, Option<String>>>>,
hash: String,
region: &crate::image_processing::Region,
config: Arc<Config>,
filtered_image: &image::ImageBuffer<image::Rgb<u8>, Vec<u8>>,
) -> Option<String> {
let cached = {
let locked = ocr_cache.read().unwrap();
locked.get(&hash).cloned()
};
let use_cache = region.use_ocr_cache.unwrap_or(true) && config.use_ocr_cache.unwrap_or(true);
if let Some(cached) = cached {
if use_cache {
return cached;
}
}
match run_ocr(filtered_image, &config.ocr_server_endpoint).await {
Ok(v) => {
if use_cache {
ocr_cache.write().unwrap().insert(hash.clone(), v.clone());
}
v
}
Err(_) => None,
}
}
#[tokio::main(flavor = "current_thread")]
pub async fn ocr_all_regions(
image: &RgbImage,
config: Arc<Config>,
ocr_cache: Arc<OcrCache>,
should_sample: bool,
) -> HashMap<String, Option<String>> {
let results = Arc::new(Mutex::new(HashMap::new()));
let mut handles = Vec::new();
for region in &config.ocr_regions {
let filtered_image = extract_and_filter(image, region);
let region = region.clone();
let results = results.clone();
let config = config.clone();
let ocr_cache = ocr_cache.clone();
handles.push(tokio::spawn(async move {
let filtered_image = filtered_image;
let hash = hash_image(&filtered_image);
let value =
run_ocr_cached(ocr_cache, hash, &region, config.clone(), &filtered_image).await;
if let Some(sample_fraction) = &config.dump_frame_fraction {
if rand::random::<f64>() < *sample_fraction && should_sample {
let file_id = rand::random::<usize>();
let img_filename = format!("ocr_data/{}.png", file_id);
filtered_image.save(img_filename).unwrap();
let value_filename = format!("ocr_data/{}.txt", file_id);
std::fs::write(value_filename, value.clone().unwrap_or_default()).unwrap();
}
}
results.lock().unwrap().insert(region.name, value);
}));
}
for handle in handles {
handle.await.expect("failed to join task in OCR");
}
let results = results.lock().unwrap().clone();
results
}

View File

@ -66,13 +66,11 @@ impl LapState {
fn median_wear(values: Vec<Option<usize>>) -> Option<usize> { fn median_wear(values: Vec<Option<usize>>) -> Option<usize> {
let mut wear_values = Vec::new(); let mut wear_values = Vec::new();
let mut last_value = 100; let mut last_value = 100;
for val in values { for val in values.into_iter().flatten() {
if let Some(val) = val { if val < last_value {
if val < last_value { wear_values.push(last_value - val);
wear_values.push(last_value - val);
}
last_value = val;
} }
last_value = val;
} }
wear_values.sort_unstable(); wear_values.sort_unstable();
wear_values.get(wear_values.len() / 2).cloned() wear_values.get(wear_values.len() / 2).cloned()

View File

@ -10,7 +10,7 @@ use eframe::{
use egui_extras::RetainedImage; use egui_extras::RetainedImage;
use image::RgbImage; use image::RgbImage;
use crate::{image_processing::to_png_bytes, local_ocr}; use crate::{image_processing::to_png_bytes, ocr};
#[derive(Default)] #[derive(Default)]
struct TrainingUi { struct TrainingUi {
@ -24,7 +24,6 @@ struct TrainingUi {
struct TrainingImage { struct TrainingImage {
img_file: PathBuf, img_file: PathBuf,
data_file: PathBuf, data_file: PathBuf,
image: RgbImage,
text: String, text: String,
ui_image: RetainedImage, ui_image: RetainedImage,
@ -74,15 +73,14 @@ fn get_training_data() -> Vec<TrainingImage> {
let ocr_value = String::from_utf8(std::fs::read(&ocr_file).unwrap()).unwrap(); let ocr_value = String::from_utf8(std::fs::read(&ocr_file).unwrap()).unwrap();
let ui_image = load_retained_image(&image); let ui_image = load_retained_image(&image);
let char_images = local_ocr::bounding_box_images(&image) let char_images = ocr::bounding_box_images(&image)
.iter() .iter()
.map(load_retained_image) .map(load_retained_image)
.collect(); .collect();
let char_hashes = local_ocr::compute_box_hashes(&image); let char_hashes = ocr::compute_box_hashes(&image);
data.push(TrainingImage { data.push(TrainingImage {
img_file, img_file,
data_file: ocr_file, data_file: ocr_file,
image,
text: ocr_value, text: ocr_value,
ui_image, ui_image,
char_images, char_images,
@ -170,7 +168,7 @@ impl eframe::App for TrainingUi {
.write(true) .write(true)
.open("learned_chars.txt") .open("learned_chars.txt")
.unwrap(); .unwrap();
file.write(buffer.as_bytes()).unwrap(); file.write_all(buffer.as_bytes()).unwrap();
} }
current_image.ui_image.show(ui); current_image.ui_image.show(ui);