148 lines
5.6 KiB
Python
148 lines
5.6 KiB
Python
import os
|
|
import httpx
|
|
from celery import Celery
|
|
from ocr_tasks import ocr_task # noqa: F401
|
|
|
|
REDIS_URL = os.getenv("REDIS_URL", "redis://redis:6379/0")
|
|
MODEL_SIZE = os.getenv("WHISPER_MODEL", "medium")
|
|
DEVICE = os.getenv("WHISPER_DEVICE", "cpu")
|
|
COMPUTE_TYPE = os.getenv("WHISPER_COMPUTE_TYPE", "int8")
|
|
LANGUAGE = os.getenv("WHISPER_LANGUAGE", "ko") or None
|
|
BEAM_SIZE = int(os.getenv("WHISPER_BEAM_SIZE", "5"))
|
|
INITIAL_PROMPT = os.getenv("WHISPER_INITIAL_PROMPT", "") or None
|
|
OUTPUT_DIR = os.getenv("OUTPUT_DIR", "/data/outputs")
|
|
OLLAMA_URL = os.getenv("OLLAMA_URL", "http://192.168.0.126:11434")
|
|
OLLAMA_TIMEOUT = int(os.getenv("OLLAMA_TIMEOUT", "600"))
|
|
|
|
_cpu_threads_env = int(os.getenv("CPU_THREADS", "0"))
|
|
CPU_THREADS = _cpu_threads_env if _cpu_threads_env > 0 else None # None = auto
|
|
|
|
celery_app = Celery("whisper_tasks", broker=REDIS_URL, backend=REDIS_URL)
|
|
celery_app.conf.update(
|
|
task_serializer="json",
|
|
result_serializer="json",
|
|
accept_content=["json"],
|
|
task_track_started=True,
|
|
result_expires=3600,
|
|
)
|
|
|
|
_model = None
|
|
|
|
def get_model():
|
|
global _model
|
|
if _model is None:
|
|
from faster_whisper import WhisperModel
|
|
kwargs = dict(device=DEVICE, compute_type=COMPUTE_TYPE)
|
|
if CPU_THREADS is not None:
|
|
kwargs["cpu_threads"] = CPU_THREADS
|
|
print(f"[Whisper] 로딩: {MODEL_SIZE} / {DEVICE} / {COMPUTE_TYPE} / threads={CPU_THREADS or 'auto'}")
|
|
_model = WhisperModel(MODEL_SIZE, **kwargs)
|
|
print("[Whisper] 로드 완료")
|
|
return _model
|
|
|
|
|
|
def _ollama_postprocess(text: str, model: str) -> str:
|
|
if not model or not text.strip():
|
|
return text
|
|
prompt = (
|
|
"다음은 음성 인식으로 추출된 텍스트입니다. "
|
|
"내용은 절대 변경하지 말고, 문장 부호를 추가하고 자연스럽게 다듬어줘. "
|
|
"결과 텍스트만 출력하고 설명은 하지 마.\n\n"
|
|
f"{text}"
|
|
)
|
|
try:
|
|
resp = httpx.post(
|
|
f"{OLLAMA_URL}/api/chat",
|
|
json={"model": model,
|
|
"messages": [{"role": "user", "content": prompt}],
|
|
"stream": False, "options": {"temperature": 0.1}},
|
|
timeout=float(OLLAMA_TIMEOUT),
|
|
)
|
|
resp.raise_for_status()
|
|
result = resp.json().get("message", {}).get("content", "").strip()
|
|
return result if result else text
|
|
except Exception as e:
|
|
print(f"[Ollama 후처리 실패] {e}")
|
|
return text
|
|
|
|
|
|
@celery_app.task(bind=True, name="tasks.transcribe_task", queue="stt")
|
|
def transcribe_task(self, file_id: str, audio_path: str,
|
|
use_ollama: bool = False, ollama_model: str = ""):
|
|
self.update_state(state="PROGRESS", meta={"progress": 5, "message": "모델 준비 중..."})
|
|
try:
|
|
model = get_model()
|
|
self.update_state(state="PROGRESS", meta={"progress": 15, "message": "오디오 분석 중..."})
|
|
|
|
segments_gen, info = model.transcribe(
|
|
audio_path,
|
|
language=LANGUAGE,
|
|
beam_size=BEAM_SIZE,
|
|
initial_prompt=INITIAL_PROMPT,
|
|
vad_filter=True,
|
|
vad_parameters=dict(min_silence_duration_ms=500),
|
|
word_timestamps=False,
|
|
)
|
|
|
|
self.update_state(state="PROGRESS", meta={"progress": 30, "message": "텍스트 변환 중..."})
|
|
|
|
segments, parts = [], []
|
|
duration = info.duration
|
|
|
|
for seg in segments_gen:
|
|
segments.append({"start": round(seg.start, 2),
|
|
"end": round(seg.end, 2),
|
|
"text": seg.text.strip()})
|
|
parts.append(seg.text.strip())
|
|
if duration > 0:
|
|
pct = 30 + int((seg.end / duration) * 50)
|
|
self.update_state(
|
|
state="PROGRESS",
|
|
meta={"progress": min(pct, 80),
|
|
"message": f"변환 중... {seg.end:.0f}s / {duration:.0f}s"},
|
|
)
|
|
|
|
raw_text = "\n".join(parts)
|
|
full_text = raw_text
|
|
|
|
if use_ollama and ollama_model:
|
|
self.update_state(state="PROGRESS",
|
|
meta={"progress": 85,
|
|
"message": f"Ollama({ollama_model}) 후처리 중..."})
|
|
full_text = _ollama_postprocess(raw_text, ollama_model)
|
|
|
|
self.update_state(state="PROGRESS", meta={"progress": 95, "message": "파일 저장 중..."})
|
|
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
|
output_filename = f"{file_id}.txt"
|
|
|
|
with open(os.path.join(OUTPUT_DIR, output_filename), "w", encoding="utf-8") as f:
|
|
f.write(f"# 변환 결과\n# 언어: {info.language} | 재생 시간: {duration:.1f}초")
|
|
if use_ollama and ollama_model:
|
|
f.write(f" | Ollama 후처리: {ollama_model}")
|
|
f.write("\n\n## 전체 텍스트\n\n" + full_text + "\n\n")
|
|
f.write("## 타임스탬프별 세그먼트\n\n")
|
|
for seg in segments:
|
|
f.write(f"[{_fmt(seg['start'])} → {_fmt(seg['end'])}] {seg['text']}\n")
|
|
|
|
try: os.remove(audio_path)
|
|
except: pass
|
|
|
|
return {
|
|
"text": full_text,
|
|
"raw_text": raw_text,
|
|
"segments": segments,
|
|
"language": info.language,
|
|
"duration": round(duration, 1),
|
|
"output_file": output_filename,
|
|
"ollama_used": use_ollama and bool(ollama_model),
|
|
"ollama_model": ollama_model if (use_ollama and ollama_model) else "",
|
|
}
|
|
|
|
except Exception as e:
|
|
raise Exception(f"변환 실패: {str(e)}")
|
|
|
|
|
|
def _fmt(s):
|
|
m, sec = divmod(int(s), 60)
|
|
return f"{m:02d}:{sec:02d}"
|