from __future__ import annotations
from dataclasses import dataclass
from typing import List, Tuple, Dict, Callable
import math
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
from transformers import DebertaV2Tokenizer
from refinement.scenarios import SCENARIO_LABELS
from config import CFG_CLF
from langchain_deepseek import ChatDeepSeek
import os, re, json
os.environ["TRANSFORMERS_NO_FAST_TOKENIZER"] = "1"   # Force disable all fast tokenizers.
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["DEEPSEEK_API_KEY"] = ""


def _pick_device(dev: str) -> int:
    if dev == "auto":
        return 0 if torch.cuda.is_available() else -1
    return 0 if dev.startswith("cuda") else -1

@dataclass
class ScenarioPrediction:
    label: str
    score: float
    all_scores: List[Tuple[str, float]]
    stats: Dict[str, float] | None = None
    source: str = "primary"   # "finetuned" | "zeroshot" | "fallback-llm"

def _entropy(ps: List[Tuple[str, float]]) -> float:
    """Shannon entropy over a list of (label, prob) pairs; uses natural log."""
    eps = 1e-12
    return -sum(p * math.log(max(p, eps)) for _, p in ps)

def _margin(ps_sorted: List[Tuple[str, float]]) -> float:
    """Top1 - Top2 probability margin; ps_sorted must be sorted desc by prob."""
    if len(ps_sorted) < 2:
        return 1.0
    return ps_sorted[0][1] - ps_sorted[1][1]

def _trigger(ps_sorted: List[Tuple[str, float]], cfg=CFG_CLF) -> Tuple[bool, Dict[str, float]]:
    """
    Multi-signal gating for fallback:
    - MSP (max prob) < tau_conf OR
    - margin < tau_margin OR
    - entropy > tau_H
    """
    msp = ps_sorted[0][1]
    marg = _margin(ps_sorted)
    H = _entropy(ps_sorted)
    need = (msp < cfg.conf_threshold) or (marg < cfg.margin_threshold) or (H > cfg.entropy_threshold)
    return need, {"msp": msp, "margin": marg, "entropy": H}

# ---------- LLM fallback interface (inject your own API) ----------
def default_llm_call(sys_prompt: str, user_prompt: str) -> str:
    # print('classifier.py line 60 prompt: ', sys_prompt)
    llm = ChatDeepSeek(
        model="deepseek-chat",
        temperature=0,
        max_tokens=None,
        timeout=None,
        max_retries=2,
    )
    query = [
        ("system", sys_prompt),
        ("human", user_prompt),
    ]
    raw = llm.invoke(query)
    raw_text = raw.content.strip()
    match = re.search(r'\{.*\}', raw_text, re.S)
    if not match:
        raise ValueError("No valid JSON object found in the LLM response.")

    data = json.loads(match.group(0))

    # print('classifier.py line 80 data: ', data)

    return data

    # return '{"label":"non-difficult","reason":"fallback stub"}'


LLM_LABELS_EN = SCENARIO_LABELS
SYS_PROMPT_TMPL = r"""
You are a few-shot classifier for Text-to-Video (T2V) prompt *difficulty scenarios*.
Return ONLY a valid JSON object of the exact form:
{"label": "<one of SCENARIO_LABELS>", "reason": "<short phrase (<= 20 words)>"}

Allowed labels (must match EXACTLY one string in SCENARIO_LABELS):
1) Abstract Descriptions
2) Complex Spatial Relationships
3) Multi-Element Scenes
4) Fine-Grained Details
5) Temporal Consistency
6) Stylistic Hybrids
7) Causality and Physics
8) non-difficult

## Task
Given a short English prompt P_in, decide which SINGLE label best describes the *dominant difficulty* that a T2V model would face when generating a video.

## Diagnostic definitions:
- Abstract Descriptions: Figurative language, metaphors, emotions as objects, surreal imagery.
- Complex Spatial Relationships: Explicit positions/orientations between ≥2 entities; lots of prepositions (“on top”, “behind”, “between”).
- Multi-Element Scenes: ≥3 different entities or activities; dense environments with many elements in one shot.
- Fine-Grained Details: Micro-level attributes (textures, tiny objects, reflections, accessories); often close-up.
- Temporal Consistency: Clear time progression or motion over time (bloom, melting, time-lapse).
- Stylistic Hybrids: Mixing multiple visual or artistic styles; style blending is central.
- Causality and Physics: Cause-effect chains or physical forces (gravity, splashes, collisions).
- non-difficult: None of the above applies.

## Tie-breaking rules:
1) Figurative language dominates → Abstract Descriptions
2) Spatial focus dominates → Complex Spatial Relationships
3) Many varied elements, no strong spatial focus → Multi-Element Scenes
4) Close-up or micro details dominate → Fine-Grained Details
5) Time progression dominates → Temporal Consistency
6) Mixed styles dominate → Stylistic Hybrids
7) Physics/cause-effect dominate → Causality and Physics
8) Otherwise choose non-difficult.

## Few-shot examples (prompt → label):
- "Hope dances in a field of forgotten dreams." → Abstract Descriptions
- "A cat and a dog sit back-to-back; a parrot hovers above." → Complex Spatial Relationships
- "A neon street with vendors, robots, and flashing billboards." → Multi-Element Scenes
- "A gold pocket watch with a cracked rim on velvet." → Fine-Grained Details
- "A bud opens into a flower in slow motion." → Temporal Consistency
- "A medieval castle with neon cyberpunk signs." → Stylistic Hybrids
- "A glass tips; wine splashes and forms ripples." → Causality and Physics
- "A child runs across a field." → non-difficult

Classify this prompt:
P_in: ```{P_in}```
"""


ZSH_HYP = (
    "The text primarily exhibits the following T2V difficulty: {}. "
    "Definitions: "
    "Abstract Descriptions=figurative personification/non-literal imagery; "
    "Complex Spatial Relationships=explicit relative positions/orientations; "
    "Multi-Element Scenes=many heterogeneous elements or activities; "
    "Fine-Grained Details=micro attributes/close-up textures; "
    "Temporal Consistency (Video)=explicit time progression; "
    "Stylistic Hybrids=mixed art/media styles; "
    "Causality and Physics=cause→effect or physical laws; "
    "non-difficult=none dominates."
)



def llm_fallback_classify(user_prompt: str, llm_call: Callable[[str], str] = default_llm_call) -> Dict[str, str]:
    """Call LLM fallback with the scenario definitions and parse a strict-JSON response."""
    # print('classifier.py line 101 P_in: ', user_prompt)
    raw = llm_call(SYS_PROMPT_TMPL, user_prompt)
    # print('classifier.py line 103 raw: ', raw)
    json_str = json.dumps(raw)
    data = json.loads(json_str)
    lab = data.get("label", "").strip()
    # print('classifier.py line 110 lab: ', lab)
    # print('LLM_LABELS_EN', LLM_LABELS_EN)
    if lab not in LLM_LABELS_EN:
        lab = "non-difficult"
    return {"label": lab, "reason": data.get("reason", "")}
    # except Exception:
    #     return {"label": "non-difficult", "reason": "invalid-JSON"}

# ---------- Main classifier ----------
class ScenarioClassifier:
    """
    Hybrid classifier:
    - If a fine-tuned 8-way classifier is provided, use it first.
    - Otherwise, use a zero-shot pipeline with candidate labels.
    - If uncertainty is detected, trigger LLM zero-shot fallback.
    """
    def __init__(self, cfg=CFG_CLF, llm_call: Callable[[str], str] = default_llm_call):
        self.cfg = cfg
        self.device = _pick_device(cfg.device)
        tok_zeroshot = DebertaV2Tokenizer.from_pretrained(
            cfg.tokenizer_override or cfg.zeroshot_model
        )
        self.zeroshot = pipeline(
            "zero-shot-classification",
            model=cfg.zeroshot_model,
            tokenizer=tok_zeroshot,
            device=self.device
        )
        self.ft_pipe = None
        if cfg.finetuned_model:
            tok = AutoTokenizer.from_pretrained(
                cfg.finetuned_model, use_fast=cfg.use_fast_tokenizer
            )
            mdl = AutoModelForSequenceClassification.from_pretrained(cfg.finetuned_model)
            self.ft_pipe = pipeline(
                "text-classification", model=mdl, tokenizer=tok,
                device=self.device, top_k=None
            )
        self.llm_call = llm_call

    def predict(self, prompt: str) -> ScenarioPrediction:
        # print("classifier.py line 150 self.ft_pipe", self.ft_pipe)
        # Path A: fine-tuned classifier
        if self.ft_pipe is not None:
            scores = self.ft_pipe(prompt)[0]  # [{'label': '...', 'score': ...}, ...]
            # print('classifier.py line 154 scores: ', scores)
            all_scores = [(d['label'], float(d['score'])) for d in scores]
            all_scores.sort(key=lambda x: x[1], reverse=True)
            need_fb, stats = _trigger(all_scores, self.cfg)
            if not need_fb:
                return ScenarioPrediction(label=all_scores[0][0], score=all_scores[0][1],
                                          all_scores=all_scores, stats=stats, source="finetuned")
            # fall-through to fallback if uncertain

        # Path B: zero-shot classification
        z = self.zeroshot(prompt, candidate_labels=SCENARIO_LABELS,
                          hypothesis_template=ZSH_HYP)
        all_scores = list(zip(z['labels'], map(float, z['scores'])))
        # print("classifier.py line 167 all_scores: {}".format(all_scores))
        need_fb, stats = _trigger(all_scores, self.cfg)

        # print("classifier.py line 170  need_fb", need_fb)

        if need_fb:
            fb = llm_fallback_classify(prompt, self.llm_call)
            label = fb["label"]
            # print("classifier.py line 172  label", label)
            # Use a conservative floor for confidence or replace with self-consistency vote
            prob = all_scores[0][1]
            return ScenarioPrediction(label=label, score=prob,
                                      all_scores=all_scores, stats=stats, source="fallback-llm")
        else:
            return ScenarioPrediction(label=all_scores[0][0], score=all_scores[0][1],
                                      all_scores=all_scores, stats=stats, source="zeroshot")