supper/src/ocr.rs

71 lines
2.1 KiB
Rust
Raw Normal View History

2022-05-21 18:12:10 +00:00
use std::{collections::HashMap, sync::{Arc, Mutex}};
use anyhow::Result;
use image::RgbImage;
use serde::{Deserialize, Serialize};
use crate::image_processing::{Region, extract_region, filter_to_white};
#[derive(Serialize, Deserialize, Debug)]
pub struct OcrRegion {
pub confidence: f64,
pub value: String,
}
#[derive(Serialize, Deserialize, Debug)]
struct OcrResult {
regions: Vec<OcrRegion>,
error: Option<String>,
}
async fn run_ocr(image: &RgbImage) -> Result<Vec<OcrRegion>> {
let client = reqwest::Client::new();
let response = client
.post("http://localhost:3000/")
.body(crate::image_processing::to_png_bytes(&image))
.send()
.await?;
if !response.status().is_success() {
anyhow::bail!("failed to run OCR query")
}
let result: OcrResult = response.json().await?;
Ok(result.regions)
}
pub async fn ocr_all_regions(image: &RgbImage, regions: &[Region]) -> HashMap<String, Option<String>> {
let results = Arc::new(Mutex::new(HashMap::new()));
let mut handles = Vec::new();
for region in regions {
let filtered_image = extract_region(image, region);
let region = region.clone();
let results = results.clone();
handles.push(tokio::spawn(async move {
let mut image = filtered_image;
2022-05-22 17:19:13 +00:00
filter_to_white(&mut image);
2022-05-21 18:12:10 +00:00
let ocr_results = run_ocr(&image).await;
let value = match ocr_results {
Ok(ocr_regions) => {
if ocr_regions.is_empty() {
None
} else {
let mut out = String::new();
for r in ocr_regions {
out += &r.value;
}
Some(out)
}
}
Err(_) => None
};
results.lock().unwrap().insert(region.name, value);
}));
}
for handle in handles {
handle.await.expect("failed to join task in OCR");
}
let results = results.lock().unwrap().clone();
results
}