from __future__ import annotations

import math
import re
from collections import Counter
from typing import Any, Dict, Iterable, List, Optional

from ..utils.text import THINK_SPLIT, normalize_text

YES_TOKENS = {
    "yes",
    "y",
    "true",
    "affirmative",
    "agree",
    "correct",
    "indeed",
    "sure",
    "certainly",
    "likely",
    "probably yes",
}
NO_TOKENS = {
    "no",
    "n",
    "false",
    "negative",
    "disagree",
    "incorrect",
    "nope",
    "unlikely",
    "probably no",
}


def parse_answer(generated_text: str) -> str:
    """Return 'yes' | 'no' | 'abstain'.
    Robustly parse the first decisive token after </think>, else the whole text.
    """
    if not isinstance(generated_text, str):
        return "abstain"
    parts = THINK_SPLIT.split(generated_text)
    tail = parts[-1] if parts else generated_text
    s = normalize_text(tail)

    for token in s.split():
        if token in YES_TOKENS:
            return "yes"
        if token in NO_TOKENS:
            return "no"

    if " yes " in f" {s} ":
        return "yes"
    if " no " in f" {s} ":
        return "no"
    return "abstain"


def safe_logit(p: float, eps: float = 1e-6) -> float:
    p = min(max(float(p), eps), 1.0 - eps)
    return math.log(p / (1.0 - p))


def sample_weight(token_probs: List[float], k_tail: int = 3, kappa: float = 0.1) -> float:
    if not isinstance(token_probs, list) or len(token_probs) == 0:
        return 0.0
    tail = token_probs[-k_tail:] if len(token_probs) >= k_tail else token_probs
    logits = [safe_logit(float(p)) for p in tail if isinstance(p, (int, float))]
    if not logits:
        return 0.0
    s = sum(logits) / len(logits)
    w = math.exp(kappa * s)
    if not (w > 0.0 and math.isfinite(w)):
        return 0.0
    return w


def kish_effective_n(weights: Iterable[float]) -> float:
    ws = list(weights)
    if not ws:
        return 0.0
    s1 = sum(ws)
    s2 = sum(w * w for w in ws)
    return (s1 * s1) / s2 if s2 > 0 else 0.0


def aggregate_outputs(sample_dicts: List[Dict[str, Any]], k_tail: int = 3, kappa: float = 0.1) -> Dict[str, Any]:
    w_yes = 0.0
    w_no = 0.0
    weights_used: List[float] = []
    counts: Counter = Counter()

    for d in (sample_dicts or []):
        gen = d.get("generated_text", "")
        probs = d.get("token_probs", [])
        ans = parse_answer(gen)
        counts[ans] += 1
        if ans in ("yes", "no"):
            w = sample_weight(probs, k_tail=k_tail, kappa=kappa)
            if w <= 0.0 or not math.isfinite(w):
                w = 1.0
            if ans == "yes":
                w_yes += w
            else:
                w_no += w
            weights_used.append(w)

    total_w = w_yes + w_no
    abstain_rate = counts["abstain"] / max(1, len(sample_dicts or []))
    if total_w == 0.0:
        return {"p_yes": None, "abstain_rate": abstain_rate, "counts": dict(counts), "n_eff": 0.0}
    p_yes = w_yes / total_w
    n_eff = kish_effective_n(weights_used)
    return {"p_yes": p_yes, "abstain_rate": abstain_rate, "counts": dict(counts), "n_eff": n_eff}


def compute_predicted_trend(p_no_news: Optional[float], p_with_news: Optional[float], threshold: float = 0.05) -> str:
    if p_no_news is None or p_with_news is None:
        return "Unknown"
    delta = p_with_news - p_no_news
    if delta >= threshold:
        return "Up"
    if delta <= -threshold:
        return "Down"
    return "Still"


# Single-sample helpers (from evaluate_logits_single_sample)

def _sigmoid(x: float) -> float:
    if x >= 0:
        z = math.exp(-x)
        return 1.0 / (1.0 + z)
    else:
        z = math.exp(x)
        return z / (1.0 + z)


def tail_confidence_from_token_probs(token_probs: List[float], k_tail: int = 3, alpha: float = 0.5) -> float:
    if not isinstance(token_probs, list) or len(token_probs) == 0:
        return 0.5
    tail = token_probs[-k_tail:] if len(token_probs) >= k_tail else token_probs
    logits = [safe_logit(p) for p in tail if isinstance(p, (int, float))]
    if not logits:
        return 0.5
    mean_logit = sum(logits) / len(logits)
    conf = _sigmoid(alpha * mean_logit)
    eps = 1e-6
    return min(max(conf, eps), 1 - eps)


def single_sample_p_yes(sample_dict: Optional[Dict[str, Any]], k_tail: int = 3, alpha: float = 0.5) -> Optional[float]:
    if not sample_dict:
        return None
    gen = sample_dict.get("generated_text", "")
    probs = sample_dict.get("token_probs", [])
    ans = parse_answer(gen)
    if ans == "abstain":
        return None
    conf = tail_confidence_from_token_probs(probs, k_tail=k_tail, alpha=alpha)
    return conf if ans == "yes" else (1.0 - conf)
