#!/usr/bin/env python3
"""
Karaoke IG style - legenda word-level estilo Instagram/TikTok.

Gera filtros drawtext do ffmpeg que renderizam:
- Blocos curtos de 2-4 palavras
- Texto branco caixa-alta com contorno preto
- Palavra falada destacada com pill laranja por trás
- Posicionamento ~67% do topo do vídeo
- Fonte Manrope ExtraBold

Uso:
    from karaoke_legenda import gerar_karaoke_drawtext_filters
    filtros = gerar_karaoke_drawtext_filters(transcricao, 1080, 1920, fontsize=70)
    # filtros é uma string longa com 1+N drawtext (1 por bloco branco, N por palavras ativas)
    ffmpeg -i input.mp4 -vf "{filtros}" output.mp4
"""

from __future__ import annotations

import re
from pathlib import Path
from typing import List, Dict, Any

from PIL import ImageFont

# Caminho da fonte Manrope ExtraBold (relativo a este arquivo)
FONT_PATH = str(Path(__file__).parent / "fonts" / "Manrope-ExtraBold.ttf")

# Cor laranja do pill (#F47B3F)
PILL_COLOR = "0xF47B3F"

# Tamanho máximo de palavras por bloco
MAX_WORDS_POR_BLOCO = 4
MIN_WORDS_POR_BLOCO = 2

# Quanto do alto do vídeo (fração) onde a legenda fica
# 0.72 = ~72% do topo (mid-bottom, mais próximo do estilo IG)
Y_FRACAO = 0.72

# Padding horizontal do pill (em px na resolução de referência)
PILL_PADDING_X = 12
# Padding vertical do pill (em px) - não usado direto pelo drawtext (box é simétrico)
PILL_PADDING_Y = 8
# Espaço entre palavras (em px). Drawtext separa por espaço da própria fonte.
# Aqui usamos a largura do espaço da fonte real.


def _normalizar(palavra: str) -> str:
    """Caixa alta + remove caracteres não-imprimíveis simples."""
    return palavra.strip().upper()


def _escapar_drawtext(texto: str) -> str:
    """
    Escapa texto pra entrar no filter drawtext.
    Drawtext usa : como separador e \\ como escape. Aspas simples
    também precisam de cuidado quando o filtro é quotado.
    """
    # Ordem importa: barra invertida primeiro
    out = texto.replace("\\", "\\\\")
    out = out.replace(":", "\\:")
    out = out.replace("'", "’")  # troca por aspas curva pra evitar bug
    out = out.replace(",", "\\,")
    out = out.replace("%", "\\%")
    return out


def _coletar_palavras(transcricao: Dict[str, Any]) -> List[Dict[str, Any]]:
    """
    Pega a transcrição do Whisper (que já vem com word_timestamps=True)
    e extrai a lista flat de palavras com {word, start, end}.

    A função tenta primeiro `transcricao["segmentos"][i]["words"]`
    (formato do nosso wrapper). Se não tiver, tenta `segments[i]["words"]`
    direto do Whisper.
    """
    palavras = []
    segmentos = transcricao.get("segmentos") or transcricao.get("segments") or []

    for seg in segmentos:
        words = seg.get("words") or seg.get("palavras") or []
        for w in words:
            txt = (w.get("word") or w.get("text") or w.get("palavra") or "").strip()
            if not txt:
                continue
            start = float(w.get("start", w.get("inicio", 0)))
            end = float(w.get("end", w.get("fim", start + 0.3)))
            if end <= start:
                end = start + 0.2
            palavras.append({
                "word": _normalizar(txt),
                "start": start,
                "end": end,
            })

    # Se não há word-level (fallback raro), particiona o texto do segmento
    # uniformemente. Estratégia degradada mas evita crash.
    if not palavras:
        for seg in segmentos:
            txt = (seg.get("texto") or seg.get("text") or "").strip()
            if not txt:
                continue
            ini = float(seg.get("inicio", seg.get("start", 0)))
            fim = float(seg.get("fim", seg.get("end", ini + 1.0)))
            tokens = [t for t in re.split(r"\s+", txt) if t]
            if not tokens:
                continue
            dt = (fim - ini) / max(1, len(tokens))
            for i, tk in enumerate(tokens):
                palavras.append({
                    "word": _normalizar(tk),
                    "start": ini + i * dt,
                    "end": ini + (i + 1) * dt,
                })

    # Sanitiza: garante ordem temporal monotônica estrita.
    # Whisper às vezes retorna palavras com starts iguais ou end > start_próxima.
    palavras.sort(key=lambda p: (p["start"], p["end"]))

    # Empurra starts pra frente quando há duplicatas, distribuindo
    # uniformemente entre [start, end_max] de cada cluster.
    i = 0
    while i < len(palavras):
        # Acha cluster com mesmo start (ou muito próximos)
        j = i
        while j + 1 < len(palavras) and palavras[j + 1]["start"] <= palavras[i]["start"] + 0.01:
            j += 1
        if j > i:
            # Distribui starts uniformemente dentro do cluster
            start_cluster = palavras[i]["start"]
            end_cluster = max(p["end"] for p in palavras[i:j + 1])
            # se proxima palavra existe, end_cluster não pode ultrapassar
            if j + 1 < len(palavras):
                end_cluster = min(end_cluster, palavras[j + 1]["start"])
            span = max(end_cluster - start_cluster, 0.05 * (j - i + 1))
            dt = span / (j - i + 1)
            for k in range(i, j + 1):
                palavras[k]["start"] = start_cluster + (k - i) * dt
                palavras[k]["end"] = start_cluster + (k - i + 1) * dt
        i = j + 1

    # Agora garante non-overlap final
    for i in range(len(palavras) - 1):
        if palavras[i]["end"] > palavras[i + 1]["start"]:
            palavras[i]["end"] = palavras[i + 1]["start"]
        if palavras[i]["end"] <= palavras[i]["start"]:
            palavras[i]["end"] = palavras[i]["start"] + 0.05

    return palavras


def _agrupar_blocos(palavras: List[Dict[str, Any]]) -> List[List[Dict[str, Any]]]:
    """
    Agrupa palavras em blocos de 2-4. Tenta quebrar em pontuação natural
    (final de segmento com pausa > 0.5s entre palavras), senão divide por tamanho.

    Retorna lista de blocos, cada bloco é lista de palavras consecutivas.
    """
    if not palavras:
        return []

    blocos = []
    atual = []
    for i, p in enumerate(palavras):
        if not atual:
            atual.append(p)
            continue

        gap = p["start"] - atual[-1]["end"]
        # Quebra se já tem MAX palavras, ou se gap grande, ou se palavra
        # anterior termina com pontuação forte
        ult = atual[-1]["word"]
        tem_pont = ult.endswith((".", "!", "?", ","))
        quebra_natural = gap > 0.45 or tem_pont
        cheio = len(atual) >= MAX_WORDS_POR_BLOCO
        razoavel = len(atual) >= MIN_WORDS_POR_BLOCO

        if cheio or (quebra_natural and razoavel):
            blocos.append(atual)
            atual = [p]
        else:
            atual.append(p)

    if atual:
        blocos.append(atual)

    return blocos


def _medir(font: ImageFont.FreeTypeFont, texto: str) -> float:
    """Retorna largura do texto na fonte (px)."""
    return float(font.getlength(texto))


def _posicoes_bloco(font: ImageFont.FreeTypeFont, bloco: List[Dict[str, Any]],
                    video_width: int) -> Dict[str, Any]:
    """
    Calcula a string completa do bloco e a posição x de cada palavra
    quando o bloco é centralizado horizontalmente.

    Retorna:
        {
            "texto": "MEU CACHORRO",
            "x_inicial": 256,   # x do canto esquerdo do bloco no vídeo
            "largura_total": 567,
            "altura": 76,
            "palavras_pos": [
                {"word": "MEU", "x": 256, "largura": 152, "start":.., "end":..},
                {"word": "CACHORRO", "x": 422, "largura": 400, ...},
            ]
        }
    """
    # Constrói texto separando por espaço
    palavras_txt = [p["word"] for p in bloco]
    texto_completo = " ".join(palavras_txt)

    # Largura total e do espaço
    largura_total = _medir(font, texto_completo)
    largura_espaco = _medir(font, "  ") - _medir(font, " ")  # largura média
    # mais simples: medir " "
    largura_espaco = _medir(font, " ")

    x_inicial = (video_width - largura_total) / 2.0

    pos = []
    cursor = x_inicial
    for i, p in enumerate(bloco):
        w = _medir(font, p["word"])
        pos.append({
            "word": p["word"],
            "x": cursor,
            "largura": w,
            "start": p["start"],
            "end": p["end"],
        })
        cursor += w
        if i < len(bloco) - 1:
            cursor += largura_espaco

    # Altura aprox: bbox da fonte numa string genérica
    bbox = font.getbbox("ÁGÊMQpgy")
    altura = bbox[3] - bbox[1]

    return {
        "texto": texto_completo,
        "x_inicial": x_inicial,
        "largura_total": largura_total,
        "altura": altura,
        "palavras_pos": pos,
    }


def _drawtext_branco(texto: str, x: float, y: float, fontsize: int,
                      t_inicio: float, t_fim: float, font_path: str) -> str:
    """Drawtext do texto branco completo (1 por bloco)."""
    texto_esc = _escapar_drawtext(texto)
    return (
        f"drawtext=fontfile='{font_path}'"
        f":text='{texto_esc}'"
        f":fontsize={fontsize}"
        f":fontcolor=white"
        f":borderw=3:bordercolor=black"
        f":x={x:.1f}:y={y:.1f}"
        f":enable='between(t,{t_inicio:.3f},{t_fim:.3f})'"
    )


def _drawtext_pill(palavra: str, x: float, y: float, fontsize: int,
                    t_inicio: float, t_fim: float, font_path: str,
                    pill_padding: int) -> str:
    """
    Drawtext da palavra ativa, com box laranja (pill) + borda preta no texto
    pra manter consistência com o texto branco abaixo.
    """
    texto_esc = _escapar_drawtext(palavra)
    return (
        f"drawtext=fontfile='{font_path}'"
        f":text='{texto_esc}'"
        f":fontsize={fontsize}"
        f":fontcolor=white"
        f":borderw=3:bordercolor=black"
        f":box=1:boxcolor={PILL_COLOR}@1.0:boxborderw={pill_padding}"
        f":x={x:.1f}:y={y:.1f}"
        f":enable='between(t,{t_inicio:.3f},{t_fim:.3f})'"
    )


def gerar_karaoke_drawtext_filters(
    transcricao: Dict[str, Any],
    video_width: int = 1080,
    video_height: int = 1920,
    fontsize: int = 70,
    y_fracao: float = Y_FRACAO,
    font_path: str = None,
) -> str:
    """
    Função principal. Gera string de filtros drawtext separados por vírgula
    pra ser usada em ffmpeg -vf.

    Args:
        transcricao: dict com chave 'segmentos' (cada um tem 'words' do Whisper)
        video_width, video_height: tamanho do vídeo de saída
        fontsize: tamanho da fonte (px). Padrão 70 (~3.6% da altura).
        y_fracao: posição vertical (0 = topo, 1 = base). Padrão 0.67.
        font_path: override do caminho da fonte. Padrão Manrope-ExtraBold.

    Returns:
        string única com todos os filtros drawtext separados por vírgula.
        Vazio se não houver palavras.
    """
    if font_path is None:
        font_path = FONT_PATH

    if not Path(font_path).exists():
        raise FileNotFoundError(f"Fonte não encontrada: {font_path}")

    palavras = _coletar_palavras(transcricao)
    if not palavras:
        return ""

    blocos = _agrupar_blocos(palavras)
    if not blocos:
        return ""

    # Carrega fonte pra medir
    font = ImageFont.truetype(font_path, fontsize)

    # Y base: 67% do topo, ajustado pra altura do texto
    bbox = font.getbbox("ÁGÊMQpgy")
    altura_texto = bbox[3] - bbox[1]
    y_base = (video_height * y_fracao) - (altura_texto / 2)

    # Re-quebra blocos que ficariam mais largos que ~92% da largura do vídeo.
    # Isso evita texto vazando pras laterais com fontes grandes / palavras longas.
    largura_max = video_width * 0.92
    blocos_ajustados = []
    for bloco in blocos:
        texto_full = " ".join(p["word"] for p in bloco)
        if float(font.getlength(texto_full)) <= largura_max or len(bloco) <= 1:
            blocos_ajustados.append(bloco)
            continue
        # Quebra em pedaços menores até caber
        parcial = []
        for p in bloco:
            parcial.append(p)
            texto_parcial = " ".join(pp["word"] for pp in parcial)
            if float(font.getlength(texto_parcial)) > largura_max and len(parcial) > 1:
                # remove o último, salva, e começa novo bloco com ele
                ultimo = parcial.pop()
                if parcial:
                    blocos_ajustados.append(parcial)
                parcial = [ultimo]
        if parcial:
            blocos_ajustados.append(parcial)
    blocos = blocos_ajustados

    # Ordena por tempo de início e garante não-sobreposição entre blocos
    blocos.sort(key=lambda b: b[0]["start"])

    filtros = []

    # Calcula t_fim de cada bloco como mínimo entre (último.end + respiro) e início do próximo
    n = len(blocos)
    for i, bloco in enumerate(blocos):
        info = _posicoes_bloco(font, bloco, video_width)

        t_inicio_bloco = bloco[0]["start"]
        t_fim_natural = bloco[-1]["end"] + 0.20  # respiro pós-última palavra
        if i + 1 < n:
            t_inicio_proximo = blocos[i + 1][0]["start"]
            # nunca sobrepõe com o próximo bloco
            t_fim_bloco = min(t_fim_natural, t_inicio_proximo - 0.01)
        else:
            t_fim_bloco = t_fim_natural

        if t_fim_bloco <= t_inicio_bloco:
            continue

        # 1) drawtext branco do bloco inteiro
        filtros.append(_drawtext_branco(
            info["texto"],
            info["x_inicial"],
            y_base,
            fontsize,
            t_inicio_bloco,
            t_fim_bloco,
            font_path,
        ))

        # 2) drawtext de cada palavra ativa com pill laranja
        # Estende cada pill até o início da próxima palavra do mesmo bloco (sem sobrepor)
        pps = info["palavras_pos"]
        for j, pp in enumerate(pps):
            if j + 1 < len(pps):
                end_pill = min(pp["end"] + 0.05, pps[j + 1]["start"] - 0.005)
            else:
                end_pill = min(pp["end"] + 0.05, t_fim_bloco)
            if end_pill <= pp["start"]:
                continue
            filtros.append(_drawtext_pill(
                pp["word"],
                pp["x"],
                y_base,
                fontsize,
                pp["start"],
                end_pill,
                font_path,
                PILL_PADDING_X,
            ))

    return ",".join(filtros)


# ---------------------------------------------------------------------------
# Auto-teste / CLI rápido
# ---------------------------------------------------------------------------
if __name__ == "__main__":
    import json
    import sys

    # Permite testar com um JSON de transcrição salvo
    if len(sys.argv) >= 2:
        with open(sys.argv[1], "r", encoding="utf-8") as f:
            tr = json.load(f)
        filtros = gerar_karaoke_drawtext_filters(tr)
        print(filtros[:1200])
        print(f"\n... ({len(filtros)} chars total)")
    else:
        # Mock simples pra ver se o pipeline interno funciona
        mock = {
            "segmentos": [{
                "inicio": 0.0,
                "fim": 2.0,
                "texto": "meu cachorro brinca",
                "words": [
                    {"word": "meu", "start": 0.0, "end": 0.45},
                    {"word": "cachorro", "start": 0.45, "end": 1.20},
                    {"word": "brinca", "start": 1.20, "end": 2.00},
                ]
            }]
        }
        out = gerar_karaoke_drawtext_filters(mock, 1080, 1920, 70)
        print(out)
