import argparse
import json
import logging
import os
import re
from typing import List, Optional, Tuple

import matplotlib.pyplot as plt
import numpy as np
from scipy.ndimage import uniform_filter1d


def _setup_matplotlib():
    # User prefers Times New Roman for plots
    plt.rcParams["font.family"] = "Times New Roman"
    plt.rcParams["axes.grid"] = True
    plt.rcParams["grid.alpha"] = 0.3
    plt.rcParams["figure.constrained_layout.use"] = True


def _find_result_jsons(root_dir: str) -> List[str]:
    matches: List[str] = []
    for dirpath, dirnames, filenames in os.walk(root_dir):
        # Only consider length_extrapolation directories
        if os.path.basename(dirpath) != "length_extrapolation":
            continue
        for fname in filenames:
            if not fname.endswith(".json"):
                continue
            matches.append(os.path.join(dirpath, fname))
    return sorted(matches)


def _parse_dataset_and_len(json_path: str) -> Tuple[Optional[str], Optional[int]]:
    name = os.path.basename(json_path)
    m = re.match(r"(.+)_([0-9]+)\.json$", name)
    if not m:
        return None, None
    dataset = m.group(1)
    try:
        max_len = int(m.group(2))
    except Exception:
        max_len = None
    return dataset, max_len


def _load_token_losses(json_path: str) -> Optional[np.ndarray]:
    try:
        with open(json_path, "r") as f:
            payload = json.load(f)
    except Exception as e:
        logging.error(f"Failed to read {json_path}: {e}")
        return None
    # Support both {"results": {...}} and flat {...}
    results = payload.get("results", payload)
    token_losses = results.get("token_losses")
    if token_losses is None:
        logging.warning(f"No token_losses in {json_path}; skipping")
        return None
    try:
        arr = np.array(token_losses, dtype=float)
        return arr
    except Exception as e:
        logging.error(f"Bad token_losses format in {json_path}: {e}")
        return None


def _smooth_losses(losses: np.ndarray, window_size: int = 50) -> np.ndarray:
    """Apply uniform smoothing to the losses."""
    if len(losses) < window_size:
        window_size = max(1, len(losses) // 10)
    return uniform_filter1d(losses, size=window_size, mode='nearest')


def _run_label_from_json(json_path: str, logs_root: str) -> str:
    # json_path -> .../<logs_root>/.../<run_id>/length_extrapolation/<dataset>_<len>.json
    length_dir = os.path.dirname(json_path)
    run_dir = os.path.dirname(length_dir)
    try:
        rel = os.path.relpath(run_dir, logs_root)
    except Exception:
        rel = os.path.basename(run_dir)
    return rel


def _plot_all_runs(lines: List[Tuple[str, np.ndarray, Optional[str], Optional[int]]],
                   out_path: str,
                   dpi: int,
                   show: bool):
    _setup_matplotlib()
    fig, ax = plt.subplots(figsize=(5.5, 3.6))

    for idx, (label, token_losses, dataset, max_len) in enumerate(lines):
        x = np.arange(1, token_losses.shape[0] + 1)
        smoothed_losses = _smooth_losses(token_losses)
        extra = f" [{dataset}, L={max_len}]" if dataset is not None and max_len is not None else ""
        ax.plot(x, smoothed_losses, lw=2.0, label=f"{label}{extra}")

    # Add vertical line at training context length (4096)
    ax.axvline(x=4096, color='red', linestyle='--', alpha=0.7, label='Training context length (4096)')

    ax.set_xlabel("Token position")
    ax.set_ylabel("Per-token loss")
    ax.grid(True, linestyle="-")
    # Title summarizes dataset/len if uniform
    datasets = {d for _, _, d, _ in lines if d is not None}
    lens = {l for _, _, _, l in lines if l is not None}
    title = "Per-token loss"
    if len(datasets) == 1 and len(lens) == 1:
        only_d = next(iter(datasets))
        only_l = next(iter(lens))
        title = f"Per-token loss ({only_d}, len={only_l})"
    ax.set_title(title)

    ax.set_ylim(2.8,4)
    ax.legend(loc="best", framealpha=0.95, fontsize=9)

    fig.savefig(out_path, dpi=dpi)
    if show:
        plt.show()
    plt.close(fig)


def main():
    logging.basicConfig(level=logging.INFO)

    parser = argparse.ArgumentParser(description="Plot per-token losses from length_extrapolation results")
    parser.add_argument("--logs_dir", type=str, required=True, help="Root directory to scan (e.g., 23-16-02 or a date folder)")
    parser.add_argument("--dpi", type=int, default=300)
    parser.add_argument("--ext", type=str, default="png", choices=["png", "pdf"]) 
    parser.add_argument("--overwrite", action="store_true")
    parser.add_argument("--show", action="store_true")
    args = parser.parse_args()

    root = os.path.abspath(args.logs_dir)
    if not os.path.isdir(root):
        logging.error(f"Directory not found: {root}")
        return

    json_files = _find_result_jsons(root)
    if not json_files:
        logging.info(f"No length_extrapolation JSONs found under {root}")
        return

    logging.info(f"Found {len(json_files)} result files")

    lines: List[Tuple[str, np.ndarray, Optional[str], Optional[int]]] = []
    for jpath in json_files:
        dataset, max_len = _parse_dataset_and_len(jpath)
        token_losses = _load_token_losses(jpath)
        if token_losses is None:
            continue
        label = _run_label_from_json(jpath, root)
        lines.append((label, token_losses, dataset, max_len))

    if len(lines) == 0:
        logging.info("No valid token_losses found; nothing to plot")
        return

    # Determine output path under the plotting/ directory (this script's dir)
    script_dir = os.path.dirname(os.path.abspath(__file__))
    datasets = {d for _, _, d, _ in lines if d is not None}
    lens = {l for _, _, _, l in lines if l is not None}
    if len(datasets) == 1 and len(lens) == 1:
        suffix = f"_{next(iter(datasets))}_{next(iter(lens))}"
    else:
        suffix = "_mixed"
    out_fname = f"token_losses_all{suffix}.{args.ext}"
    out_path = os.path.join(script_dir, out_fname)

    if os.path.exists(out_path) and not args.overwrite:
        logging.info(f"Exists, skip: {out_path}")
        return

    logging.info(f"Saving aggregated plot to {out_path}")
    _plot_all_runs(lines, out_path, dpi=args.dpi, show=args.show)


if __name__ == "__main__":
    main()


