"""Plot LoRA adapter norms vs eval perplexity for a sweep directory."""

from __future__ import annotations

import argparse
import json
import math
import re
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Tuple

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib import cm, colors
from matplotlib.lines import Line2D
from safetensors import safe_open

RANK_DIR_RE = re.compile(r"r\d+$")
RANK_BASE_COLORS = [
    "#1f77b4",
    "#2ca02c",
    "#9467bd",
    "#d62728",
    "#8c564b",
    "#e377c2",
]
PROGRESS_MIN = 0.15
PROGRESS_MAX = 0.95
NON_TRAINED_LABEL = "Non-trained"
TRAINED_LABEL = "trained (validation)"
LOSS_LABEL = "loss"
PPL_LABEL = "ppl"


def gradient_cmap(base_color: str) -> colors.LinearSegmentedColormap:
    base = np.array(colors.to_rgb(base_color))
    light = 1 - (1 - base) * 0.35
    dark = base * 0.65
    return colors.LinearSegmentedColormap.from_list("", [light, dark])


def metric_label_for(metric: str) -> str:
    prefix = ""
    base_metric = metric
    if base_metric.startswith("delta_"):
        prefix = "Delta "
        base_metric = base_metric[len("delta_"):]
    role_label = None
    if base_metric.startswith("heldout_eval_"):
        role_label = TRAINED_LABEL
        base_metric = base_metric[len("heldout_eval_"):]
    elif base_metric.startswith("eval_"):
        role_label = NON_TRAINED_LABEL
        base_metric = base_metric[len("eval_"):]
    metric_suffix = None
    if base_metric == "loss":
        metric_suffix = LOSS_LABEL
    elif base_metric == "perplexity":
        metric_suffix = PPL_LABEL
    if role_label and metric_suffix:
        return f"{prefix}{role_label} {metric_suffix}"
    if role_label:
        return f"{prefix}{role_label} {base_metric.replace('_', ' ').title()}"
    return f"{prefix}{base_metric.replace('_', ' ').title()}"


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description="Plot eval perplexity vs adapter ||AB||_F for LoRA sweep checkpoints."
    )
    parser.add_argument(
        "sweep_dir",
        type=Path,
        help="Path to the sweep directory (e.g. sweep_1614967).",
    )
    parser.add_argument(
        "--metric",
        default="eval_perplexity",
        choices=[
            "eval_loss",
            "eval_perplexity",
            "heldout_eval_loss",
            "heldout_eval_perplexity",
            "delta_eval_loss",
            "delta_eval_perplexity",
            "delta_heldout_eval_loss",
            "delta_heldout_eval_perplexity",
        ],
        help="Metric to plot on the y-axis.",
    )
    parser.add_argument(
        "--output",
        type=Path,
        default=None,
        help="Output path for the plot PNG.",
    )
    parser.add_argument(
        "--csv",
        type=Path,
        default=None,
        help="Optional path to save the computed points as CSV.",
    )
    return parser.parse_args()


def load_eval_metric(metrics_path: Path, metric: str) -> Dict[int, float]:
    data = json.loads(metrics_path.read_text())
    records = data.get("forgetting_curve") or data.get("records") or []
    metric_by_step: Dict[int, float] = {}
    for record in records:
        step = record.get("step")
        value = record.get(metric)
        if step is None or value is None:
            continue
        metric_by_step[int(step)] = float(value)
    if not metric_by_step:
        raise ValueError(f"No {metric} values found in {metrics_path}")
    return metric_by_step


def nearest_step(target_step: int, steps: List[int]) -> int:
    if not steps:
        raise ValueError("No evaluation steps available.")
    if target_step <= steps[0]:
        return steps[0]
    if target_step >= steps[-1]:
        return steps[-1]
    left, right = 0, len(steps) - 1
    while left <= right:
        mid = (left + right) // 2
        if steps[mid] == target_step:
            return steps[mid]
        if steps[mid] < target_step:
            left = mid + 1
        else:
            right = mid - 1
    before = steps[right]
    after = steps[left]
    return before if target_step - before <= after - target_step else after


def _pattern_value(patterns: Dict[str, float], key: str) -> float | None:
    for pattern, value in patterns.items():
        try:
            if re.search(pattern, key):
                return float(value)
        except re.error:
            if pattern in key:
                return float(value)
    return None


def load_adapter_config(adapter_config_path: Path) -> dict:
    if not adapter_config_path.exists():
        raise FileNotFoundError(f"Missing {adapter_config_path}")
    return json.loads(adapter_config_path.read_text())


def adapter_scale_for_module(adapter_config: dict, module_key: str) -> float:
    default_r = adapter_config.get("r")
    if default_r in (None, 0):
        raise ValueError("Adapter config missing valid rank (r).")
    default_alpha = adapter_config.get("lora_alpha", 1.0)
    alpha_pattern = adapter_config.get("alpha_pattern") or {}
    rank_pattern = adapter_config.get("rank_pattern") or {}
    alpha_value = _pattern_value(alpha_pattern, module_key) or default_alpha
    rank_value = _pattern_value(rank_pattern, module_key) or default_r
    if rank_value == 0:
        raise ValueError("Rank pattern resolved to zero.")
    use_rslora = bool(adapter_config.get("use_rslora", False))
    if use_rslora:
        return float(alpha_value) / math.sqrt(float(rank_value))
    return float(alpha_value) / float(rank_value)


def compute_adapter_norm(adapter_path: Path, adapter_config: dict) -> Tuple[float, float]:
    total_sq_unscaled = 0.0
    total_sq_scaled = 0.0
    with safe_open(adapter_path, framework="numpy") as tensors:
        pairs: Dict[str, Dict[str, str]] = defaultdict(dict)
        for key in tensors.keys():
            if key.endswith("lora_A.weight"):
                pairs[key[: -len("lora_A.weight")]]["A"] = key
            elif key.endswith("lora_B.weight"):
                pairs[key[: -len("lora_B.weight")]]["B"] = key

        for pair in pairs.values():
            if "A" not in pair or "B" not in pair:
                continue
            a = np.asarray(tensors.get_tensor(pair["A"]), dtype=np.float32)
            b = np.asarray(tensors.get_tensor(pair["B"]), dtype=np.float32)
            # ||BA||_F^2 = trace(B^T B A A^T) = sum(B^T B * A A^T).
            bt_b = b.T @ b
            a_at = a @ a.T
            ba_sq = float(np.sum(bt_b * a_at))
            module_key = pair["A"][: -len("lora_A.weight")]
            scale = adapter_scale_for_module(adapter_config, module_key)
            total_sq_unscaled += ba_sq
            total_sq_scaled += (scale * scale) * ba_sq
    return (
        math.sqrt(max(total_sq_unscaled, 0.0)),
        math.sqrt(max(total_sq_scaled, 0.0)),
    )


def default_adapter_scale(adapter_config: dict) -> Tuple[str, float | None]:
    default_r = adapter_config.get("r")
    default_alpha = adapter_config.get("lora_alpha", 1.0)
    alpha_pattern = adapter_config.get("alpha_pattern") or {}
    rank_pattern = adapter_config.get("rank_pattern") or {}
    use_rslora = bool(adapter_config.get("use_rslora", False))
    scale_mode = "alpha/sqrt(r)" if use_rslora else "alpha/r"
    if alpha_pattern or rank_pattern:
        return scale_mode, None
    if default_r in (None, 0):
        return scale_mode, None
    if use_rslora:
        return scale_mode, float(default_alpha) / math.sqrt(float(default_r))
    return scale_mode, float(default_alpha) / float(default_r)


def collect_points(sweep_dir: Path, metric: str) -> pd.DataFrame:
    results: List[dict] = []
    rank_dirs = sorted(
        [p for p in sweep_dir.iterdir() if p.is_dir() and RANK_DIR_RE.match(p.name)],
        key=lambda p: int(p.name[1:]),
    )
    if not rank_dirs:
        raise ValueError(f"No rank directories found in {sweep_dir}")

    for rank_dir in rank_dirs:
        metrics_path = rank_dir / "metrics.json"
        if not metrics_path.exists():
            raise FileNotFoundError(f"Missing metrics.json in {rank_dir}")
        metric_by_step = load_eval_metric(metrics_path, metric)
        eval_steps = sorted(metric_by_step.keys())

        adapter_dir = rank_dir / "adapter"
        checkpoint_dirs = sorted(
            [p for p in adapter_dir.glob("checkpoint-*") if p.is_dir()],
            key=lambda p: int(p.name.split("-")[-1]),
        )
        if not checkpoint_dirs:
            raise ValueError(f"No checkpoints found in {adapter_dir}")

        rank_value = int(rank_dir.name[1:])
        for checkpoint_dir in checkpoint_dirs:
            step = int(checkpoint_dir.name.split("-")[-1])
            eval_step = nearest_step(step, eval_steps)
            adapter_path = checkpoint_dir / "adapter_model.safetensors"
            adapter_config_path = checkpoint_dir / "adapter_config.json"
            if not adapter_path.exists():
                raise FileNotFoundError(f"Missing {adapter_path}")
            adapter_config = load_adapter_config(adapter_config_path)
            scale_mode, scale_value = default_adapter_scale(adapter_config)
            print(f"Computing norm for {rank_dir.name} step {step}...", flush=True)
            adapter_norm_unscaled, adapter_norm = compute_adapter_norm(
                adapter_path, adapter_config
            )
            results.append(
                {
                    "rank": rank_value,
                    "checkpoint_step": step,
                    "eval_step": eval_step,
                    "eval_step_offset": step - eval_step,
                    "metric": metric_by_step[eval_step],
                    "adapter_norm": adapter_norm,
                    "adapter_norm_unscaled": adapter_norm_unscaled,
                    "adapter_scale_mode": scale_mode,
                    "adapter_scale": scale_value,
                }
            )

    return pd.DataFrame(results)


def plot_points(df: pd.DataFrame, metric_label: str, output: Path) -> None:
    fig, ax = plt.subplots(figsize=(9, 6))
    legend_handles: List[Line2D] = []

    for idx, (rank, group) in enumerate(df.groupby("rank", sort=True)):
        group = group.sort_values("checkpoint_step")
        cmap = gradient_cmap(RANK_BASE_COLORS[idx % len(RANK_BASE_COLORS)])
        point_colors = [
            cmap(v) for v in np.linspace(PROGRESS_MIN, PROGRESS_MAX, len(group))
        ]
        ax.scatter(
            group["adapter_norm"],
            group["metric"],
            c=point_colors,
            s=42,
            edgecolor="none",
        )
        if len(group) > 1:
            ax.plot(
                group["adapter_norm"],
                group["metric"],
                color=cmap(0.85),
                alpha=0.3,
                linewidth=1.0,
            )
        legend_handles.append(
            Line2D(
                [0],
                [0],
                marker="o",
                color="none",
                markerfacecolor=cmap(0.85),
                markersize=8,
                label=f"r{rank}",
            )
        )

    ax.set_xlabel("Adapter update norm (||ΔW||_F)")
    ax.set_ylabel(metric_label)
    ax.set_title(f"LoRA sweep: adapter norm vs {metric_label}")
    ax.grid(True, linewidth=0.3, alpha=0.5)
    ax.legend(handles=legend_handles, frameon=False, title="Rank")
    progress_norm = colors.Normalize(vmin=PROGRESS_MIN, vmax=PROGRESS_MAX)
    progress_map = cm.ScalarMappable(norm=progress_norm, cmap=plt.get_cmap("Greys"))
    progress_map.set_array([])
    cbar = fig.colorbar(progress_map, ax=ax, fraction=0.046, pad=0.04)
    cbar.set_label("Training progress")
    cbar.set_ticks([PROGRESS_MIN, PROGRESS_MAX])
    cbar.set_ticklabels(["early", "late"])
    fig.tight_layout()

    output.parent.mkdir(parents=True, exist_ok=True)
    fig.savefig(output, dpi=200)
    plt.close(fig)


def main() -> None:
    args = parse_args()
    sweep_dir = args.sweep_dir
    if args.output is None:
        args.output = sweep_dir / "adapter_norm_vs_eval_ppl.png"
    if args.csv is None:
        args.csv = sweep_dir / "adapter_norm_vs_eval_ppl.csv"

    df = collect_points(sweep_dir, args.metric)
    df = df.sort_values(["rank", "checkpoint_step"])
    df.to_csv(args.csv, index=False)
    print(f"Saved CSV to {args.csv}")

    metric_label = metric_label_for(args.metric)
    plot_points(df, metric_label, args.output)
    print(f"Saved plot to {args.output}")


if __name__ == "__main__":
    main()
