# test_metrics.py
# -*- coding: utf-8 -*-
import asyncio, json, math, statistics, argparse, csv, re
from pathlib import Path
from typing import List, Tuple, Dict, Any, Sequence, Optional, Set
import utils, openai_utils
import nli_utils
from sklearn.metrics import roc_auc_score


# ----------------------------- Generation & Evaluation -----------------------------
async def _async_generate_once(
        question: str,
        model: str,
        temperature: float,
        *,
        stop=None,
        excluded: list[str] | None = None,
):
    excluded = []
    suffix = ""
    if excluded:
        excluded_str = ", ".join(map(str, excluded))
        suffix = f"\nEXCLUDED ANSWER LIST: [{excluded_str}]"

    prompt = utils.make_prompt_with_final_free(question + suffix)

    # Non-streaming call, with token-level logprobs disabled here (top_logprobs=0)
    out = await openai_utils.get_openai_completion_with_token_probs_async(
        prompt=prompt,
        model=model,
        temperature=temperature,
        top_logprobs=0,
        stop=stop
    )

    # Note: using the simplified collector (case_insensitive_ok only)
    final_ans, index_groups, sim_ans = utils.collect_answer_index_groups(
        out, case_insensitive_ok=True
    )
    if not final_ans:  # Failed to parse a final answer
        final_ans = None

    return {
        "raw": out,
        "final": final_ans,
        "index_groups": index_groups,
        "sim": sim_ans,
        "excluded": list(excluded),  # Pass back the exclusion list used in this round
    }


async def _async_generate_explain(
        question: str,
        base_text: str,
        final_ans: str,
        model: str,
        temperature: float,
        *,
        stop=None,
) -> str:
    """
    Build an explanation path for a previously generated reasoning path (answer → explanation).
    By default, uses utils.make_explain_prompt(question, base_text, final_ans) to construct the prompt.
    If your function signature differs, fallback attempts are made.
    """
    try:
        exp_prompt = utils.make_explain_prompt(question, base_text, final_ans)
    except TypeError:
        # Fallbacks if your make_explain_prompt has different signatures
        try:
            exp_prompt = utils.make_explain_prompt(question, final_ans)
        except TypeError:
            exp_prompt = utils.make_explain_prompt(base_text)

    # Explanation typically doesn't need </final>; keep stop=None unless required.
    exp_out = await openai_utils.get_openai_completion_with_token_probs_async(
        prompt=exp_prompt,
        model=model,
        temperature=temperature,
        top_logprobs=0,
        stop=stop
    )

    return exp_out.get("text", "") or ""


def _normalize_answer(x: Any) -> str:
    return str(x).strip().lower()


async def sample_k_with_exclusion_unique(
        question: str,
        model: str,
        k: int,
        temperature: float,
        *,
        stop=None,
        normalize_fn=None,
        max_trials_factor: int = 3,
) -> List[Dict[str, Any]]:
    """
    Keep sampling until we collect k samples with UNIQUE answers or hit the max trial count.

    Rules:
      - At each attempt, inject the current `seen` set into the prompt as EXCLUDED.
      - If the generated answer ∈ EXCLUDED, skip that sample (it still counts toward trials),
        and retry with the same EXCLUDED list.
      - Only add a sample when it yields a NEW answer; also add that answer into `seen`.
    """
    samples: List[Dict[str, Any]] = []
    seen: Set[str] = set()
    normalize_fn = normalize_fn or _normalize_answer

    max_trials = max(1, k) * max(1, max_trials_factor)
    trials = 0

    while len(samples) < k and trials < max_trials:
        trials += 1
        excluded = sorted(seen)
        s = await _async_generate_once(
            question, model, temperature, stop=stop, excluded=excluded
        )
        key = normalize_fn(s.get("final"))
        if not key:
            # Could not parse an answer; try again
            continue

        # New answer discovered: keep it
        seen.add(key)
        samples.append(s)

    return samples


def _to_value_slope(seq, beta=0.5):
    if not seq: return []
    vecs = []
    prev = seq[0]
    for x in seq:
        slope = x - prev
        vecs.append((x, beta * slope))
        prev = x
    return vecs


def _dtw_vec(a, b):
    def _l2(u, v):
        return ((u[0] - v[0]) ** 2 + (u[1] - v[1]) ** 2) ** 0.5

    n, m = len(a), len(b)
    INF = float('inf')
    dp = [[INF] * (m + 1) for _ in range(n + 1)]
    dp[0][0] = 0.0
    for i in range(1, n + 1):
        for j in range(1, m + 1):
            cost = _l2(a[i - 1], b[j - 1])
            dp[i][j] = cost + min(dp[i - 1][j], dp[i][j - 1], dp[i - 1][j - 1])
    return dp[n][m]


def vecdtw_similarity(a, b, beta=0.5):
    v1 = _to_value_slope(a, beta=beta)
    v2 = _to_value_slope(b, beta=beta)
    d = _dtw_vec(v1, v2)
    path_len = max(1.0, (len(v1) + len(v2)) / 2.0)
    return max(0.0, 1.0 - d / path_len)


def compute_metrics_for_sample(sample: Dict[str, Any], gold: str) -> Dict[str, Any]:
    raw = sample.get("raw", {}) or {}
    text = raw.get("text", "") or ""
    tokens = raw.get("tokens", []) or []
    groups = sample.get("index_groups") or []
    T = max(1, len(tokens))

    def _avg_prob(group: List[int]) -> float:
        ps = [tokens[i].get("prob", 0.0) for i in group if 0 <= i < T]
        return (sum(ps) / len(ps)) if ps else 0.0

    # Correctness label (simple whitespace/case-insensitive containment; replace with your matcher if needed)
    final = sample.get("final")
    correct = int((str(final).strip().lower() in str(gold).strip().lower()) or (
                str(gold).strip().lower() in str(final).strip().lower()))

    # Current path average confidence
    cur_means = [_avg_prob(g) for g in groups if g]
    current_avg_prob = (sum(cur_means) / len(cur_means)) if cur_means else 0.0

    # Only keep the per-question similarity: simG_vecdtw
    # (mean vecDTW similarity between all earlier spans and the last span)
    simG = 0.0
    span_series = []
    for g in groups:
        if not g:
            continue
        probs = [tokens[i].get("prob", 0.0) for i in g if 0 <= i < T]
        if probs:
            span_series.append(probs)

    if len(span_series) >= 2:
        last = span_series[-1]
        vals = [vecdtw_similarity(s, last, beta=0.5) for s in span_series[:-1]]
        simG = float(sum(vals) / len(vals))

    return {
        "final": final,
        "gold": gold,
        "correct": correct,
        "n_spans": len(groups),

        # Keep only the metrics you need
        "current_avg_prob": current_avg_prob,
        "simG_vecdtw": simG,
    }


def _parse_final_index_from_output(nli_output, default=1) -> int:
    text = ""
    if isinstance(nli_output, dict):
        text = nli_output.get("text") or nli_output.get("raw") or ""
    else:
        text = str(nli_output or "")
    m = re.search(r"<\s*final\s*>\s*([0-9]+)\s*<\s*/\s*final\s*>", text, flags=re.I)
    if m:
        try:
            return int(m.group(1))
        except ValueError:
            pass
    # Fallback: grab the last number in the string
    nums = re.findall(r"\b([0-9]+)\b", text)
    if nums:
        try:
            return int(nums[-1])
        except ValueError:
            pass
    return int(default)


# ----------------------------- Main Pipeline (NLI-based Selection) -----------------------------
async def main(args):
    # Load data
    data = [json.loads(l) for l in Path(args.jsonl).read_text(encoding="utf-8").splitlines() if l.strip()]

    total_q = 0
    greedy_hits = 0
    nli_hits_by_mode = {m: 0 for m in ["mean", "b2a", "cont_mean", "cont_b2a", "b2a_penalized", "mean_penalized"]}
    baseline_hits = {b: 0 for b in ["conf", "sim", "vote"]}

    # For AUROC: collect scores and labels
    all_labels = []  # gold correctness (0/1)
    scores_by_mode = {m: [] for m in ["mean", "b2a", "cont_mean", "cont_b2a", "b2a_penalized", "mean_penalized"]}
    scores_by_baseline = {b: [] for b in ["conf", "sim"]}

    rows = []

    for idx, ex in enumerate(data, 1):
        print(idx)
        q_full = ex.get("question", "")
        examples = '''Examples:
        Question: Today is Christmas Eve of 1937. What is the date tomorrow in MM/DD/YYYY?
        Answer:
        1) Premise/Evidence:
        Today's date is December 24, 1937. We are asked to find the date tomorrow.
        2) Reasoning:
        - Step 1: Identify the current day, month, and year: 12/24/1937.  
        - Step 2: Add one day to the current date: 24 + 1 = 25.  
        - Step 3: Since December has 31 days, adding 1 day does not change the month or year.  
        - Step 4: Format the resulting date in MM/DD/YYYY: 12/25/1937.
        3) Conclusion:
        <final>12/25/1937</final>

        Question: In the UK, people usually put the day before the month when formatting the date. Therefore, today is 02/01/1987 to them. What is the date a month ago in MM/DD/YYYY?
        Answer:
        1) Premise/Evidence:
        Today's date is 02/01/1987 in UK format (day/month), which corresponds to February 1, 1987 in MM/DD/YYYY.
        2) Reasoning:
        - Step 1: Identify the current month and year: February 1987.  
        - Step 2: Determine the previous month: January 1987.  
        - Step 3: Keep the day the same (1) since it exists in January.  
        - Step 4: Format the resulting date in MM/DD/YYYY: 01/01/1987.
        3) Conclusion:
        <final>01/01/1987</final>

        Question: Jane and John married on Jan 2, 1958. It is their 5-year anniversary today. What is the date one week from today in MM/DD/YYYY?
        Answer:
        1) Premise/Evidence:
        Jane and John were married on January 2, 1958. Their 5-year anniversary is on January 2, 1963.
        2) Reasoning:
        - Step 1: Start from the anniversary date: 01/02/1963.  
        - Step 2: Add one week (7 days): 2 + 7 = 9.  
        - Step 3: Since January has more than 9 days, month and year remain unchanged.  
        - Step 4: Format the resulting date in MM/DD/YYYY: 01/09/1963.
        3) Conclusion:
        <final>01/09/1963</final
        '''
        q = "Question: " + q_full.split("\nOptions:")[0] if "\nOptions:" in q_full else q_full
        gold = str(ex.get("answer", "")).strip()

        # Step 1) Sampling
        samples = await sample_k_with_exclusion_unique(
            q, args.model, args.k, args.temperature,
            stop=args.stop, normalize_fn=_normalize_answer
        )
        if not samples:
            total_q += 1
            continue

        # Greedy baseline (first sample)
        greedy_final = samples[0].get("final", "")
        greedy_correct = int((_normalize_answer(greedy_final) in _normalize_answer(gold)) or (
                    _normalize_answer(gold) in _normalize_answer(greedy_final)))

        # Step 2) Generate explanations
        explain_texts = []
        for s in samples:
            base_text = (s.get("raw", {}) or {}).get("text", "") or ""
            final_ans = s.get("final", "")
            exp_txt = await _async_generate_explain(
                question=q,
                base_text=base_text,
                final_ans=final_ans,
                model=args.model,
                temperature=args.temperature,
                stop=None,
            )
            explain_texts.append(exp_txt)
            # Print reasoning and explanation paths
            print("\n--- Reasoning Path ---")
            print(base_text)
            print("--- Explanation Path ---")
            print(exp_txt)
            print("-----------------------")

        # Step 3) Build NLI pairs
        pairs, valid_indices = [], []
        for i, (s, e) in enumerate(zip(samples, explain_texts), 1):
            final_ans = s.get("final")
            if not final_ans:
                continue
            p1 = (s.get("raw", {}) or {}).get("text", "") or ""
            p2 = e or ""
            pairs.append([p1, p2])
            valid_indices.append(i)

        total_q += 1
        greedy_hits += greedy_correct

        # Baselines
        baseline_results = {}
        metrics_list = [compute_metrics_for_sample(s, gold) for s in samples]

        # --- conf (pick by current_avg_prob)
        if metrics_list:
            best_idx_conf = max(range(len(metrics_list)), key=lambda i: metrics_list[i]["current_avg_prob"])
            pick = samples[best_idx_conf]
            final_ans = pick.get("final")
            correct = int((_normalize_answer(final_ans) in _normalize_answer(gold)) or (
                        _normalize_answer(gold) in _normalize_answer(final_ans)))
            baseline_results["conf"] = {"pick_idx": best_idx_conf + 1, "final": final_ans, "correct": correct}
            scores_by_baseline["conf"].append(metrics_list[best_idx_conf]["current_avg_prob"])

        # --- sim (pick by simG_vecdtw)
        if metrics_list:
            best_idx_sim = max(range(len(metrics_list)), key=lambda i: metrics_list[i]["simG_vecdtw"])
            pick = samples[best_idx_sim]
            final_ans = pick.get("final")
            correct = int((_normalize_answer(final_ans) in _normalize_answer(gold)) or (
                        _normalize_answer(gold) in _normalize_answer(final_ans)))
            baseline_results["sim"] = {"pick_idx": best_idx_sim + 1, "final": final_ans, "correct": correct}
            scores_by_baseline["sim"].append(metrics_list[best_idx_sim]["simG_vecdtw"])

        # --- vote (majority vote on normalized answers)
        counts = {}
        for i, s in enumerate(samples, 1):
            ans = _normalize_answer(s.get("final"))
            if not ans: continue
            counts.setdefault(ans, []).append(i)
        if counts:
            best_ans, idxs = max(counts.items(), key=lambda kv: len(kv[1]))
            pick = samples[idxs[0] - 1]
            final_ans = pick.get("final")
            correct = int((_normalize_answer(final_ans) in _normalize_answer(gold)) or (
                        _normalize_answer(gold) in _normalize_answer(final_ans)))
            baseline_results["vote"] = {"pick_idx": idxs[0], "final": final_ans, "correct": correct}

        for b in baseline_results:
            baseline_hits[b] += baseline_results[b]["correct"]

        # Step 4) NLI-based selection (use best_score for AUROC)
        results_by_mode = {}
        for mode in ["mean", "b2a", "cont_mean", "cont_b2a", "b2a_penalized", "mean_penalized"]:
            if not pairs:
                nli_final, nli_correct, pick_idx, best_score = None, 0, 0, 0.0
            else:
                best_idx_local, best_score, _ = nli_utils.nli_tournament_judge_pairs(pairs, mode=mode)
                pick_idx = valid_indices[best_idx_local - 1]
                pick = samples[pick_idx - 1]
                nli_final = pick.get("final")
                nli_correct = int((_normalize_answer(nli_final) in _normalize_answer(gold)) or (
                            _normalize_answer(gold) in _normalize_answer(nli_final)))
            results_by_mode[mode] = {
                "pick_idx": pick_idx,
                "nli_final": nli_final,
                "nli_correct": nli_correct,
                "score": best_score,  # Save the score for AUROC
            }
            nli_hits_by_mode[mode] += nli_correct
            scores_by_mode[mode].append(best_score)
        # Gold labels (use greedy correctness here; adjust if needed)
        all_labels.append(int((_normalize_answer(greedy_final) in _normalize_answer(gold)) or (
                    _normalize_answer(gold) in _normalize_answer(greedy_final))))

        # Save row
        row = {"qid": idx, "k_collected": len(samples), "gold": gold,
               "greedy_final": greedy_final, "greedy_correct": greedy_correct}
        for m in results_by_mode:
            row[f"{m}_choice_idx"] = results_by_mode[m]["pick_idx"]
            row[f"{m}_final"] = results_by_mode[m]["nli_final"]
            row[f"{m}_correct"] = results_by_mode[m]["nli_correct"]
            row[f"{m}_score"] = results_by_mode[m]["score"]
        for b in baseline_results:
            row[f"{b}_choice_idx"] = baseline_results[b]["pick_idx"]
            row[f"{b}_final"] = baseline_results[b]["final"]
            row[f"{b}_correct"] = baseline_results[b]["correct"]
        rows.append(row)

        # Print a single-line progress summary (includes NLI and baselines)
        print(
            f"[{idx}] gold={gold!r} | greedy={greedy_final!r}({greedy_correct}) | "
            + " | ".join(
                [f"{m}@{results_by_mode[m]['pick_idx']}={results_by_mode[m]['nli_final']!r}"
                 f"({results_by_mode[m]['nli_correct']}, score={results_by_mode[m]['score']:.4f})"
                 for m in ["mean", "b2a", "cont_mean", "cont_b2a", "b2a_penalized", "mean_penalized"]]
            )
            + f" | conf@{baseline_results['conf']['pick_idx']}={baseline_results['conf']['final']!r}"
            + f"({baseline_results['conf']['correct']}, score={metrics_list[best_idx_conf]['current_avg_prob']:.4f})"
            + f" | sim@{baseline_results['sim']['pick_idx']}={baseline_results['sim']['final']!r}"
            + f"({baseline_results['sim']['correct']}, score={metrics_list[best_idx_sim]['simG_vecdtw']:.4f})"
            + f" | vote@{baseline_results['vote']['pick_idx']}={baseline_results['vote']['final']!r}"
            + f"({baseline_results['vote']['correct']})"
            + f" | k={len(samples)}"
        )

    # Step 5) Print accuracies
    print("\n=== Accuracy ===")
    print(f"Questions: {total_q}")
    print(f"Greedy accuracy = {greedy_hits}/{total_q} = {greedy_hits / total_q:.4f}")
    for m in nli_hits_by_mode:
        hits = nli_hits_by_mode[m]
        print(f"NLI-{m} accuracy = {hits}/{total_q} = {hits / total_q:.4f}")
    for b in baseline_hits:
        hits = baseline_hits[b]
        print(f"Baseline-{b} accuracy = {hits}/{total_q} = {hits / total_q:.4f}")

    # Step 5.5) Print AUROC
    print("\n=== AUROC ===")
    if len(set(all_labels)) > 1:  # Need both positive and negative labels
        for m in scores_by_mode:
            auc = roc_auc_score(all_labels, scores_by_mode[m])
            print(f"NLI-{m} AUROC = {auc:.4f}")
        for b in scores_by_baseline:
            auc = roc_auc_score(all_labels, scores_by_baseline[b])
            print(f"Baseline-{b} AUROC = {auc:.4f}")
    else:
        print("AUROC cannot be computed (all labels are the same).")

    # Step 6) Save CSV
    out_csv = args.out or "nli_pick_vs_greedy.csv"
    fieldnames = ["qid", "k_collected", "gold", "greedy_final", "greedy_correct"]

    # Add *_score columns for each NLI mode
    for m in ["mean", "b2a", "cont_mean", "cont_b2a", "b2a_penalized", "mean_penalized"]:
        fieldnames += [
            f"{m}_choice_idx",
            f"{m}_final",
            f"{m}_correct",
            f"{m}_score"
        ]

    for b in ["conf", "sim", "vote"]:
        fieldnames += [f"{b}_choice_idx", f"{b}_final", f"{b}_correct"]

    with open(out_csv, "w", encoding="utf-8", newline="") as f:
        w = csv.DictWriter(f, fieldnames=fieldnames)
        w.writeheader()
        for r in rows:
            w.writerow({k: r.get(k) for k in fieldnames})
    print(f"[saved] {out_csv} (rows={len(rows)})")


# ----------------------------- CLI -----------------------------
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--jsonl", default="./date_understanding/date_understanding_process.jsonl.jsonl")
    parser.add_argument("--model", default="gpt-4o-mini")
    parser.add_argument("--k", type=int, default=3)
    parser.add_argument("--temperature", type=float, default=0.7)
    parser.add_argument("--stop", nargs="*", default=None)
    parser.add_argument("--out", default="out.csv")
    parser.add_argument("--print-every", type=int, default=1)
    args = parser.parse_args()
    asyncio.run(main(args))
