supper/src/ocr.rs

126 lines
3.8 KiB
Rust
Raw Normal View History

2022-05-22 18:03:47 +00:00
use std::{
collections::HashMap,
2022-05-22 20:25:47 +00:00
sync::{Arc, Mutex, RwLock},
2022-05-22 18:03:47 +00:00
};
2022-05-21 18:12:10 +00:00
use anyhow::Result;
use image::RgbImage;
use serde::{Deserialize, Serialize};
2022-05-22 18:03:47 +00:00
use crate::{
config::{Config, LearnedConfig},
2022-05-24 02:39:13 +00:00
image_processing::{extract_and_filter, hash_image}, state::OcrCache,
2022-05-22 18:03:47 +00:00
};
2022-05-21 18:12:10 +00:00
#[derive(Serialize, Deserialize, Debug)]
pub struct OcrRegion {
pub confidence: f64,
pub value: String,
}
#[derive(Serialize, Deserialize, Debug)]
struct OcrResult {
regions: Vec<OcrRegion>,
error: Option<String>,
}
2022-05-22 18:03:47 +00:00
async fn run_ocr(image: &RgbImage, url: &str) -> Result<Option<String>> {
2022-05-21 18:12:10 +00:00
let client = reqwest::Client::new();
let response = client
2022-05-22 18:03:47 +00:00
.post(url)
.body(crate::image_processing::to_png_bytes(image))
2022-05-21 18:12:10 +00:00
.send()
.await?;
if !response.status().is_success() {
2022-05-22 18:03:47 +00:00
eprintln!("failed to run OCR query");
2022-05-21 18:12:10 +00:00
anyhow::bail!("failed to run OCR query")
}
let result: OcrResult = response.json().await?;
2022-05-22 18:03:47 +00:00
let result = if result.regions.is_empty() {
None
} else {
let mut buffer = String::new();
for r in &result.regions {
buffer += &r.value;
}
Some(buffer)
};
Ok(result)
2022-05-21 18:12:10 +00:00
}
2022-05-24 00:58:28 +00:00
async fn run_ocr_cached(
ocr_cache: Arc<RwLock<HashMap<String, Option<String>>>>,
hash: String,
region: &crate::image_processing::Region,
config: Arc<Config>,
2022-06-03 02:57:45 +00:00
filtered_image: &image::ImageBuffer<image::Rgb<u8>, Vec<u8>>,
2022-05-24 00:58:28 +00:00
) -> Option<String> {
let cached = {
let locked = ocr_cache.read().unwrap();
locked.get(&hash).cloned()
};
let use_cache = region.use_ocr_cache.unwrap_or(true) && config.use_ocr_cache.unwrap_or(true);
if let Some(cached) = cached {
if use_cache {
return cached;
}
}
2022-06-03 02:57:45 +00:00
match run_ocr(filtered_image, &config.ocr_server_endpoint).await {
2022-05-24 00:58:28 +00:00
Ok(v) => {
if use_cache {
ocr_cache.write().unwrap().insert(hash.clone(), v.clone());
}
v
}
Err(_) => None,
}
}
2022-05-22 18:03:47 +00:00
#[tokio::main(flavor = "current_thread")]
pub async fn ocr_all_regions(
image: &RgbImage,
config: Arc<Config>,
learned: Arc<LearnedConfig>,
2022-05-24 02:39:13 +00:00
ocr_cache: Arc<OcrCache>,
2022-06-03 02:57:45 +00:00
should_sample: bool
2022-05-22 18:03:47 +00:00
) -> HashMap<String, Option<String>> {
2022-05-21 18:12:10 +00:00
let results = Arc::new(Mutex::new(HashMap::new()));
let mut handles = Vec::new();
2022-05-22 18:03:47 +00:00
for region in &config.ocr_regions {
2022-05-22 20:25:47 +00:00
let filtered_image = extract_and_filter(image, region);
2022-05-21 18:12:10 +00:00
let region = region.clone();
let results = results.clone();
2022-05-22 18:03:47 +00:00
let config = config.clone();
let learned = learned.clone();
2022-05-22 20:25:47 +00:00
let ocr_cache = ocr_cache.clone();
2022-05-21 18:12:10 +00:00
handles.push(tokio::spawn(async move {
2022-05-22 20:25:47 +00:00
let filtered_image = filtered_image;
let hash = hash_image(&filtered_image);
2022-05-22 18:03:47 +00:00
let value = if let Some(learned_value) = learned.learned_images.get(&hash) {
Some(learned_value.clone())
} else {
2022-06-03 02:57:45 +00:00
run_ocr_cached(ocr_cache, hash, &region, config.clone(), &filtered_image).await
2022-05-21 18:12:10 +00:00
};
2022-06-03 02:57:45 +00:00
if let Some(sample_fraction) = &config.dump_frame_fraction {
if rand::random::<f64>() < *sample_fraction {
let file_id = rand::random::<usize>();
let img_filename = format!("ocr_data/{}.png", file_id);
filtered_image.save(img_filename).unwrap();
let value_filename = format!("ocr_data/{}.txt", file_id);
std::fs::write(value_filename, value.clone().unwrap_or_default()).unwrap();
}
}
2022-05-21 18:12:10 +00:00
results.lock().unwrap().insert(region.name, value);
}));
}
for handle in handles {
handle.await.expect("failed to join task in OCR");
}
let results = results.lock().unwrap().clone();
results
2022-05-22 18:03:47 +00:00
}