154 lines
6.6 KiB
Python
154 lines
6.6 KiB
Python
import os
|
|
import uuid
|
|
import time
|
|
import glob
|
|
import aiofiles
|
|
from fastapi import FastAPI, UploadFile, File, HTTPException, Depends, Form, Request
|
|
from fastapi.staticfiles import StaticFiles
|
|
from fastapi.responses import FileResponse
|
|
|
|
from auth import authenticate, create_access_token, require_auth
|
|
from tasks import celery_app, transcribe_task
|
|
from ocr_tasks import ocr_task
|
|
|
|
app = FastAPI(title="VoiceScript API")
|
|
|
|
UPLOAD_DIR = os.getenv("UPLOAD_DIR", "/data/uploads")
|
|
OUTPUT_DIR = os.getenv("OUTPUT_DIR", "/data/outputs")
|
|
MAX_UPLOAD_BYTES = int(os.getenv("MAX_UPLOAD_MB", "500")) * 1024 * 1024
|
|
OUTPUT_KEEP_SECS = int(os.getenv("OUTPUT_KEEP_HOURS", "48")) * 3600
|
|
|
|
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
|
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
|
|
|
AUDIO_EXT = {"mp3","mp4","wav","m4a","ogg","flac","aac","wma","webm","mkv","avi","mov"}
|
|
IMAGE_EXT = {"jpg","jpeg","png","bmp","tiff","tif","webp","gif"}
|
|
|
|
|
|
# ── 인증 ──────────────────────────────────────────────────────
|
|
@app.post("/api/login")
|
|
def login(username: str = Form(...), password: str = Form(...)):
|
|
if not authenticate(username, password):
|
|
raise HTTPException(status_code=401, detail="아이디 또는 비밀번호가 올바르지 않습니다")
|
|
return {"access_token": create_access_token(username), "token_type": "bearer"}
|
|
|
|
@app.get("/api/me")
|
|
def me(user: str = Depends(require_auth)):
|
|
return {"username": user}
|
|
|
|
|
|
# ── STT ───────────────────────────────────────────────────────
|
|
@app.post("/api/transcribe")
|
|
async def transcribe(request: Request, file: UploadFile = File(...),
|
|
_: str = Depends(require_auth)):
|
|
_check_size(request)
|
|
ext = _ext(file.filename)
|
|
if ext not in AUDIO_EXT:
|
|
raise HTTPException(400, f"지원하지 않는 형식. 지원: {', '.join(sorted(AUDIO_EXT))}")
|
|
file_id = str(uuid.uuid4())
|
|
save_path = os.path.join(UPLOAD_DIR, f"{file_id}.{ext}")
|
|
await _save(file, save_path)
|
|
task = transcribe_task.delay(file_id, save_path)
|
|
return {"task_id": task.id, "file_id": file_id, "filename": file.filename}
|
|
|
|
|
|
# ── OCR ───────────────────────────────────────────────────────
|
|
@app.post("/api/ocr")
|
|
async def ocr(
|
|
request: Request,
|
|
file: UploadFile = File(...),
|
|
mode: str = Form("text"),
|
|
backend: str = Form("paddle"),
|
|
ollama_model: str = Form("granite3.2-vision"),
|
|
custom_prompt: str = Form(""),
|
|
_: str = Depends(require_auth),
|
|
):
|
|
_check_size(request)
|
|
ext = _ext(file.filename)
|
|
if ext not in IMAGE_EXT:
|
|
raise HTTPException(400, f"지원하지 않는 형식. 지원: {', '.join(sorted(IMAGE_EXT))}")
|
|
if mode not in ("text", "structure"): mode = "text"
|
|
if backend not in ("paddle", "ollama"): backend = "paddle"
|
|
file_id = str(uuid.uuid4())
|
|
save_path = os.path.join(UPLOAD_DIR, f"{file_id}.{ext}")
|
|
await _save(file, save_path)
|
|
task = ocr_task.delay(file_id, save_path, mode, backend, ollama_model, custom_prompt)
|
|
return {"task_id": task.id, "file_id": file_id,
|
|
"filename": file.filename, "mode": mode, "backend": backend}
|
|
|
|
|
|
# ── 상태 조회 (celery_app.AsyncResult 사용) ───────────────────
|
|
@app.get("/api/status/{task_id}")
|
|
def get_status(task_id: str, _: str = Depends(require_auth)):
|
|
r = celery_app.AsyncResult(task_id)
|
|
if r.state == "PENDING":
|
|
return {"state": "pending", "progress": 0, "message": "대기 중..."}
|
|
if r.state == "PROGRESS":
|
|
m = r.info or {}
|
|
return {"state": "progress", "progress": m.get("progress", 0),
|
|
"message": m.get("message", "처리 중...")}
|
|
if r.state == "SUCCESS":
|
|
return {"state": "success", "progress": 100, **r.result}
|
|
if r.state == "FAILURE":
|
|
return {"state": "failure", "progress": 0, "message": str(r.info)}
|
|
return {"state": r.state.lower(), "progress": 0}
|
|
|
|
|
|
# ── 다운로드 ──────────────────────────────────────────────────
|
|
@app.get("/api/download/{filename}")
|
|
def download(filename: str, _: str = Depends(require_auth)):
|
|
if ".." in filename or "/" in filename:
|
|
raise HTTPException(400, "잘못된 파일명")
|
|
path = os.path.join(OUTPUT_DIR, filename)
|
|
if not os.path.exists(path):
|
|
raise HTTPException(404, "파일을 찾을 수 없습니다")
|
|
media = ("application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
|
|
if filename.endswith(".xlsx") else "text/plain")
|
|
return FileResponse(path, media_type=media, filename=filename)
|
|
|
|
|
|
# ── 결과 파일 정리 ────────────────────────────────────────────
|
|
@app.post("/api/cleanup")
|
|
def cleanup(_: str = Depends(require_auth)):
|
|
return {"removed": _cleanup_outputs()}
|
|
|
|
@app.on_event("startup")
|
|
async def on_startup():
|
|
_cleanup_outputs()
|
|
|
|
|
|
# ── 유틸 ──────────────────────────────────────────────────────
|
|
def _check_size(request: Request):
|
|
cl = request.headers.get("content-length")
|
|
if cl and int(cl) > MAX_UPLOAD_BYTES:
|
|
raise HTTPException(413, f"파일이 너무 큽니다. 최대 {MAX_UPLOAD_BYTES//1024//1024}MB")
|
|
|
|
def _cleanup_outputs() -> int:
|
|
if OUTPUT_KEEP_SECS == 0:
|
|
return 0
|
|
cutoff = time.time() - OUTPUT_KEEP_SECS
|
|
removed = 0
|
|
for f in glob.glob(os.path.join(OUTPUT_DIR, "*")):
|
|
try:
|
|
if os.path.getmtime(f) < cutoff:
|
|
os.remove(f); removed += 1
|
|
except Exception:
|
|
pass
|
|
return removed
|
|
|
|
def _ext(fn):
|
|
return fn.rsplit(".", 1)[-1].lower() if "." in fn else ""
|
|
|
|
async def _save(file: UploadFile, path: str):
|
|
written = 0
|
|
async with aiofiles.open(path, "wb") as f:
|
|
while chunk := await file.read(1024 * 1024):
|
|
written += len(chunk)
|
|
if written > MAX_UPLOAD_BYTES:
|
|
await f.close()
|
|
os.remove(path)
|
|
raise HTTPException(413, f"파일이 너무 큽니다. 최대 {MAX_UPLOAD_BYTES//1024//1024}MB")
|
|
await f.write(chunk)
|
|
|
|
app.mount("/", StaticFiles(directory="static", html=True), name="static")
|