#!/usr/bin/env python3
import os
import sys
import argparse
from pathlib import Path
from typing import Dict, List, Tuple, Optional

import json
import numpy as np
import matplotlib.pyplot as plt


def normalize_answer(val: Optional[str]) -> str:
    if val is None:
        return ""
    s = str(val).strip().lower()
    if s.startswith("\\boxed{") and s.endswith("}"):
        s = s[len("\\boxed{"):-1].strip()
    return s


def parse_meta_from_name(name: str) -> Tuple[Optional[str], bool, Optional[float]]:
    base = os.path.basename(name)
    core = base.replace(".jsonl", "")
    tokens = core.split("_")
    which: Optional[str] = None
    temp: Optional[float] = None
    quant = False
    for t in tokens:
        if t in ("base", "post"):
            which = t
        elif t.startswith("temp"):
            try:
                temp = float(t.replace("temp", ""))
            except Exception:
                temp = None
        elif t in ("qb", "qp", "quant", "quantized"):
            quant = True
    return which, quant, temp


def load_sampling_consistency(samples_path: Path, n: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """Compute sampling consistency over t for a JSONL cache.

    For each prompt and each t in [1..min(n, len(parsed))], compute max vote share among the first t
    parsed answers (after normalization). Then average across prompts for each t.

    Returns:
        t: array [T]
        mean_consistency: array [T] with mean of max_vote_share per t
        counts: array [T] with number of prompts contributing at each t
    """
    per_prompt_curves: List[np.ndarray] = []

    with open(samples_path, "r", encoding="utf-8") as f:
        for line in f:
            if not line.strip():
                continue
            rec = json.loads(line)
            parsed: List[Optional[str]] = rec.get("parsed", [])
            if not parsed:
                continue
            # Keep None values as disagreements, normalize non-None values
            norm = [normalize_answer(a) if a is not None else None for a in parsed]
            if not norm:
                continue
            L = min(n, len(norm))
            curve = np.zeros(L, dtype=float)
            from collections import Counter
            for i in range(1, L + 1):
                # Count non-None responses only for majority calculation
                valid_responses = [a for a in norm[:i] if a is not None]
                if not valid_responses:
                    curve[i - 1] = 0.0  # All None responses = 0 consistency
                    continue
                counts = Counter(valid_responses)
                max_votes = max(counts.values()) if counts else 0
                # Denominator is total responses (including None), numerator is max votes among valid
                curve[i - 1] = float(max_votes) / float(i) if i > 0 else 0.0
            per_prompt_curves.append(curve)

    if not per_prompt_curves:
        return np.array([]), np.array([]), np.array([])

    max_t = max(len(c) for c in per_prompt_curves)
    sums = np.zeros(max_t, dtype=float)
    counts = np.zeros(max_t, dtype=int)
    for c in per_prompt_curves:
        L = len(c)
        sums[:L] += c
        counts[:L] += 1
    with np.errstate(invalid="ignore"):
        means = np.divide(sums, counts, out=np.zeros_like(sums), where=counts > 0)
    t = np.arange(1, max_t + 1)
    return t, means, counts


def main():
    ap = argparse.ArgumentParser(description="Plot Sampling Consistency (mean max vote share over t) for 4-bit Base and Post")
    ap.add_argument("experiment", help="Experiment ID or path")
    ap.add_argument("-n", "--n", type=int, default=50, help="Max t to consider (caps per-sample length)")
    ap.add_argument("--out", default=None, help="Optional explicit output PDF path")
    ap.add_argument("--verbose", action="store_true", help="Print debug info")
    args = ap.parse_args()

    exp_dir = Path(args.experiment)
    if not exp_dir.exists():
        exp_dir = Path("experiments") / args.experiment
    if not exp_dir.exists():
        print(f"Experiment not found: {args.experiment}")
        sys.exit(1)

    plots_dir = exp_dir / "plots"
    plots_dir.mkdir(exist_ok=True)

    # Find 4-bit non-greedy caches for base and post
    sample_paths = sorted(plots_dir.glob("self_consistency_samples_*temp1.0_*_quant.jsonl"))
    selected: Dict[str, Path] = {}
    for sp in sample_paths:
        which, quant, temp = parse_meta_from_name(sp.name)
        if not quant:
            continue
        if temp is None or float(temp) != 1.0:
            continue
        if which in ("base", "post"):
            # Prefer explicit *_quant.jsonl when multiple present
            selected[which] = sp

    if args.verbose:
        print("Found sample caches:")
        for k, v in selected.items():
            print(f" - {k}: {v}")

    if "base" not in selected and "post" not in selected:
        print("No 4-bit temp=1.0 caches found for base/post")
        sys.exit(0)

    # Reduced width to 2/3 of original, reduced height for better proportions
    fig, ax = plt.subplots(figsize=(5.6, 4.5))
    order = ["base", "post"]
    label_map = {"base": "Base Model", "post": "Post-trained"}
    color_map = {"base": "tab:orange", "post": "tab:blue"}
    marker_map = {"base": "^", "post": "o"}  # Triangle for base (orange), circle for post (blue)
    
    # Marker positions at major ticks
    marker_positions = [5, 10, 15, 20]

    for key in order:
        if key not in selected:
            continue
        t, mean_consistency, counts = load_sampling_consistency(selected[key], args.n)
        if t.size == 0:
            continue
        
        # Plot the full line with markers for legend
        # First plot line without label
        ax.plot(t, mean_consistency, color=color_map.get(key, None), linewidth=2.2)
        
        # Add visible markers at specific positions
        marker_indices = []
        marker_values = []
        for pos in marker_positions:
            if pos <= len(t):
                idx = pos - 1  # Convert to 0-based index
                marker_indices.append(t[idx])
                marker_values.append(mean_consistency[idx])
        
        # Plot markers with label for legend
        if marker_indices:
            ax.plot(marker_indices, marker_values, marker=marker_map[key], 
                   color=color_map.get(key, None), markersize=8, markeredgewidth=1.5,
                   markeredgecolor='white', linestyle='', label=label_map[key], zorder=5)

    # Set integer x-axis ticks at 1, 5, 10, 15, 20
    ax.set_xticks([1, 5, 10, 15, 20])
    ax.set_xlim(1, 20)
    
    # Increase font sizes for labels and ticks
    ax.set_xlabel("# Sampled Reasoning Paths (t)", fontsize=20)
    ax.set_ylabel("Sampling Consistency", fontsize=20)
    ax.tick_params(axis='both', which='major', labelsize=14)
    ax.grid(True, linestyle=":", alpha=0.6)

    # Get model and dataset for filename (no title on plot)
    model_dataset_suffix = ""
    try:
        import pandas as pd
        cfg = pd.read_json(exp_dir / "config.json", typ="series")
        model = cfg.get("model", "")
        dataset = cfg.get("dataset", "")
        if model and dataset:
            model_dataset_suffix = f"_{model}_{dataset}"
    except Exception:
        pass

    # Move legend to upper right with increased font size to match tick labels
    ax.legend(loc="upper right", frameon=False, fontsize=14)
    # Use tight_layout with padding to ensure nothing gets cut off
    fig.tight_layout(pad=2.0)

    out_path = Path(args.out) if args.out else (plots_dir / f"sampling_consistency{model_dataset_suffix}.pdf")
    # Save with bbox_inches='tight' to include all elements
    fig.savefig(out_path, dpi=200, bbox_inches='tight', pad_inches=0.3)
    print(f"Saved plot: {out_path}")


if __name__ == "__main__":
    main()


