#!/usr/bin/env python3
import os
import sys
import argparse
from pathlib import Path
from typing import Dict, List, Tuple, Optional

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import json
from collections import Counter

# Build labels

def label_for(which: str, quantized: bool, temp: float, metric: str) -> str:
    side = "Post" if which == "post" else "Base"
    quant = "4-bit" if quantized else "full"
    mode = "greedy" if float(temp) == 0.0 else f"temp={temp}"
    return f"{side} {quant} – {mode} – {metric}"


def parse_meta_from_name(name: str) -> Tuple[str, bool, float]:
    base = os.path.basename(name)
    tokens = base.replace(".csv", "").replace(".jsonl", "").split("_")
    which = None
    temp = None
    quant = False
    for t in tokens:
        if t in ("base", "post"):
            which = t
        elif t.startswith("temp"):
            try:
                temp = float(t.replace("temp", ""))
            except Exception:
                temp = None
        elif t in ("qb", "qp", "quant"):
            quant = True
    return which or "base", bool(quant), float(temp) if temp is not None else 1.0


def parse_k_from_name(name: str) -> Optional[int]:
    base = os.path.basename(name)
    core = base.replace(".csv", "").replace(".jsonl", "")
    parts = core.split("_")
    for i, p in enumerate(parts):
        if p.startswith("k"):
            s = p[1:]
            if s.isdigit():
                return int(s)
        if p == "k" and i + 1 < len(parts) and parts[i + 1].isdigit():
            return int(parts[i + 1])
    return None


def load_series_from_csv(csv_path: Path) -> List[Tuple[str, np.ndarray, np.ndarray, Optional[int]]]:
    df = pd.read_csv(csv_path)
    which, quant, temp = parse_meta_from_name(csv_path.name)
    k_total = parse_k_from_name(csv_path.name)
    t = df["t"].to_numpy()

    series: List[Tuple[str, np.ndarray, np.ndarray, Optional[int]]] = []
    if "base_pass" in df.columns:
        series.append((label_for("base", quant, temp, "Pass@t"), t, df["base_pass" ].to_numpy(), k_total))
    if "post_pass" in df.columns:
        series.append((label_for("post", quant, temp, "Pass@t"), t, df["post_pass" ].to_numpy(), k_total))
    if "base_greedy" in df.columns:
        series.append((label_for("base", quant, 0.0, "Pass@t"), t, df["base_greedy"].to_numpy(), k_total))
    if "post_greedy" in df.columns:
        series.append((label_for("post", quant, 0.0, "Pass@t"), t, df["post_greedy"].to_numpy(), k_total))
    if not series:
        ycols = [c for c in df.columns if c != "t"]
        if ycols:
            series.append((label_for(which, quant, temp, "Pass@t"), t, df[ycols[0]].to_numpy(), k_total))
    return series


def load_mv_from_jsonl(samples_path: Path, n: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """Compute MV@t by voting over final parsed answers among first t samples.
    Ties (no unique mode) count as incorrect.
    """
    mv_curves: List[np.ndarray] = []
    counts_per_t: List[int] = []
    def normalize_answer(val: str) -> str:
        if val is None:
            return ""
        s = str(val).strip().lower()
        # Strip common wrappers like \boxed{...}
        if s.startswith("\\boxed{") and s.endswith("}"):
            s = s[len("\\boxed{"):-1].strip()
        return s

    with open(samples_path, "r", encoding="utf-8") as f:
        for line in f:
            rec = json.loads(line)
            parsed: List[str] = rec.get("parsed", [])
            gt: str = rec.get("ground_truth", "")
            if not parsed:
                continue
            gt_norm = normalize_answer(gt)
            mv_binary: List[float] = []
            for t in range(1, min(n, len(parsed)) + 1):
                answers_t = [normalize_answer(a) for a in parsed[:t] if a is not None]
                if not answers_t:
                    mv_binary.append(0.0)
                    continue
                counts = Counter(answers_t)
                most_common = counts.most_common()
                if len(most_common) == 1 or (len(most_common) > 1 and most_common[0][1] > most_common[1][1]):
                    modal_answer = most_common[0][0]
                    mv_binary.append(1.0 if modal_answer == gt_norm else 0.0)
                else:
                    # tie: no unique mode → treat as incorrect
                    mv_binary.append(0.0)
            if mv_binary:
                mv_curves.append(np.array(mv_binary, dtype=float))
    if not mv_curves:
        return np.array([]), np.array([]), np.array([])
    max_t = max(len(c) for c in mv_curves)
    successes = np.zeros(max_t, dtype=float)
    totals = np.zeros(max_t, dtype=int)
    for c in mv_curves:
        L = len(c)
        successes[:L] += c
        totals[:L] += 1
    with np.errstate(invalid="ignore"):
        mv_avg = np.divide(successes, totals, out=np.zeros_like(successes, dtype=float), where=totals>0)
    t = np.arange(1, max_t + 1)
    return t, mv_avg, totals


def wilson_ci(p: np.ndarray, n: np.ndarray, z: float = 1.96) -> Tuple[np.ndarray, np.ndarray]:
    p = np.asarray(p, dtype=float)
    n = np.asarray(n, dtype=float)
    with np.errstate(divide='ignore', invalid='ignore'):
        denom = 1 + (z**2)/n
        center = (p + (z**2)/(2*n)) / denom
        margin = (z * np.sqrt((p*(1-p)/n) + (z**2)/(4*(n**2)))) / denom
        lower = center - margin
        upper = center + margin
    lower = np.clip(lower, 0.0, 1.0)
    upper = np.clip(upper, 0.0, 1.0)
    return lower, upper


def main():
    ap = argparse.ArgumentParser(description="Overlay pass@t and MV@t curves if available")
    ap.add_argument("experiment", help="Experiment ID or path")
    ap.add_argument("--out", default=None, help="Optional explicit output path for PDF")
    ap.add_argument("-n", "--n", type=int, default=20, help="Assumed max t for MV@t parsing")
    ap.add_argument("--mv_only", action="store_true", help="If set, plot only MV@t curves and save as self_consistency_mv_only.png")
    ap.add_argument("--pass_only", action="store_true", help="If set, plot only Pass@t curves and save as self_consistency_combined.png")
    ap.add_argument("--with_ci", action="store_true", help="If set, add Wilson 95% CI")
    ap.add_argument("--ci_style", choices=["ribbon", "bars"], default="ribbon", help="CI visualization style: shaded ribbon or error bars")
    ap.add_argument("--verbose", action="store_true", help="Print selected labels and debug info")
    args = ap.parse_args()

    exp_dir = Path(args.experiment)
    if not exp_dir.exists():
        exp_dir = Path("experiments") / args.experiment
    if not exp_dir.exists():
        print(f"Experiment not found: {args.experiment}")
        sys.exit(1)

    plots_dir = exp_dir / "plots"
    plots_dir.mkdir(exist_ok=True)

    csv_paths = sorted(plots_dir.glob("self_consistency_passk_*.csv"))
    sample_paths = sorted(plots_dir.glob("self_consistency_samples_*.jsonl"))

    # label -> (t, y, n_per_t)
    all_series: List[Tuple[str, np.ndarray, np.ndarray, Optional[np.ndarray]]] = []
    for p in csv_paths:
        try:
            all_series.extend(load_series_from_csv(p))
        except Exception:
            continue

    # Add MV@t curves. For combined chart keep only 4-bit; for mv_only include both 4-bit and full
    for sp in sample_paths:
        which, quant, temp = parse_meta_from_name(sp.name)
        if (not args.mv_only) and (not args.pass_only) and (not quant):
            # In combined-with-MV mode keep only 4-bit MV curves
            continue
        t, mv, totals = load_mv_from_jsonl(sp, args.n)
        if t.size > 0:
            all_series.append((label_for(which, quant, temp, "MV@t"), t, mv, totals))

    if not all_series:
        print("No plottable series found")
        sys.exit(0)

    unique: Dict[str, Tuple[np.ndarray, np.ndarray, Optional[np.ndarray]]] = {}
    for label, t, y, n_per_t in all_series:
        unique[label] = (t, y, n_per_t)

    # Select labels depending on mode
    labels = []
    for l in unique.keys():
        if args.mv_only:
            if "MV@t" in l:
                # Include greedy and non-greedy MV lines in MV-only plot
                labels.append(l)
            continue
        if args.pass_only:
            if "Pass@t" in l:
                labels.append(l)
            continue
        # Combined chart: include 4-bit non-greedy Pass@t and MV@t, plus 4-bit greedy Pass@t (no greedy MV)
        if (" 4-bit " in l) and ("Pass@t" in l):
            labels.append(l)
            continue
        if (" 4-bit " in l) and ("MV@t" in l) and (" – greedy – " not in l):
            labels.append(l)

    # Deterministic ordering to avoid plot inconsistencies
    if not args.mv_only and not args.pass_only:
        preferred_order = [
            "Base 4-bit – temp=1.0 – Pass@t",
            "Post 4-bit – temp=1.0 – Pass@t",
            "Base 4-bit – temp=1.0 – MV@t",
            "Post 4-bit – temp=1.0 – MV@t",
            "Base 4-bit – greedy – Pass@t",
            "Post 4-bit – greedy – Pass@t",
        ]
        present = {l for l in labels}
        ordered = [l for l in preferred_order if l in present]
        unordered_rest = [l for l in labels if l not in set(ordered)]
        labels = ordered + unordered_rest

    cmap = plt.colormaps.get_cmap("tab10")
    markers = ["o", "s", "^", "D", "v", "P", "*", "X", "+", "<"]

    def style_for_label(label: str) -> Tuple[str, float]:
        is_post = label.startswith("Post ")
        is_4bit = " 4-bit " in label
        is_full = " full " in label
        is_greedy = " – greedy – " in label
        # Color scheme: make Post 4-bit prominent, Base 4-bit next, Full grey, others fallback
        if is_4bit and is_post:
            color = "tab:blue"
        elif is_4bit and (not is_post):
            color = "tab:orange"
        elif is_full:
            color = "tab:gray"
        else:
            color = cmap(abs(hash(label)) % 10)
        # Line widths: boldest for 4-bit non-greedy (slightly thicker); thin for full and all greedy
        if is_4bit and (not is_greedy):
            width = 2.2
        else:
            width = 1.2
        return color, width

    if args.verbose:
        print("Selected series:")
        for l in labels:
            print(f" - {l}")

    # Reduced width to 2/3 of original, reduced height for better proportions
    fig, ax = plt.subplots(figsize=(6.1, 4.5))
    for idx, label in enumerate(labels):
        t, y, n_per_t = unique[label]
        # MV solid, Pass dashed
        # Linestyle rules
        if " – greedy – " in label:
            # All greedy lines are dotted across charts
            linestyle = ":"
        elif args.pass_only:
            # In Pass-only plots, make Base/Post 4-bit temp=1.0 Pass@t solid; others dashed
            if ("Pass@t" in label) and (" 4-bit " in label) and ("temp=1.0" in label) and (" – greedy – " not in label):
                linestyle = "-"
            else:
                linestyle = "--"
        elif args.mv_only:
            # MV-only plots: MV solid
            linestyle = "-"
        else:
            # Combined-with-MV: MV solid, Pass dashed
            linestyle = "-" if "MV@t" in label else "--"
        color, line_width = style_for_label(label)
        line_obj = ax.plot(t, y, label=label, color=color, linewidth=line_width, linestyle=linestyle, marker=markers[idx % len(markers)], markersize=3)[0]
        if args.with_ci:
            # Mode-aware CI visibility
            if args.mv_only:
                # MV-only: ribbons for all non-greedy MV lines (4-bit and full)
                show_ci = (" – greedy – " not in label) and ("MV@t" in label)
            elif args.pass_only:
                # Pass-only: ribbons for all non-greedy Pass lines (4-bit and full)
                show_ci = (" – greedy – " not in label) and ("Pass@t" in label)
            else:
                # Combined-with-MV: ribbons for the four non-greedy 4-bit lines
                show_ci = (" 4-bit " in label) and (" – greedy – " not in label) and ("Pass@t" in label or "MV@t" in label)
            if not show_ci:
                continue
            # Determine n per t
            if n_per_t is None:
                # Use constant n based on filename k if available
                k_guess = None
                # Try to recover k from label by searching original csv names is complex; default to 500
                k_guess = 500
                n_vec = np.full_like(t, fill_value=k_guess, dtype=float)
            elif isinstance(n_per_t, (int, np.integer)):
                n_vec = np.full_like(t, fill_value=int(n_per_t), dtype=float)
            else:
                n_vec = np.asarray(n_per_t, dtype=float)
            lo, hi = wilson_ci(y, n_vec)
            if args.ci_style == "ribbon":
                ax.fill_between(t, lo, hi, color=color, alpha=0.15, linewidth=0, zorder=1)
            else:
                yerr = np.vstack([y - lo, hi - y])
                ax.errorbar(t, y, yerr=yerr, color=color, alpha=0.6, linewidth=0, elinewidth=1.2, capsize=2, zorder=2)

    # Set integer x-axis ticks at 1, 5, 10, 15, 20
    ax.set_xticks([1, 5, 10, 15, 20])
    ax.set_xlim(1, 20)
    
    # Increase font sizes for labels and ticks
    ax.set_xlabel("# Sampled Reasoning Paths (t)", fontsize=20)
    ax.set_ylabel("Accuracy", fontsize=20)
    ax.tick_params(axis='both', which='major', labelsize=14)
    ax.grid(True, linestyle=":", alpha=0.6)

    try:
        t_all = np.concatenate([unique[l][0] for l in labels])
        y_all = np.concatenate([unique[l][1] for l in labels])
        # Keep x limits fixed at 0-20 (already set above)
        if len(y_all) > 0:
            ymin = float(np.nanmin(y_all)); ymax = float(np.nanmax(y_all))
            pad = max(1e-3, 0.02 * (ymax - ymin if ymax > ymin else 1.0))
            ax.set_ylim(ymin - pad, ymax + pad)
    except Exception:
        pass

    handles, lbls = ax.get_legend_handles_labels()
    
    # Create new legend handles with correct linestyles
    import matplotlib.lines as mlines
    new_handles = []
    for handle, label in zip(handles, lbls):
        # Determine correct linestyle
        if " – greedy – " in label:
            linestyle = ":"
        elif args.pass_only:
            if ("Pass@t" in label) and (" 4-bit " in label) and ("temp=1.0" in label) and (" – greedy – " not in label):
                linestyle = "-"
            else:
                linestyle = "--"
        elif args.mv_only:
            linestyle = "-"
        else:
            # Combined-with-MV: MV solid, Pass dashed
            linestyle = "-" if "MV@t" in label else "--"
        
        # Create new handle with correct linestyle
        if linestyle == "--":
            # Smaller dashes, bigger gaps: [dash_length, gap_length]  
            new_handle = mlines.Line2D([], [], color=handle.get_color(), 
                                       linestyle=(0, (4, 3)),  # Slightly longer dashes, shorter gaps
                                       linewidth=handle.get_linewidth(),
                                       marker=handle.get_marker(), markersize=handle.get_markersize())
        else:
            new_handle = mlines.Line2D([], [], color=handle.get_color(), 
                                       linestyle=linestyle, linewidth=handle.get_linewidth(),
                                       marker=handle.get_marker(), markersize=handle.get_markersize())
        new_handles.append(new_handle)
    
    # Place legend below the plot with more columns to save vertical space
    legend = fig.legend(new_handles, lbls, loc="lower center", ncol=min(len(lbls), 3), frameon=False, 
                        bbox_to_anchor=(0.5, -0.05))

    # Get model and dataset for filename (no title on plot)
    model_dataset_suffix = ""
    try:
        cfg = pd.read_json(exp_dir / "config.json", typ="series")
        model = cfg.get("model", "")
        dataset = cfg.get("dataset", "")
        if model and dataset:
            model_dataset_suffix = f"_{model}_{dataset}"
    except Exception:
        pass

    # Use tight_layout with padding to ensure nothing gets cut off
    fig.tight_layout(pad=2.0)

    if args.mv_only:
        out_path = Path(args.out) if args.out else (plots_dir / f"self_consistency_mv_only{model_dataset_suffix}.pdf")
    elif args.pass_only:
        out_path = Path(args.out) if args.out else (plots_dir / f"self_consistency_combined{model_dataset_suffix}.pdf")
    else:
        out_path = Path(args.out) if args.out else (plots_dir / f"self_consistency_combined_with_mv{model_dataset_suffix}.pdf")

    # Save with bbox_inches='tight' to include all elements including legend
    fig.savefig(out_path, dpi=200, bbox_inches='tight', pad_inches=0.3)
    print(f"Saved plot: {out_path}")


if __name__ == "__main__":
    main()
