feat: VoiceScript STT+OCR 초기 버전
This commit is contained in:
288
app/ocr_tasks.py
Normal file
288
app/ocr_tasks.py
Normal file
@@ -0,0 +1,288 @@
|
||||
"""
|
||||
OCR Celery Tasks
|
||||
- PaddleOCR 3.x 호환 (use_gpu/show_log/cls 파라미터 제거, 결과구조 변경 반영)
|
||||
- backend="paddle" → PaddleOCR 로컬 실행
|
||||
- backend="ollama" → Ollama Vision API 호출
|
||||
"""
|
||||
import os
|
||||
import base64
|
||||
|
||||
import httpx
|
||||
from celery import Celery
|
||||
import openpyxl
|
||||
from openpyxl.styles import Font, PatternFill, Alignment, Border, Side
|
||||
|
||||
REDIS_URL = os.getenv("REDIS_URL", "redis://redis:6379/0")
|
||||
OUTPUT_DIR = os.getenv("OUTPUT_DIR", "/data/outputs")
|
||||
OCR_LANG = os.getenv("OCR_LANG", "korean")
|
||||
OLLAMA_URL = os.getenv("OLLAMA_URL", "http://192.168.0.126:11434")
|
||||
OLLAMA_TIMEOUT = int(os.getenv("OLLAMA_TIMEOUT", "180"))
|
||||
|
||||
celery_app = Celery("ocr_tasks", broker=REDIS_URL, backend=REDIS_URL)
|
||||
celery_app.conf.update(
|
||||
task_serializer="json",
|
||||
result_serializer="json",
|
||||
accept_content=["json"],
|
||||
task_track_started=True,
|
||||
result_expires=3600,
|
||||
)
|
||||
|
||||
# PaddleOCR 싱글톤
|
||||
_ocr_engine = None
|
||||
_struct_engine = None
|
||||
|
||||
def get_ocr():
|
||||
global _ocr_engine
|
||||
if _ocr_engine is None:
|
||||
from paddleocr import PaddleOCR
|
||||
print(f"[PaddleOCR] 로딩 (lang={OCR_LANG})")
|
||||
# PaddleOCR 3.x: use_gpu/show_log 파라미터 제거됨
|
||||
_ocr_engine = PaddleOCR(use_angle_cls=True, lang=OCR_LANG)
|
||||
print("[PaddleOCR] 완료")
|
||||
return _ocr_engine
|
||||
|
||||
def get_structure():
|
||||
global _struct_engine
|
||||
if _struct_engine is None:
|
||||
from paddleocr import PPStructure
|
||||
print("[PPStructure] 로딩")
|
||||
_struct_engine = PPStructure(table=True, ocr=True, lang=OCR_LANG)
|
||||
print("[PPStructure] 완료")
|
||||
return _struct_engine
|
||||
|
||||
|
||||
# ════════════════════════════════════════════════════════════════
|
||||
# 메인 Task
|
||||
# ════════════════════════════════════════════════════════════════
|
||||
@celery_app.task(bind=True, name="tasks.ocr_task", queue="ocr")
|
||||
def ocr_task(self, file_id, image_path, mode="text",
|
||||
backend="paddle", ollama_model="granite3.2-vision", custom_prompt=""):
|
||||
self.update_state(state="PROGRESS", meta={"progress": 8, "message": "엔진 준비 중..."})
|
||||
try:
|
||||
if backend == "ollama":
|
||||
result = _run_ollama(self, file_id, image_path, mode, ollama_model, custom_prompt)
|
||||
else:
|
||||
result = _run_paddle(self, file_id, image_path, mode)
|
||||
try: os.remove(image_path)
|
||||
except: pass
|
||||
return result
|
||||
except Exception as e:
|
||||
try: os.remove(image_path)
|
||||
except: pass
|
||||
raise Exception(f"OCR 실패: {str(e)}")
|
||||
|
||||
|
||||
# ════════════════════════════════════════════════════════════════
|
||||
# Ollama 백엔드
|
||||
# ════════════════════════════════════════════════════════════════
|
||||
_OLLAMA_PROMPTS = {
|
||||
"text": "이 이미지에서 모든 텍스트를 정확하게 추출해줘. 원본의 줄 구분과 단락 구조를 유지해줘.",
|
||||
"structure": "이 이미지를 분석해서 표는 마크다운 표 형식으로, 나머지 텍스트는 원본 구조를 유지하며 추출해줘.",
|
||||
}
|
||||
|
||||
def _run_ollama(task, file_id, image_path, mode, ollama_model, custom_prompt):
|
||||
task.update_state(state="PROGRESS",
|
||||
meta={"progress": 15, "message": f"Ollama ({ollama_model}) 연결 중..."})
|
||||
with open(image_path, "rb") as f:
|
||||
img_b64 = base64.b64encode(f.read()).decode()
|
||||
prompt = custom_prompt.strip() or _OLLAMA_PROMPTS.get(mode, _OLLAMA_PROMPTS["text"])
|
||||
task.update_state(state="PROGRESS", meta={"progress": 30, "message": "모델 추론 중..."})
|
||||
try:
|
||||
resp = httpx.post(f"{OLLAMA_URL}/api/chat", json={
|
||||
"model": ollama_model,
|
||||
"messages": [{"role": "user", "content": prompt, "images": [img_b64]}],
|
||||
"stream": False,
|
||||
"options": {"temperature": 0.1},
|
||||
}, timeout=float(OLLAMA_TIMEOUT))
|
||||
resp.raise_for_status()
|
||||
except httpx.ConnectError:
|
||||
raise Exception(f"Ollama 서버 연결 실패 ({OLLAMA_URL})")
|
||||
except httpx.TimeoutException:
|
||||
raise Exception(f"Ollama 응답 시간 초과 ({OLLAMA_TIMEOUT}초). OLLAMA_TIMEOUT 값을 늘려주세요.")
|
||||
|
||||
task.update_state(state="PROGRESS", meta={"progress": 85, "message": "결과 저장 중..."})
|
||||
full_text = resp.json().get("message", {}).get("content", "").strip()
|
||||
if not full_text:
|
||||
raise Exception("Ollama 빈 응답. 모델이 설치되어 있는지 확인하세요.")
|
||||
|
||||
tables = _parse_md_tables(full_text) if mode == "structure" else []
|
||||
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
||||
txt_file = f"{file_id}_ocr.txt"
|
||||
with open(os.path.join(OUTPUT_DIR, txt_file), "w", encoding="utf-8") as f:
|
||||
f.write(f"# OCR 결과 (Ollama / {ollama_model})\n\n{full_text}")
|
||||
xlsx_file = None
|
||||
if tables:
|
||||
xlsx_file = f"{file_id}_tables.xlsx"
|
||||
_save_excel(tables, os.path.join(OUTPUT_DIR, xlsx_file))
|
||||
tables_html = [_md_table_to_html(t) for t in tables]
|
||||
lines = [{"text": l, "confidence": 1.0, "bbox": []}
|
||||
for l in full_text.splitlines() if l.strip()]
|
||||
return {
|
||||
"mode": mode, "backend": "ollama", "ollama_model": ollama_model,
|
||||
"full_text": full_text, "lines": lines, "line_count": len(lines),
|
||||
"txt_file": txt_file,
|
||||
"tables": [{"html": h, "rows": len(t),
|
||||
"cols": max(len(r) for r in t) if t else 0}
|
||||
for h, t in zip(tables_html, tables)],
|
||||
"xlsx_file": xlsx_file,
|
||||
}
|
||||
|
||||
|
||||
# ════════════════════════════════════════════════════════════════
|
||||
# PaddleOCR 백엔드
|
||||
# ════════════════════════════════════════════════════════════════
|
||||
def _run_paddle(task, file_id, image_path, mode):
|
||||
import cv2
|
||||
img = cv2.imread(image_path)
|
||||
if img is None:
|
||||
raise ValueError("이미지를 읽을 수 없습니다")
|
||||
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
||||
return _paddle_structure(task, file_id, img) if mode == "structure" \
|
||||
else _paddle_text(task, file_id, img)
|
||||
|
||||
|
||||
def _paddle_text(task, file_id, img):
|
||||
task.update_state(state="PROGRESS", meta={"progress": 30, "message": "텍스트 인식 중..."})
|
||||
# PaddleOCR 3.x: cls 파라미터 제거, 결과 구조 변경
|
||||
result = get_ocr().ocr(img)
|
||||
task.update_state(state="PROGRESS", meta={"progress": 80, "message": "결과 정리 중..."})
|
||||
|
||||
lines = []
|
||||
if result and len(result) > 0:
|
||||
r = result[0]
|
||||
# PaddleOCR 3.x 결과 구조: dict with rec_texts, rec_scores
|
||||
if isinstance(r, dict):
|
||||
texts = r.get("rec_texts", [])
|
||||
scores = r.get("rec_scores", [])
|
||||
for text, conf in zip(texts, scores):
|
||||
if text.strip():
|
||||
lines.append({"text": text,
|
||||
"confidence": round(float(conf), 3),
|
||||
"bbox": []})
|
||||
# 구버전 호환 (list of [bbox, (text, conf)])
|
||||
elif isinstance(r, list):
|
||||
for item in r:
|
||||
if item and len(item) == 2:
|
||||
_, (text, conf) = item
|
||||
if text.strip():
|
||||
lines.append({"text": text,
|
||||
"confidence": round(float(conf), 3),
|
||||
"bbox": []})
|
||||
|
||||
full_text = "\n".join(l["text"] for l in lines)
|
||||
txt_file = f"{file_id}_ocr.txt"
|
||||
with open(os.path.join(OUTPUT_DIR, txt_file), "w", encoding="utf-8") as f:
|
||||
f.write(full_text)
|
||||
return {"mode": "text", "backend": "paddle",
|
||||
"full_text": full_text, "lines": lines,
|
||||
"line_count": len(lines), "txt_file": txt_file,
|
||||
"tables": [], "xlsx_file": None}
|
||||
|
||||
|
||||
def _paddle_structure(task, file_id, img):
|
||||
task.update_state(state="PROGRESS", meta={"progress": 20, "message": "레이아웃 분석 중..."})
|
||||
result = get_structure()(img)
|
||||
task.update_state(state="PROGRESS", meta={"progress": 60, "message": "표 구조 추출 중..."})
|
||||
|
||||
text_blocks, tables_html, tables_data = [], [], []
|
||||
for region in result:
|
||||
rtype = region.get("type", "").lower()
|
||||
if rtype == "table":
|
||||
html = region.get("res", {}).get("html", "")
|
||||
if html:
|
||||
tables_html.append(html)
|
||||
tables_data.append(_html_table_to_list(html))
|
||||
elif rtype in ("text", "title", "figure_caption"):
|
||||
for line in (region.get("res", []) or []):
|
||||
if isinstance(line, (list, tuple)) and len(line) == 2:
|
||||
_, (text, _conf) = line
|
||||
text_blocks.append(text)
|
||||
|
||||
full_text = "\n".join(text_blocks)
|
||||
task.update_state(state="PROGRESS", meta={"progress": 80, "message": "Excel 생성 중..."})
|
||||
|
||||
xlsx_file = None
|
||||
if tables_data:
|
||||
xlsx_file = f"{file_id}_tables.xlsx"
|
||||
_save_excel(tables_data, os.path.join(OUTPUT_DIR, xlsx_file))
|
||||
|
||||
txt_file = f"{file_id}_ocr.txt"
|
||||
with open(os.path.join(OUTPUT_DIR, txt_file), "w", encoding="utf-8") as f:
|
||||
f.write("# 텍스트\n\n" + full_text)
|
||||
|
||||
lines = [{"text": t, "confidence": 1.0, "bbox": []} for t in text_blocks]
|
||||
tables_meta = [{"html": h, "rows": len(d),
|
||||
"cols": max(len(r) for r in d) if d else 0}
|
||||
for h, d in zip(tables_html, tables_data)]
|
||||
return {"mode": "structure", "backend": "paddle",
|
||||
"full_text": full_text, "lines": lines,
|
||||
"line_count": len(lines), "txt_file": txt_file,
|
||||
"tables": tables_meta, "xlsx_file": xlsx_file}
|
||||
|
||||
|
||||
# ════════════════════════════════════════════════════════════════
|
||||
# 공통 유틸
|
||||
# ════════════════════════════════════════════════════════════════
|
||||
def _parse_md_tables(text):
|
||||
tables, current = [], []
|
||||
for line in text.splitlines():
|
||||
s = line.strip()
|
||||
if s.startswith("|") and s.endswith("|"):
|
||||
if all(c in "| -:" for c in s): continue
|
||||
current.append([c.strip() for c in s.strip("|").split("|")])
|
||||
else:
|
||||
if len(current) >= 2: tables.append(current)
|
||||
current = []
|
||||
if len(current) >= 2: tables.append(current)
|
||||
return tables
|
||||
|
||||
def _md_table_to_html(table):
|
||||
if not table: return ""
|
||||
rows = ""
|
||||
for i, row in enumerate(table):
|
||||
tag = "th" if i == 0 else "td"
|
||||
cells = "".join(f"<{tag}>{c}</{tag}>" for c in row)
|
||||
rows += f"<tr>{cells}</tr>"
|
||||
return f"<table>{rows}</table>"
|
||||
|
||||
def _html_table_to_list(html):
|
||||
from html.parser import HTMLParser
|
||||
class P(HTMLParser):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.rows, self._row, self._cell, self._in = [], [], [], False
|
||||
def handle_starttag(self, tag, attrs):
|
||||
if tag == "tr": self._row = []
|
||||
elif tag in ("td","th"): self._cell = []; self._in = True
|
||||
def handle_endtag(self, tag):
|
||||
if tag in ("td","th"):
|
||||
self._row.append("".join(self._cell).strip()); self._in = False
|
||||
elif tag == "tr":
|
||||
if self._row: self.rows.append(self._row)
|
||||
def handle_data(self, data):
|
||||
if self._in: self._cell.append(data)
|
||||
p = P(); p.feed(html); return p.rows
|
||||
|
||||
def _save_excel(tables, path):
|
||||
wb = openpyxl.Workbook()
|
||||
wb.remove(wb.active)
|
||||
for i, table in enumerate(tables, 1):
|
||||
ws = wb.create_sheet(f"표 {i}")
|
||||
thin = Side(style="thin", color="2A2A33")
|
||||
bdr = Border(left=thin, right=thin, top=thin, bottom=thin)
|
||||
for r_idx, row in enumerate(table, 1):
|
||||
for c_idx, val in enumerate(row, 1):
|
||||
cell = ws.cell(row=r_idx, column=c_idx, value=val)
|
||||
cell.border = bdr
|
||||
cell.alignment = Alignment(horizontal="center",
|
||||
vertical="center", wrap_text=True)
|
||||
if r_idx == 1:
|
||||
cell.fill = PatternFill("solid", fgColor="1A1A2E")
|
||||
cell.font = Font(color="00E5A0", bold=True, size=10)
|
||||
else:
|
||||
cell.font = Font(size=10)
|
||||
for col in ws.columns:
|
||||
w = max((len(str(c.value or "")) for c in col), default=8)
|
||||
ws.column_dimensions[col[0].column_letter].width = min(w + 4, 40)
|
||||
if not wb.sheetnames: wb.create_sheet("Sheet1")
|
||||
wb.save(path)
|
||||
Reference in New Issue
Block a user