__import__("pysqlite3")
import sys
sys.modules["sqlite3"] = sys.modules.pop("pysqlite3")

import os
import uuid
import traceback
import shutil

from fastapi import APIRouter
from pydantic import BaseModel

import chromadb
import whisper

from logging.handlers import RotatingFileHandler
import logging

from moviepy.editor import VideoFileClip

from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_core.documents import Document
from langchain_community.document_loaders import PyPDFLoader
from langchain_community.embeddings import HuggingFaceEmbeddings

# -------------------------------------------------------------------
# BASE CONFIG
# -------------------------------------------------------------------
BASE_DIR = "/home/devknestlms/public_html/chroma_project"
CHROMA_ROOT = BASE_DIR
SUPPORTED_MEDIA = (".mp4", ".avi", ".mov", ".mkv", ".mp3", ".wav")

# -------------------------------------------------------------------
# LOGGING (PRODUCTION SAFE)
# -------------------------------------------------------------------
LOG_FILE = os.path.join(BASE_DIR, "python_chunk_log.log")

if not os.path.exists(LOG_FILE):
    open(LOG_FILE, "a").close()

logger = logging.getLogger("chunk_logger")
logger.setLevel(logging.INFO)

handler = RotatingFileHandler(LOG_FILE, maxBytes=5 * 1024 * 1024, backupCount=3)
formatter = logging.Formatter("%(asctime)s [%(levelname)s] [%(process)d] %(message)s")
handler.setFormatter(formatter)

if not logger.handlers:
    logger.addHandler(handler)

def log(msg):
    logger.info(msg)

# -------------------------------------------------------------------
# ROUTER
# -------------------------------------------------------------------
router = APIRouter()

# -------------------------------------------------------------------
# FFMPEG CONFIG
# -------------------------------------------------------------------
FFMPEG_PATH = os.path.join(BASE_DIR, "ffmpeg")
os.environ["PATH"] = FFMPEG_PATH + os.pathsep + os.environ.get("PATH", "")
os.environ["IMAGEIO_FFMPEG_EXE"] = os.path.join(FFMPEG_PATH, "ffmpeg")

# -------------------------------------------------------------------
# LAZY LOAD MODELS
# -------------------------------------------------------------------
whisper_model = None
embeddings = None

def get_whisper():
    global whisper_model
    if whisper_model is None:
        log("[INIT] Loading Whisper model...")
        whisper_model = whisper.load_model("base")
        log("[INIT] Whisper loaded successfully")
    return whisper_model

def get_embeddings():
    global embeddings
    if embeddings is None:
        log("[INIT] Loading embeddings model...")
        embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
        log("[INIT] Embeddings loaded successfully")
    return embeddings

# -------------------------------------------------------------------
# REQUEST SCHEMA
# -------------------------------------------------------------------
class ProcessRequest(BaseModel):
    company_id: int
    course_id: int
    lesson_id: int
    content_id: int
    file_path: str

# -------------------------------------------------------------------
# CHROMADB
# -------------------------------------------------------------------
def get_chroma_collection(company_id: int):
    path = os.path.join(CHROMA_ROOT, "chroma_db", str(company_id))
    os.makedirs(path, exist_ok=True)

    client = chromadb.PersistentClient(path=path)
    return client.get_or_create_collection(name="rag_chunks")

# -------------------------------------------------------------------
# CHUNKING
# -------------------------------------------------------------------
def chunk_text(text: str):
    splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
    return splitter.split_documents([Document(page_content=text)])

# -------------------------------------------------------------------
# STORE
# -------------------------------------------------------------------
def store_chunks(chunks, meta: dict):
    log(f"[DB] Storing {len(chunks)} chunks")

    collection = get_chroma_collection(meta["company_id"])

    texts = [c.page_content for c in chunks]

    embedder = get_embeddings()
    vectors = embedder.embed_documents(texts)

    ids = [str(uuid.uuid4()) for _ in texts]
    metas = [meta for _ in texts]

    collection.add(
        documents=texts,
        embeddings=vectors,
        metadatas=metas,
        ids=ids
    )

    log("[DB] Storage completed")

# -------------------------------------------------------------------
# MEDIA PROCESSING
# -------------------------------------------------------------------
def extract_audio(file_path: str):
    try:
        log(f"[VIDEO] Opening: {file_path}")

        if file_path.lower().endswith((".mp3", ".wav")):
            return file_path, False

        audio_path = f"{BASE_DIR}/temp_{uuid.uuid4().hex}.wav"

        clip = VideoFileClip(file_path)

        if clip.audio is None:
            log("[VIDEO] No audio found")
            clip.close()
            return None, False

        log("[VIDEO] Extracting audio...")

        clip.audio.write_audiofile(audio_path, verbose=False, logger=None)
        clip.close()

        if not os.path.exists(audio_path):
            log("[ERROR] Audio file not created")
            return None, False

        log(f"[VIDEO] Audio ready: {audio_path}")
        return audio_path, True

    except Exception:
        logger.exception("[ERROR extract_audio]")
        return None, False

def transcribe_audio(file_path: str) -> str:
    try:
        model = get_whisper()
        result = model.transcribe(file_path)
        return result.get("text", "")
    except Exception:
        logger.exception("[ERROR transcribe_audio]")
        return ""

def process_media(file_path: str) -> str:
    ext = os.path.splitext(file_path)[1]
    local_temp = os.path.join(BASE_DIR, f"temp_{uuid.uuid4().hex}{ext}")

    try:
        shutil.copy2(file_path, local_temp)
        log(f"[MEDIA] Copied to temp: {local_temp}")

        audio_path, is_temp = extract_audio(local_temp)

        if not audio_path:
            return ""

        text = transcribe_audio(audio_path)

        if is_temp and os.path.exists(audio_path):
            os.remove(audio_path)

        return text

    finally:
        if os.path.exists(local_temp):
            os.remove(local_temp)
            log("[MEDIA] Temp cleaned")

# -------------------------------------------------------------------
# DOCUMENT PROCESSING
# -------------------------------------------------------------------
def process_doc(file_path: str) -> str:
    try:
        if file_path.lower().endswith(".pdf"):
            loader = PyPDFLoader(file_path)
            docs = loader.load()
            return " ".join([d.page_content for d in docs])
        return ""
    except Exception:
        logger.exception("[ERROR process_doc]")
        return ""

# -------------------------------------------------------------------
# API
# -------------------------------------------------------------------
@router.post("/process")
def process(data: ProcessRequest):
    try:
        log(f"[API] Called | company={data.company_id}")

        path = data.file_path

        if not os.path.exists(path):
            log(f"[ERROR] File not found: {path}")
            return {"success": False, "error": f"File not found: {path}"}

        log(f"[PROCESS] Processing file: {path}")

        if path.lower().endswith(SUPPORTED_MEDIA):
            text = process_media(path)
        else:
            text = process_doc(path)

        log(f"[PROCESS] Extracted text length: {len(text)}")

        if not text.strip():
            log("[PROCESS] No text extracted")
            return {"success": True, "msg": "No text extracted", "chunks": 0}

        chunks = chunk_text(text)
        log(f"[CHUNK] Created {len(chunks)} chunks")

        meta = {
            "company_id": str(data.company_id),
            "course_id": str(data.course_id),
            "lesson_id": str(data.lesson_id),
            "content_id": str(data.content_id),
            "file_name": os.path.basename(path),
        }

        store_chunks(chunks, meta)

        from search import clear_cache
        removed = clear_cache(str(data.company_id), str(data.course_id))

        log(f"[CACHE] Cleared | removed={removed}")

        log(f"[SUCCESS] Completed | chunks={len(chunks)}")

        return {
            "success": True,
            "msg": "Stored successfully",
            "chunks": len(chunks)
        }

    except Exception:
        logger.exception("[API ERROR]")
        return {"success": False, "error": "Internal server error"}