2022-05-22 18:03:47 +00:00
|
|
|
use std::{
|
|
|
|
collections::HashMap,
|
2022-05-22 20:25:47 +00:00
|
|
|
sync::{Arc, Mutex, RwLock},
|
2022-05-22 18:03:47 +00:00
|
|
|
};
|
2022-05-21 18:12:10 +00:00
|
|
|
|
|
|
|
use anyhow::Result;
|
|
|
|
use image::RgbImage;
|
|
|
|
use serde::{Deserialize, Serialize};
|
|
|
|
|
2022-05-22 18:03:47 +00:00
|
|
|
use crate::{
|
|
|
|
config::{Config, LearnedConfig},
|
2022-05-22 23:01:24 +00:00
|
|
|
image_processing::{hash_image, extract_and_filter},
|
2022-05-22 18:03:47 +00:00
|
|
|
};
|
2022-05-21 18:12:10 +00:00
|
|
|
|
|
|
|
#[derive(Serialize, Deserialize, Debug)]
|
|
|
|
pub struct OcrRegion {
|
|
|
|
pub confidence: f64,
|
|
|
|
pub value: String,
|
|
|
|
}
|
|
|
|
|
|
|
|
#[derive(Serialize, Deserialize, Debug)]
|
|
|
|
struct OcrResult {
|
|
|
|
regions: Vec<OcrRegion>,
|
|
|
|
error: Option<String>,
|
|
|
|
}
|
|
|
|
|
2022-05-22 18:03:47 +00:00
|
|
|
async fn run_ocr(image: &RgbImage, url: &str) -> Result<Option<String>> {
|
2022-05-21 18:12:10 +00:00
|
|
|
let client = reqwest::Client::new();
|
|
|
|
let response = client
|
2022-05-22 18:03:47 +00:00
|
|
|
.post(url)
|
|
|
|
.body(crate::image_processing::to_png_bytes(image))
|
2022-05-21 18:12:10 +00:00
|
|
|
.send()
|
|
|
|
.await?;
|
|
|
|
|
|
|
|
if !response.status().is_success() {
|
2022-05-22 18:03:47 +00:00
|
|
|
eprintln!("failed to run OCR query");
|
2022-05-21 18:12:10 +00:00
|
|
|
anyhow::bail!("failed to run OCR query")
|
|
|
|
}
|
|
|
|
let result: OcrResult = response.json().await?;
|
2022-05-22 18:03:47 +00:00
|
|
|
let result = if result.regions.is_empty() {
|
|
|
|
None
|
|
|
|
} else {
|
|
|
|
let mut buffer = String::new();
|
|
|
|
for r in &result.regions {
|
|
|
|
buffer += &r.value;
|
|
|
|
}
|
|
|
|
Some(buffer)
|
|
|
|
};
|
|
|
|
Ok(result)
|
2022-05-21 18:12:10 +00:00
|
|
|
}
|
|
|
|
|
2022-05-22 18:03:47 +00:00
|
|
|
#[tokio::main(flavor = "current_thread")]
|
|
|
|
pub async fn ocr_all_regions(
|
|
|
|
image: &RgbImage,
|
|
|
|
config: Arc<Config>,
|
|
|
|
learned: Arc<LearnedConfig>,
|
2022-05-22 20:25:47 +00:00
|
|
|
ocr_cache: Arc<RwLock<HashMap<String, Option<String>>>>,
|
2022-05-22 18:03:47 +00:00
|
|
|
) -> HashMap<String, Option<String>> {
|
2022-05-21 18:12:10 +00:00
|
|
|
let results = Arc::new(Mutex::new(HashMap::new()));
|
|
|
|
|
|
|
|
let mut handles = Vec::new();
|
2022-05-22 18:03:47 +00:00
|
|
|
for region in &config.ocr_regions {
|
2022-05-22 20:25:47 +00:00
|
|
|
let filtered_image = extract_and_filter(image, region);
|
2022-05-21 18:12:10 +00:00
|
|
|
let region = region.clone();
|
|
|
|
let results = results.clone();
|
2022-05-22 18:03:47 +00:00
|
|
|
let config = config.clone();
|
|
|
|
let learned = learned.clone();
|
2022-05-22 20:25:47 +00:00
|
|
|
let ocr_cache = ocr_cache.clone();
|
2022-05-21 18:12:10 +00:00
|
|
|
handles.push(tokio::spawn(async move {
|
2022-05-22 20:25:47 +00:00
|
|
|
let filtered_image = filtered_image;
|
|
|
|
let hash = hash_image(&filtered_image);
|
2022-05-22 18:03:47 +00:00
|
|
|
let value = if let Some(learned_value) = learned.learned_images.get(&hash) {
|
|
|
|
Some(learned_value.clone())
|
|
|
|
} else {
|
2022-05-22 20:25:47 +00:00
|
|
|
let cached = {
|
|
|
|
let locked = ocr_cache.read().unwrap();
|
|
|
|
locked.get(&hash).cloned()
|
|
|
|
};
|
|
|
|
if let Some(cached) = cached {
|
|
|
|
cached
|
|
|
|
} else {
|
|
|
|
match run_ocr(&filtered_image, &config.ocr_server_endpoint).await {
|
|
|
|
Ok(v) => {
|
|
|
|
if config.use_ocr_cache.unwrap_or(true) {
|
|
|
|
ocr_cache.write().unwrap().insert(hash.clone(), v.clone());
|
|
|
|
}
|
|
|
|
v
|
|
|
|
}
|
|
|
|
Err(_) => None
|
|
|
|
}
|
|
|
|
}
|
2022-05-21 18:12:10 +00:00
|
|
|
};
|
|
|
|
results.lock().unwrap().insert(region.name, value);
|
|
|
|
}));
|
|
|
|
}
|
|
|
|
for handle in handles {
|
|
|
|
handle.await.expect("failed to join task in OCR");
|
|
|
|
}
|
|
|
|
|
|
|
|
let results = results.lock().unwrap().clone();
|
|
|
|
results
|
2022-05-22 18:03:47 +00:00
|
|
|
}
|