Files
whisper-stt/app/main.py
2026-04-20 06:15:35 +09:00

154 lines
6.6 KiB
Python

import os
import uuid
import time
import glob
import aiofiles
from fastapi import FastAPI, UploadFile, File, HTTPException, Depends, Form, Request
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
from auth import authenticate, create_access_token, require_auth
from tasks import celery_app, transcribe_task
from ocr_tasks import ocr_task
app = FastAPI(title="VoiceScript API")
UPLOAD_DIR = os.getenv("UPLOAD_DIR", "/data/uploads")
OUTPUT_DIR = os.getenv("OUTPUT_DIR", "/data/outputs")
MAX_UPLOAD_BYTES = int(os.getenv("MAX_UPLOAD_MB", "500")) * 1024 * 1024
OUTPUT_KEEP_SECS = int(os.getenv("OUTPUT_KEEP_HOURS", "48")) * 3600
os.makedirs(UPLOAD_DIR, exist_ok=True)
os.makedirs(OUTPUT_DIR, exist_ok=True)
AUDIO_EXT = {"mp3","mp4","wav","m4a","ogg","flac","aac","wma","webm","mkv","avi","mov"}
IMAGE_EXT = {"jpg","jpeg","png","bmp","tiff","tif","webp","gif"}
# ── 인증 ──────────────────────────────────────────────────────
@app.post("/api/login")
def login(username: str = Form(...), password: str = Form(...)):
if not authenticate(username, password):
raise HTTPException(status_code=401, detail="아이디 또는 비밀번호가 올바르지 않습니다")
return {"access_token": create_access_token(username), "token_type": "bearer"}
@app.get("/api/me")
def me(user: str = Depends(require_auth)):
return {"username": user}
# ── STT ───────────────────────────────────────────────────────
@app.post("/api/transcribe")
async def transcribe(request: Request, file: UploadFile = File(...),
_: str = Depends(require_auth)):
_check_size(request)
ext = _ext(file.filename)
if ext not in AUDIO_EXT:
raise HTTPException(400, f"지원하지 않는 형식. 지원: {', '.join(sorted(AUDIO_EXT))}")
file_id = str(uuid.uuid4())
save_path = os.path.join(UPLOAD_DIR, f"{file_id}.{ext}")
await _save(file, save_path)
task = transcribe_task.delay(file_id, save_path)
return {"task_id": task.id, "file_id": file_id, "filename": file.filename}
# ── OCR ───────────────────────────────────────────────────────
@app.post("/api/ocr")
async def ocr(
request: Request,
file: UploadFile = File(...),
mode: str = Form("text"),
backend: str = Form("paddle"),
ollama_model: str = Form("granite3.2-vision"),
custom_prompt: str = Form(""),
_: str = Depends(require_auth),
):
_check_size(request)
ext = _ext(file.filename)
if ext not in IMAGE_EXT:
raise HTTPException(400, f"지원하지 않는 형식. 지원: {', '.join(sorted(IMAGE_EXT))}")
if mode not in ("text", "structure"): mode = "text"
if backend not in ("paddle", "ollama"): backend = "paddle"
file_id = str(uuid.uuid4())
save_path = os.path.join(UPLOAD_DIR, f"{file_id}.{ext}")
await _save(file, save_path)
task = ocr_task.delay(file_id, save_path, mode, backend, ollama_model, custom_prompt)
return {"task_id": task.id, "file_id": file_id,
"filename": file.filename, "mode": mode, "backend": backend}
# ── 상태 조회 (celery_app.AsyncResult 사용) ───────────────────
@app.get("/api/status/{task_id}")
def get_status(task_id: str, _: str = Depends(require_auth)):
r = celery_app.AsyncResult(task_id)
if r.state == "PENDING":
return {"state": "pending", "progress": 0, "message": "대기 중..."}
if r.state == "PROGRESS":
m = r.info or {}
return {"state": "progress", "progress": m.get("progress", 0),
"message": m.get("message", "처리 중...")}
if r.state == "SUCCESS":
return {"state": "success", "progress": 100, **r.result}
if r.state == "FAILURE":
return {"state": "failure", "progress": 0, "message": str(r.info)}
return {"state": r.state.lower(), "progress": 0}
# ── 다운로드 ──────────────────────────────────────────────────
@app.get("/api/download/{filename}")
def download(filename: str, _: str = Depends(require_auth)):
if ".." in filename or "/" in filename:
raise HTTPException(400, "잘못된 파일명")
path = os.path.join(OUTPUT_DIR, filename)
if not os.path.exists(path):
raise HTTPException(404, "파일을 찾을 수 없습니다")
media = ("application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
if filename.endswith(".xlsx") else "text/plain")
return FileResponse(path, media_type=media, filename=filename)
# ── 결과 파일 정리 ────────────────────────────────────────────
@app.post("/api/cleanup")
def cleanup(_: str = Depends(require_auth)):
return {"removed": _cleanup_outputs()}
@app.on_event("startup")
async def on_startup():
_cleanup_outputs()
# ── 유틸 ──────────────────────────────────────────────────────
def _check_size(request: Request):
cl = request.headers.get("content-length")
if cl and int(cl) > MAX_UPLOAD_BYTES:
raise HTTPException(413, f"파일이 너무 큽니다. 최대 {MAX_UPLOAD_BYTES//1024//1024}MB")
def _cleanup_outputs() -> int:
if OUTPUT_KEEP_SECS == 0:
return 0
cutoff = time.time() - OUTPUT_KEEP_SECS
removed = 0
for f in glob.glob(os.path.join(OUTPUT_DIR, "*")):
try:
if os.path.getmtime(f) < cutoff:
os.remove(f); removed += 1
except Exception:
pass
return removed
def _ext(fn):
return fn.rsplit(".", 1)[-1].lower() if "." in fn else ""
async def _save(file: UploadFile, path: str):
written = 0
async with aiofiles.open(path, "wb") as f:
while chunk := await file.read(1024 * 1024):
written += len(chunk)
if written > MAX_UPLOAD_BYTES:
await f.close()
os.remove(path)
raise HTTPException(413, f"파일이 너무 큽니다. 최대 {MAX_UPLOAD_BYTES//1024//1024}MB")
await f.write(chunk)
app.mount("/", StaticFiles(directory="static", html=True), name="static")