import argparse
import json
import math
import re
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple

import matplotlib


def _ensure_headless_backend() -> None:
    """Select a non-interactive matplotlib backend for headless environments."""
    backend = matplotlib.get_backend()
    if backend.lower() in {"agg", "module://matplotlib_inline.backend_inline"}:
        return
    try:
        matplotlib.use("Agg")
    except Exception:
        pass


_ensure_headless_backend()
import matplotlib.pyplot as plt  # noqa: E402
from matplotlib.lines import Line2D  # noqa: E402


def find_repo_root(start: Path) -> Path:
    """Walk upwards from start to locate repository root (directory containing .git)."""
    current = start.resolve()
    for parent in [current] + list(current.parents):
        if (parent / ".git").exists():
            return parent
    return start


def parse_filename_metadata(file_path: Path) -> Tuple[str, Optional[int], str]:
    """Extract dataset name, max length, and a readable run tag from file name.

    Expected file name format: {dataset}_{max_len}_{run_tag}.json
    """
    name = file_path.stem
    parts = name.split("_", 2)
    if len(parts) < 3:
        return name, None, name

    dataset = parts[0]
    try:
        max_len = int(parts[1])
    except Exception:
        max_len = None

    raw_tag = parts[2]

    # Trim trailing timestamp-like segments to keep labels short
    # Matches patterns like -20250805_133928-538969_0-...
    match = re.search(r"-(?:19|20)\d{6,}", raw_tag)
    if match:
        tag = raw_tag[: match.start()]
    else:
        tag = raw_tag

    return dataset, max_len, tag


def load_perplexities(file_path: Path) -> Optional[List[float]]:
    """Load perplexities list from a length extrapolation result JSON file."""
    try:
        with file_path.open("r") as f:
            data = json.load(f)
    except Exception:
        return None

    # Expected schema: {"results": {"perplexities": [...]}}
    if isinstance(data, dict):
        if "results" in data and isinstance(data["results"], dict):
            perplexities = data["results"].get("perplexities")
        else:
            perplexities = data.get("perplexities")
    else:
        perplexities = None

    if not isinstance(perplexities, list):
        return None
    return perplexities


def discover_json_files(target_dir: Path, pattern: Optional[str]) -> List[Path]:
    files = []
    if pattern:
        files.extend(sorted(target_dir.glob(pattern)))
    else:
        files.extend(sorted(target_dir.glob("*.json")))
    return [p for p in files if p.is_file()]


def compute_token_losses(perplexities: List[float]) -> List[float]:
    """Convert a list of perplexities to per-token losses using natural log.

    Skips NaNs or non-positive values by propagating previous valid value when possible.
    """
    losses: List[float] = []
    last_valid: Optional[float] = None
    for v in perplexities:
        try:
            if (
                v is None
                or (isinstance(v, float) and (math.isnan(v) or math.isinf(v)))
                or v <= 0
            ):
                losses.append(last_valid if last_valid is not None else float("nan"))
            else:
                lv = math.log(float(v))
                losses.append(lv)
                last_valid = lv
        except Exception:
            losses.append(last_valid if last_valid is not None else float("nan"))
    return losses


def moving_average(values: List[float], window: int) -> List[float]:
    if window <= 1:
        return values
    window = max(1, int(window))
    half = window // 2
    out: List[float] = []
    prefix_sum = [0.0]
    valid = [0]
    for x in values:
        if math.isnan(x) or math.isinf(x):
            prefix_sum.append(prefix_sum[-1])
            valid.append(valid[-1])
        else:
            prefix_sum.append(prefix_sum[-1] + x)
            valid.append(valid[-1] + 1)
    n = len(values)
    for i in range(n):
        left = max(0, i - half)
        right = min(n, i + half + 1)
        cnt = valid[right] - valid[left]
        if cnt == 0:
            out.append(float("nan"))
        else:
            s = prefix_sum[right] - prefix_sum[left]
            out.append(s / cnt)
    return out


def prettify_dataset_name(dataset: str) -> str:
    mapping = {
        "trivia": "Trivia QA",
        "trivia_qa": "Trivia QA",
        "codeparrot": "CodeParrot",
        "open_thoughts_math": "OpenThoughts Math",
        "openthoughts_math": "OpenThoughts Math",
        "open_thoughts_114k_math": "OpenThoughts 114K Math",
        "openthoughts_114k_math": "OpenThoughts 114K Math",
    }
    if dataset in mapping:
        return mapping[dataset]
    # Title-case with spaces on underscores as a fallback
    return dataset.replace("_", " ").title()


def shorten_model_tag(tag: str) -> str:
    """Produce a compact, stable model label based on the run tag.

    Heuristics: keep fields like 'paninetto{digit}', 'ane{digit}', 'Array', otherwise
    compress to <= 40 chars.
    """
    parts = []
    m = re.search(r"paninetto(\d)", tag)
    if m:
        parts.append(f"paninetto{m.group(1)}")
    m = re.search(r"ane(\d)", tag)
    if m:
        parts.append(f"ane{m.group(1)}")
    if "Array" in tag:
        parts.append("Array")
    if parts:
        return "/".join(parts)
    if len(tag) <= 40:
        return tag
    return tag[:37] + "…"


def extract_flags_from_tag(tag: str) -> Dict[str, Optional[bool]]:
    """Extract available on/off flags from run tag.

    Supported flags:
      - paninetto{0|1}
      - ane{0|1}  (negative eigenvalues)
      - sc{0|1}   (simple combination)
      - iq{0|1}   (independent query projections)

    Returns a dict with keys 'paninetto', 'ane', 'sc', 'iq' and Optional[bool] values.
    Value None means flag not present in the tag.
    """
    flags: Dict[str, Optional[bool]] = {"paninetto": None, "ane": None, "sc": None, "iq": None}

    m = re.search(r"paninetto(\d)", tag)
    if m:
        flags["paninetto"] = m.group(1) == "1"

    m = re.search(r"ane(\d)", tag)
    if m:
        flags["ane"] = m.group(1) == "1"

    m = re.search(r"sc(\d)", tag)
    if m:
        flags["sc"] = m.group(1) == "1"

    m = re.search(r"iq(\d)", tag)
    if m:
        flags["iq"] = m.group(1) == "1"

    return flags


def group_files_by_dataset_and_model(
    files: List[Path],
) -> Dict[str, Dict[str, List[Path]]]:
    grouped: Dict[str, Dict[str, List[Path]]] = defaultdict(lambda: defaultdict(list))
    for p in files:
        dataset, _max_len, tag = parse_filename_metadata(p)
        model_label = shorten_model_tag(tag)
        grouped[dataset][model_label].append(p)
    return grouped


def plot_losses_subplots(
    grouped: Dict[str, Dict[str, List[Path]]],
    downsample: int,
    smooth_window: int,
    train_ctx: Optional[int],
    output_path: Path,
    selected_factors: Optional[List[str]] = None,
    baseline_by_dataset: Optional[Dict[str, List[Path]]] = None,
) -> None:
    # Stable style resembling the attached figure
    plt.rcParams.update(
        {
            "font.family": "serif",
            "font.size": 18,
            "axes.labelsize": 22,
            "axes.titlesize": 30,
            "xtick.labelsize": 18,
            "ytick.labelsize": 18,
            "legend.fontsize": 18,
            "axes.spines.top": False,
            "axes.spines.right": False,
        }
    )

    datasets = list(grouped.keys())
    if not datasets:
        raise ValueError("No datasets discovered for plotting.")

    # Color mapping base (viridis)
    cmap = plt.get_cmap("viridis")
    color_on = cmap(0.98)   # yellow-ish
    color_off = cmap(0.55)  # green-ish
    # Dedicated baseline color (distinct from viridis ramp)
    baseline_color = "#E24A33"  # orange from Matplotlib classic palette

    num_cols = len(datasets)
    fig, axes = plt.subplots(1, num_cols, figsize=(5 * num_cols, 4.5), sharey=False)
    if num_cols == 1:
        axes = [axes]  # type: ignore[list-item]

    # Normalize selected factors (order maps to: color, linestyle, marker)
    valid_factors = {"paninetto", "ane", "sc", "iq"}
    if selected_factors is not None:
        selected_factors = [f.strip().lower() for f in selected_factors if f.strip()]
        for f in selected_factors:
            if f not in valid_factors:
                raise ValueError(
                    f"Unknown factor '{f}'. Allowed: paninetto, ane, sc, iq."
                )
        if len(selected_factors) > 3:
            raise ValueError("At most three factors can be selected (color, linestyle, marker)")
    else:
        # Default mapping if not provided: color=paninetto or sc, linestyle=ane, marker=iq
        selected_factors = ["paninetto", "ane", "iq"]

    # When only one factor is selected, make it highly distinguishable by using
    # both color and linestyle (and light markers) for that single factor.
    single_factor_dual_style = selected_factors is not None and len(selected_factors) == 1

    # First pass: detect which selected flags are present across all files
    present_flags = {k: False for k in valid_factors}
    for dataset in datasets:
        for paths in grouped[dataset].values():
            for p in paths:
                _ds, _ml, tag = parse_filename_metadata(p)
                f = extract_flags_from_tag(tag)
                for k in selected_factors:
                    if f[k] is not None:
                        present_flags[k] = True

    for ax, dataset in zip(axes, datasets):
        pretty = prettify_dataset_name(dataset)
        y_min = float("inf")
        y_max = float("-inf")

        # Optional: plot baseline first so experimental curves draw on top
        if baseline_by_dataset is not None and dataset in baseline_by_dataset:
            baseline_candidates = baseline_by_dataset[dataset]
            # Choose the baseline file with the longest run
            try:
                best_baseline = max(
                    baseline_candidates, key=lambda p: len(load_perplexities(p) or [])
                )
            except Exception:
                best_baseline = None
            if best_baseline is not None:
                bppl = load_perplexities(best_baseline)
                if bppl:
                    blosses = compute_token_losses(bppl)
                    if smooth_window and smooth_window > 1:
                        blosses = moving_average(blosses, smooth_window)
                    bx = list(range(1, len(blosses) + 1))
                    by = blosses
                    if downsample > 1:
                        bx = bx[::downsample]
                        by = by[::downsample]
                    ax.plot(
                        bx,
                        by,
                        color=baseline_color,
                        linestyle="-",
                        linewidth=2.5,
                        label="DeltaProduct_3",
                        zorder=1,
                    )
                    finite_vals = [v for v in by if not (math.isnan(v) or math.isinf(v))]
                    if finite_vals:
                        y_min = min(y_min, min(finite_vals))
                        y_max = max(y_max, max(finite_vals))
        for model_label, paths in grouped[dataset].items():
            # Sub-group paths by the tuple of selected factor values so we can plot
            # all present combinations (e.g., iq on/off) for this model and dataset.
            subgroups: Dict[Tuple[Optional[bool], ...], List[Path]] = defaultdict(list)
            for p in paths:
                _ds, _ml, tag = parse_filename_metadata(p)
                f = extract_flags_from_tag(tag)
                key: Tuple[Optional[bool], ...] = tuple(
                    (f[name] if name is not None else None) for name in selected_factors
                )
                # Skip runs that are missing any selected factor
                if any(v is None for v in key):
                    continue
                subgroups[key].append(p)

            for key, subpaths in subgroups.items():
                # For each factor combination, pick the longest run
                best_path = max(subpaths, key=lambda p: len(load_perplexities(p) or []))
                ppl = load_perplexities(best_path)
                if not ppl:
                    continue
                losses = compute_token_losses(ppl)
                if smooth_window and smooth_window > 1:
                    losses = moving_average(losses, smooth_window)
                x_values = list(range(1, len(losses) + 1))
                if downsample > 1:
                    x_values = x_values[::downsample]
                    y_values = losses[::downsample]
                else:
                    y_values = losses

                # Determine visual encoding from filename tag
                _ds, _ml, tag = parse_filename_metadata(best_path)
                flags = extract_flags_from_tag(tag)

                # Determine encodings from selected factors (color, linestyle, marker)
                color_flag_name = selected_factors[0] if len(selected_factors) >= 1 else None
                linestyle_flag_name = selected_factors[1] if len(selected_factors) >= 2 else None
                marker_flag_name = selected_factors[2] if len(selected_factors) >= 3 else None

                color_flag_value: Optional[bool] = (
                    flags[color_flag_name] if color_flag_name else None
                )

                if color_flag_value is True:
                    color = color_on
                elif color_flag_value is False:
                    color = color_off
                else:
                    color = "#666666"  # neutral when no color-driving flag present

                # Linestyle
                if linestyle_flag_name:
                    ls_val = flags[linestyle_flag_name]
                    linestyle = "-" if ls_val is True else "--" if ls_val is False else "-"
                else:
                    # If only one selected factor, reuse its value to drive linestyle too
                    if single_factor_dual_style and color_flag_name:
                        ls_val = flags[color_flag_name]
                        linestyle = "-" if ls_val is True else "--" if ls_val is False else "-"
                    else:
                        linestyle = "-"

                # Marker
                if marker_flag_name:
                    m_val = flags[marker_flag_name]
                    marker = "o" if m_val is True else "s" if m_val is False else None
                else:
                    # For single-factor mode, add subtle markers as another cue
                    if single_factor_dual_style and color_flag_name:
                        m_val = flags[color_flag_name]
                        marker = "o" if m_val is True else "s" if m_val is False else None
                    else:
                        marker = None
                markevery = 160 if marker is not None else None

                ax.plot(
                    x_values,
                    y_values,
                    color=color,
                    linestyle=linestyle,
                    linewidth=2.5,
                    marker=marker if marker is not None else None,
                    markevery=markevery,
                    markersize=2,
                    zorder=3 if color_flag_value is True else 2,
                )

                # Track y-limits while ignoring NaNs/Infs
                finite_vals = [v for v in y_values if not (math.isnan(v) or math.isinf(v))]
                if finite_vals:
                    y_min = min(y_min, min(finite_vals))
                    y_max = max(y_max, max(finite_vals))

        if train_ctx is not None and train_ctx > 0:
            ax.axvline(train_ctx, color="black", linestyle="dashdot", linewidth=2)

        ax.set_title(pretty)
        ax.grid(True, linestyle=":", linewidth=0.75, alpha=0.6)
        ax.set_xlim(left=0)
        # Fixed x ticks at powers-of-two window as requested
        ax.set_xticks([4096, 8192, 16384])

        # Adaptive y-limits with comfortable margins
        if y_min != float("inf") and y_max != float("-inf"):
            span = max(1e-6, y_max - y_min)
            # Ensure a minimum visible span to keep ticks readable
            min_span = 0.15
            span = max(span, min_span)
            margin = 0.12 * span
            ax.set_ylim(y_min - (margin * 0.05), y_max + (margin * -8))

    # Reserve bottom margin based on how many legend tiers we will show
    legend_tiers: List[str] = []
    baseline_present = baseline_by_dataset is not None and any(
        baseline_by_dataset.get(ds) for ds in datasets
    )
    if baseline_present:
        legend_tiers.append("baseline")
    if len(selected_factors) >= 1 and present_flags[selected_factors[0]]:
        legend_tiers.append("color")
    if len(selected_factors) >= 2 and present_flags[selected_factors[1]]:
        legend_tiers.append("linestyle")
    if len(selected_factors) >= 3 and present_flags[selected_factors[2]]:
        legend_tiers.append("marker")

    # Base space for 1-4 tiers (including optional baseline)
    base = 0.16
    per_tier_extra = 0.08
    bottom_margin = base + per_tier_extra * max(0, len(legend_tiers) - 1)
    fig.subplots_adjust(bottom=bottom_margin)

    # Shared labels above the legend area
    fig.supxlabel("Sequence Length", y=bottom_margin + 0.12)
    fig.supylabel("Token Loss")

    # Legend tiers (draw from top to bottom of reserved space)
    anchor_y = bottom_margin - 0.02

    # Baseline tier (topmost)
    if "baseline" in legend_tiers:
        baseline_handles = [
            Line2D([0], [0], color=baseline_color, linestyle="-", linewidth=2.5, label=r"DeltaProduct$_3$"),
        ]
        fig.legend(
            baseline_handles,
            [h.get_label() for h in baseline_handles],
            loc="center",
            ncol=1,
            bbox_to_anchor=(0.5, anchor_y + 0.02 * len(legend_tiers)),
            bbox_transform=fig.transFigure,
        )

    # Color tier: label matches the selected factor name
    if "color" in legend_tiers:
        color_factor_name = selected_factors[0]
        label_map = {
            "paninetto": "Paninetto",
            "sc": "Simple combination",
            "iq": "Independent queries",
            "ane": "Negative eigenvalues",
        }
        color_label_title = label_map[color_factor_name]
        # If in single-factor mode, make legend reflect linestyle too
        if single_factor_dual_style:
            color_handles = [
                Line2D([0], [0], color=color_on, linestyle="-", linewidth=2.5, marker="o", markersize=2, label=f"{color_label_title} On"),
                Line2D([0], [0], color=color_off, linestyle="--", linewidth=2.5, marker="s", markersize=2, label=f"{color_label_title} Off"),
            ]
        else:
            color_handles = [
                Line2D([0], [0], color=color_on, linewidth=2.5, label=f"{color_label_title} On"),
                Line2D([0], [0], color=color_off, linewidth=2.5, label=f"{color_label_title} Off"),
            ]
        fig.legend(
            color_handles,
            [h.get_label() for h in color_handles],
            loc="center",
            ncol=2,
            bbox_to_anchor=(0.5, anchor_y - 0.08 * len(legend_tiers)),
            bbox_transform=fig.transFigure,
        )

    # Linestyle tier
    if "linestyle" in legend_tiers:
        ls_factor_name = selected_factors[1]
        label_map = {
            "paninetto": "Paninetto",
            "sc": "Simple combination",
            "iq": "Independent queries",
            "ane": "Negative eigenvalues",
        }
        ls_title = label_map[ls_factor_name]
        ls_handles = [
            Line2D([0], [0], color="black", linestyle="-", linewidth=2.5, label=f"{ls_title} On"),
            Line2D([0], [0], color="black", linestyle="--", linewidth=2.5, label=f"{ls_title} Off"),
        ]
        fig.legend(
            ls_handles,
            [h.get_label() for h in ls_handles],
            loc="center",
            ncol=2,
            bbox_to_anchor=(0.5, anchor_y + 0.1 * len(legend_tiers)),
            bbox_transform=fig.transFigure,
        )

    # Marker tier
    if "marker" in legend_tiers:
        mk_factor_name = selected_factors[2]
        label_map = {
            "paninetto": "Paninetto",
            "sc": "Simple combination",
            "iq": "Independent queries",
            "ane": "Negative eigenvalues",
        }
        mk_title = label_map[mk_factor_name]
        mk_handles = [
            Line2D([0], [0], color="black", linestyle="None", marker="o", markersize=2, label=f"{mk_title} On"),
            Line2D([0], [0], color="black", linestyle="None", marker="s", markersize=2, label=f"{mk_title} Off"),
        ]
        fig.legend(
            mk_handles,
            [h.get_label() for h in mk_handles],
            loc="center",
            ncol=2,
            bbox_to_anchor=(0.5, anchor_y - 0.08 * len(legend_tiers)),
            bbox_transform=fig.transFigure,
        )

    output_path.parent.mkdir(parents=True, exist_ok=True)
    fig.tight_layout(rect=(0, bottom_margin + 0.04, 1, 1))
    fig.savefig(output_path.with_suffix(".png"), dpi=300)
    try:
        fig.savefig(output_path.with_suffix(".pdf"))
    except Exception:
        pass


def main() -> None:
    script_path = Path(__file__).resolve()
    repo_root = find_repo_root(script_path)
    default_data_dir = repo_root / "data" / "length_extrapolation"

    parser = argparse.ArgumentParser(
        description="Auto-detect length extrapolation results and plot per-token loss subplots",
    )
    parser.add_argument(
        "--dir",
        type=Path,
        default=default_data_dir,
        help="Directory containing JSON result files (default: repo_root/data/length_extrapolation)",
    )
    parser.add_argument(
        "--pattern",
        type=str,
        default=None,
        help="Optional glob pattern to filter files, e.g., 'codeparrot_*.json'",
    )
    parser.add_argument(
        "--downsample",
        type=int,
        default=1,
        help="Plot every k-th point to reduce clutter (default: 1)",
    )
    parser.add_argument(
        "--smooth",
        type=int,
        default=0,
        help="Apply moving-average smoothing window (tokens). 0 disables.",
    )
    parser.add_argument(
        "--factors",
        type=str,
        default=None,
        help=(
            "Comma-separated list of flags to visualize in order (color,linestyle,marker). "
            "Allowed: paninetto, ane, sc, iq. Example: --factors paninetto,ane or --factors sc,iq"
        ),
    )
    parser.add_argument(
        "--train-ctx",
        type=int,
        default=4096,
        help="Draw a vertical dashed line at this training context length (0 disables)",
    )
    parser.add_argument(
        "--out",
        type=Path,
        default=None,
        help="Output figure path without extension (default: <dir>/length_extrapolation_loss)",
    )
    parser.add_argument(
        "--baseline-dir",
        type=Path,
        default=None,
        help=(
            "Optional directory containing baseline JSONs to overlay (default: data/length_extrapolation/dp3_baseline). "
            "Baselines are matched by dataset prefix and labeled as DeltaProduct_3."
        ),
    )

    args = parser.parse_args()

    target_dir = args.dir
    if not target_dir.exists():
        raise FileNotFoundError(f"Directory not found: {target_dir}")

    files = discover_json_files(target_dir, args.pattern)
    if not files:
        raise FileNotFoundError(
            f"No JSON files found in {target_dir} with pattern {args.pattern or '*.json'}"
        )

    output_path = (
        args.out if args.out is not None else (target_dir / "length_extrapolation_loss")
    )

    # Keep only JSONs that actually contain a perplexities list
    valid_files: List[Path] = []
    for p in files:
        ppl = load_perplexities(p)
        if isinstance(ppl, list) and len(ppl) > 0:
            valid_files.append(p)

    if not valid_files:
        raise ValueError("No valid result files with 'perplexities' found.")

    grouped = group_files_by_dataset_and_model(valid_files)

    # Discover baseline files (default path if not provided)
    baseline_dir: Optional[Path]
    if args.baseline_dir is None:
        # Default to repo_root/data/length_extrapolation/dp3_baseline
        baseline_dir = repo_root / "data" / "length_extrapolation" / "dp3_baseline"
    else:
        baseline_dir = args.baseline_dir

    baseline_by_dataset: Optional[Dict[str, List[Path]]] = None
    if baseline_dir.exists():
        baseline_candidates = discover_json_files(baseline_dir, None)
        # Map by dataset inferred from file name prefix
        tmp: Dict[str, List[Path]] = defaultdict(list)
        for p in baseline_candidates:
            ds, _ml, _tag = parse_filename_metadata(p)
            # Keep only those datasets that are present in the main grouped set
            if ds in grouped:
                tmp[ds].append(p)
        baseline_by_dataset = dict(tmp)

    plot_losses_subplots(
        grouped=grouped,
        downsample=max(1, int(args.downsample)),
        smooth_window=max(0, int(args.smooth)),
        train_ctx=(None if int(args.train_ctx) <= 0 else int(args.train_ctx)),
        output_path=output_path,
        selected_factors=(
            [s.strip() for s in args.factors.split(",") if s.strip()]
            if args.factors
            else None
        ),
        baseline_by_dataset=baseline_by_dataset,
    )


if __name__ == "__main__":
    main()
