"""Compute alignment between LoRA update subspaces and base weight singular vectors."""

from __future__ import annotations

import argparse
import csv
import json
import math
import os
import random
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Iterable, List, Optional

import matplotlib.pyplot as plt
import torch
from safetensors import safe_open

RANK_DIR_RE = re.compile(r"r\d+$")
LAYER_RE = re.compile(
    r"layers\.(\d+)\.(?:self_attn\.)?(q_proj|k_proj|v_proj|o_proj)")
BOOTSTRAP_SAMPLES = 1000
CI_LEVEL = 0.95


@dataclass
class BaseBasis:
    left: torch.Tensor
    right: torch.Tensor
    k: int
    shape: tuple[int, int]


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description=(
            "Compute alignment between LoRA update subspaces and base model singular vectors. "
            "Uses top-k singular vectors of the base weight and principal angles to the LoRA "
            "column/row subspaces."
        )
    )
    parser.add_argument(
        "sweep_dir",
        type=Path,
        help="Path to the sweep directory (e.g. sweep_1614967).",
    )
    parser.add_argument(
        "--base-model",
        type=Path,
        default=None,
        help=(
            "Path to a base model directory or model.safetensors file. "
            "If omitted, resolves from adapter_config.json via the HF cache."
        ),
    )
    parser.add_argument(
        "--hf-home",
        type=Path,
        default=None,
        help="Optional HF_HOME override (defaults to ~/.cache/huggingface).",
    )
    parser.add_argument(
        "--max-k",
        type=int,
        default=8,
        help="Maximum number of singular vectors to compare per module.",
    )
    parser.add_argument(
        "--niter",
        type=int,
        default=4,
        help="Number of power iterations for randomized SVD.",
    )
    parser.add_argument(
        "--output",
        type=Path,
        default=None,
        help="Output CSV path for per-module alignment records.",
    )
    parser.add_argument(
        "--summary",
        type=Path,
        default=None,
        help="Output CSV path for rank-level summary statistics.",
    )
    parser.add_argument(
        "--time-plot",
        type=Path,
        default=None,
        help="Output PNG path for alignment-over-time plot.",
    )
    return parser.parse_args()


def _pattern_value(patterns: Dict[str, float], key: str) -> float | None:
    for pattern, value in patterns.items():
        try:
            if re.search(pattern, key):
                return float(value)
        except re.error:
            if pattern in key:
                return float(value)
    return None


def adapter_rank_for_module(adapter_config: dict, module_key: str) -> int:
    default_r = adapter_config.get("r")
    if default_r in (None, 0):
        raise ValueError("Adapter config missing valid rank (r).")
    rank_pattern = adapter_config.get("rank_pattern") or {}
    rank_value = _pattern_value(rank_pattern, module_key) or default_r
    if rank_value == 0:
        raise ValueError("Rank pattern resolved to zero.")
    return int(rank_value)


def load_adapter_config(path: Path) -> dict:
    if not path.exists():
        raise FileNotFoundError(f"Missing {path}")
    return json.loads(path.read_text())


def resolve_hf_snapshot(model_id: str, hf_home: Optional[Path]) -> Path:
    home = hf_home or Path(os.environ.get(
        "HF_HOME", Path.home() / ".cache" / "huggingface"))
    hub_dir = home / "hub" / f"models--{model_id.replace('/', '--')}"
    if not hub_dir.exists():
        raise FileNotFoundError(
            f"HF cache not found for model {model_id} at {hub_dir}")
    ref_path = hub_dir / "refs" / "main"
    snapshot = None
    if ref_path.exists():
        commit = ref_path.read_text().strip()
        candidate = hub_dir / "snapshots" / commit
        if candidate.exists():
            snapshot = candidate
    if snapshot is None:
        snapshots = sorted((hub_dir / "snapshots").glob("*"),
                           key=lambda p: p.stat().st_mtime)
        if not snapshots:
            raise FileNotFoundError(f"No snapshots found in {hub_dir}")
        snapshot = snapshots[-1]
    return snapshot


def resolve_base_model_path(
    adapter_config: dict, base_model_path: Optional[Path], hf_home: Optional[Path]
) -> Path:
    if base_model_path is not None:
        if base_model_path.is_dir():
            model_path = base_model_path / "model.safetensors"
            if model_path.exists():
                return model_path
            raise FileNotFoundError(
                f"No model.safetensors in {base_model_path}")
        if base_model_path.is_file():
            return base_model_path
        raise FileNotFoundError(
            f"Base model path not found: {base_model_path}")

    model_id = adapter_config.get("base_model_name_or_path")
    if not model_id:
        raise ValueError("Adapter config missing base_model_name_or_path.")
    snapshot = resolve_hf_snapshot(model_id, hf_home)
    model_path = snapshot / "model.safetensors"
    if model_path.exists():
        return model_path
    shards = sorted(snapshot.glob("model-*.safetensors"))
    if shards:
        raise FileNotFoundError(
            f"Model is sharded in {snapshot}; please pass --base-model to a merged file."
        )
    raise FileNotFoundError(f"Missing model.safetensors in {snapshot}")


def collect_rank_dirs(sweep_dir: Path) -> List[Path]:
    rank_dirs = [p for p in sweep_dir.iterdir() if p.is_dir()
                 and RANK_DIR_RE.match(p.name)]
    return sorted(rank_dirs, key=lambda p: int(p.name[1:]))


def max_step_from_metrics(metrics_path: Path) -> Optional[int]:
    if not metrics_path.exists():
        return None
    data = json.loads(metrics_path.read_text())
    records = data.get("forgetting_curve") or data.get("records") or []
    steps = [record.get("step")
             for record in records if record.get("step") is not None]
    if not steps:
        return None
    return int(max(steps))


def lora_pairs(tensors) -> Dict[str, Dict[str, str]]:
    pairs: Dict[str, Dict[str, str]] = {}
    for key in tensors.keys():
        if key.endswith("lora_A.weight"):
            prefix = key[: -len("lora_A.weight")]
            pairs.setdefault(prefix, {})["A"] = key
        elif key.endswith("lora_B.weight"):
            prefix = key[: -len("lora_B.weight")]
            pairs.setdefault(prefix, {})["B"] = key
    return pairs


def _candidate_keys(module_key: str) -> List[str]:
    candidates = []
    base = module_key
    for _ in range(4):
        candidates.append(base + "weight")
        if base.startswith("base_model."):
            base = base[len("base_model."):]
        elif base.startswith("model."):
            base = base[len("model."):]
        else:
            break
    return list(dict.fromkeys(candidates))


def resolve_base_weight_key(base_keys: Iterable[str], module_key: str) -> Optional[str]:
    key_set = base_keys if isinstance(base_keys, set) else set(base_keys)
    candidates = _candidate_keys(module_key)
    for candidate in candidates:
        if candidate in key_set:
            return candidate
    base_keys_list = list(key_set)
    for candidate in candidates:
        matches = [key for key in base_keys_list if key.endswith(candidate)]
        if len(matches) == 1:
            return matches[0]
    return None


def orthonormal_basis(matrix: torch.Tensor, eps: float = 1e-6) -> Optional[torch.Tensor]:
    if matrix.numel() == 0:
        return None
    try:
        u, s, _ = torch.linalg.svd(matrix, full_matrices=False)
    except RuntimeError:
        return None
    valid_mask = s > eps
    rank = int(valid_mask.sum().item())
    if rank == 0:
        return None
    return u[:, :rank]


def top_singular_vectors(
    matrix: torch.Tensor, k: int, niter: int
) -> tuple[torch.Tensor, torch.Tensor]:
    k = min(k, min(matrix.shape))
    if k <= 0:
        raise ValueError("k must be positive for SVD.")
    if hasattr(torch, "svd_lowrank"):
        u, _, v = torch.svd_lowrank(matrix, q=k, niter=niter)
        return u, v
    u, _, vh = torch.linalg.svd(matrix, full_matrices=False)
    return u[:, :k], vh.transpose(-2, -1)[:, :k]


def principal_cosines(
    base_vectors: torch.Tensor, lora_basis: torch.Tensor
) -> torch.Tensor:
    overlap = base_vectors.T @ lora_basis
    return torch.linalg.svdvals(overlap)


def alignment_stats(
    base_vectors: torch.Tensor,
    lora_basis: torch.Tensor,
    max_k: int,
) -> Optional[dict]:
    k_base = min(max_k, base_vectors.shape[1])
    if k_base <= 0:
        return None
    base_slice = base_vectors[:, :k_base]
    lora_slice = lora_basis
    cosines = principal_cosines(base_slice, lora_slice)
    k = int(cosines.numel())
    return {
        "k": k,
        "k_base": int(k_base),
        "mean_cos": float(cosines.mean().item()),
        "median_cos": float(cosines.median().item()),
        "min_cos": float(cosines.min().item()),
        "max_cos": float(cosines.max().item()),
    }


def module_metadata(module_key: str) -> tuple[Optional[int], Optional[str]]:
    match = LAYER_RE.search(module_key)
    if not match:
        return None, None
    return int(match.group(1)), match.group(2)


def write_csv(path: Path, records: List[dict]) -> None:
    if not records:
        raise ValueError("No records to write.")
    path.parent.mkdir(parents=True, exist_ok=True)
    with path.open("w", newline="") as handle:
        writer = csv.DictWriter(handle, fieldnames=list(records[0].keys()))
        writer.writeheader()
        writer.writerows(records)


def summarize_records(records: List[dict]) -> List[dict]:
    summary: Dict[tuple, dict] = {}
    for record in records:
        key = (record["rank"], record["side"])
        entry = summary.setdefault(
            key,
            {
                "rank": record["rank"],
                "side": record["side"],
                "count": 0,
                "mean_cos_sum": 0.0,
                "median_cos_sum": 0.0,
            },
        )
        entry["count"] += 1
        entry["mean_cos_sum"] += record["mean_cos"]
        entry["median_cos_sum"] += record["median_cos"]
    results = []
    for entry in summary.values():
        count = entry["count"]
        results.append(
            {
                "rank": entry["rank"],
                "side": entry["side"],
                "count": count,
                "mean_cos": entry["mean_cos_sum"] / count,
                "median_cos": entry["median_cos_sum"] / count,
            }
        )
    return sorted(results, key=lambda row: (row["rank"], row["side"]))


def summarize_alignment_over_time(records: List[dict]) -> List[dict]:
    summary: Dict[tuple, dict] = {}
    for record in records:
        step = record.get("checkpoint_step")
        if step is None:
            continue
        side = record.get("side")
        if side == "left":
            dim = record.get("out_features")
        elif side == "right":
            dim = record.get("in_features")
        else:
            continue
        if not dim:
            continue
        lora_rank = record.get("lora_rank") or record["rank"]
        k_base = record.get("k_base", 0)
        target_subspace_dim = max(int(lora_rank), int(k_base))
        random_baseline = (float(target_subspace_dim) / float(dim)) ** 0.5
        if random_baseline == 0.0:
            continue
        normalized_alignment = float(record["mean_cos"]) / random_baseline
        key = (record["rank"], int(step))
        entry = summary.setdefault(
            key,
            {
                "rank": record["rank"],
                "checkpoint_step": int(step),
                "count": 0,
                "normalized_sum": 0.0,
                "values": [],
            },
        )
        entry["count"] += 1
        entry["normalized_sum"] += normalized_alignment
        entry["values"].append(normalized_alignment)
    results = []
    for entry in summary.values():
        count = entry["count"]
        values = entry["values"]
        mean, ci_low, ci_high, p_value = bootstrap_stats(
            values,
            baseline=1.0,
            n_boot=BOOTSTRAP_SAMPLES,
            ci=CI_LEVEL,
            seed=int(entry["rank"]) * 1000003 + int(entry["checkpoint_step"]),
        )
        results.append(
            {
                "rank": entry["rank"],
                "checkpoint_step": entry["checkpoint_step"],
                "normalized_alignment": mean,
                "ci_low": ci_low,
                "ci_high": ci_high,
                "p_value": p_value,
                "count": count,
            }
        )
    return sorted(results, key=lambda row: (row["rank"], row["checkpoint_step"]))


def bootstrap_stats(
    values: List[float],
    baseline: float = 1.0,
    n_boot: int = BOOTSTRAP_SAMPLES,
    ci: float = CI_LEVEL,
    seed: int = 0,
) -> tuple[float, float, float, float]:
    if not values:
        raise ValueError("No values provided for bootstrap statistics.")
    mean = sum(values) / len(values)
    if len(values) < 2 or n_boot <= 1:
        return mean, mean, mean, 1.0

    rng = random.Random(seed)
    n = len(values)
    boot_means = []
    for _ in range(n_boot):
        sample_sum = 0.0
        for _ in range(n):
            sample_sum += values[rng.randrange(n)]
        boot_means.append(sample_sum / n)

    boot_means.sort()
    alpha = (1.0 - ci) / 2.0
    ci_low = percentile(boot_means, alpha)
    ci_high = percentile(boot_means, 1.0 - alpha)

    below = sum(1 for m in boot_means if m <= baseline) / n_boot
    above = sum(1 for m in boot_means if m >= baseline) / n_boot
    p_value = 2.0 * min(below, above)
    p_value = max(0.0, min(1.0, p_value))
    return mean, ci_low, ci_high, p_value


def percentile(sorted_values: List[float], q: float) -> float:
    if not sorted_values:
        raise ValueError("Percentile requires non-empty list.")
    if q <= 0.0:
        return sorted_values[0]
    if q >= 1.0:
        return sorted_values[-1]
    pos = q * (len(sorted_values) - 1)
    lo = int(math.floor(pos))
    hi = int(math.ceil(pos))
    if lo == hi:
        return sorted_values[lo]
    frac = pos - lo
    return sorted_values[lo] * (1.0 - frac) + sorted_values[hi] * frac


def plot_alignment_over_time(series: List[dict], output_path: Path) -> None:
    if not series:
        raise ValueError("No alignment-over-time data to plot.")

    ranks = sorted({row["rank"] for row in series})
    rank_colors = [
        "#1f77b4",
        "#2ca02c",
        "#9467bd",
        "#d62728",
        "#8c564b",
        "#e377c2",
    ]

    fig, ax = plt.subplots(figsize=(8, 5))
    for idx, rank in enumerate(ranks):
        rows = [row for row in series if row["rank"] == rank]
        rows.sort(key=lambda row: row["checkpoint_step"])
        steps = [row["checkpoint_step"] for row in rows]
        values = [row["normalized_alignment"] for row in rows]
        ci_low = [row["ci_low"] for row in rows]
        ci_high = [row["ci_high"] for row in rows]
        ax.plot(
            steps,
            values,
            marker="o",
            linewidth=2,
            color=rank_colors[idx % len(rank_colors)],
            label=f"r{rank}",
        )
        ax.fill_between(
            steps,
            ci_low,
            ci_high,
            color=rank_colors[idx % len(rank_colors)],
            alpha=0.15,
            linewidth=0,
        )

    ax.set_xlabel("Training step")
    ax.axhline(y=1.0, color="gray", linestyle="--", alpha=0.5, label="Random chance")
    ax.set_ylabel("Normalized alignment (x times random)")
    ax.set_title("Alignment efficiency (normalized by rank)")
    ax.grid(True, linewidth=0.3, alpha=0.5)
    ax.legend(frameon=False, title="Rank")

    output_path.parent.mkdir(parents=True, exist_ok=True)
    fig.tight_layout()
    fig.savefig(output_path, dpi=300)
    plt.close(fig)


def main() -> None:
    args = parse_args()
    sweep_dir = args.sweep_dir
    rank_dirs = collect_rank_dirs(sweep_dir)
    if not rank_dirs:
        raise ValueError(f"No rank directories found in {sweep_dir}")

    first_config = load_adapter_config(
        rank_dirs[0] / "adapter" / "adapter_config.json")
    base_model_path = resolve_base_model_path(
        first_config, args.base_model, args.hf_home)
    print(f"Using base model weights at {base_model_path}")

    base_basis_cache: Dict[str, BaseBasis] = {}
    records: List[dict] = []

    with safe_open(base_model_path, framework="pt", device="cpu") as base_tensors:
        base_keys = set(base_tensors.keys())
        for rank_dir in rank_dirs:
            metrics_path = rank_dir / "metrics.json"
            adapter_dir = rank_dir / "adapter"
            rank_value = int(rank_dir.name[1:])

            checkpoint_dirs = sorted(
                [p for p in adapter_dir.glob("checkpoint-*") if p.is_dir()],
                key=lambda p: int(p.name.split("-")[-1]),
            )
            if checkpoint_dirs:
                checkpoints = [
                    (checkpoint_dir, int(checkpoint_dir.name.split("-")[-1]))
                    for checkpoint_dir in checkpoint_dirs
                ]
            else:
                fallback_step = max_step_from_metrics(metrics_path)
                if fallback_step is None:
                    fallback_step = 0
                checkpoints = [(adapter_dir, fallback_step)]

            for checkpoint_dir, checkpoint_step in checkpoints:
                adapter_path = checkpoint_dir / "adapter_model.safetensors"
                adapter_config_path = checkpoint_dir / "adapter_config.json"
                if not adapter_path.exists():
                    raise FileNotFoundError(f"Missing {adapter_path}")
                if not adapter_config_path.exists():
                    adapter_config_path = adapter_dir / "adapter_config.json"
                adapter_config = load_adapter_config(adapter_config_path)

                with safe_open(adapter_path, framework="pt", device="cpu") as adapter_tensors:
                    pairs = lora_pairs(adapter_tensors)
                    for module_key, pair in pairs.items():
                        if "A" not in pair or "B" not in pair:
                            continue
                        base_key = resolve_base_weight_key(
                            base_keys, module_key)
                        if base_key is None:
                            print(
                                f"Skipping {module_key} (base weight not found)")
                            continue
                        if base_key not in base_basis_cache:
                            weight = base_tensors.get_tensor(
                                base_key).to(torch.float32)
                            u, v = top_singular_vectors(
                                weight, args.max_k, args.niter)
                            base_basis_cache[base_key] = BaseBasis(
                                left=u,
                                right=v,
                                k=u.shape[1],
                                shape=tuple(weight.shape),
                            )
                        basis = base_basis_cache[base_key]

                        a = adapter_tensors.get_tensor(
                            pair["A"]).to(torch.float32)
                        b = adapter_tensors.get_tensor(
                            pair["B"]).to(torch.float32)
                        right_basis = orthonormal_basis(a.T)
                        left_basis = orthonormal_basis(b)
                        if right_basis is None or left_basis is None:
                            continue

                        layer, proj = module_metadata(module_key)
                        lora_rank = adapter_rank_for_module(
                            adapter_config, module_key)
                        left_stats = alignment_stats(
                            basis.left, left_basis, args.max_k)
                        if left_stats:
                            records.append(
                                {
                                    "rank": rank_value,
                                    "checkpoint_step": checkpoint_step,
                                    "module_key": module_key.rstrip("."),
                                    "base_weight_key": base_key,
                                    "layer": layer,
                                    "proj": proj,
                                    "side": "left",
                                    "k": left_stats["k"],
                                    "k_base": left_stats["k_base"],
                                    "lora_rank": int(min(lora_rank, left_basis.shape[1])),
                                    "out_features": basis.shape[0],
                                    "in_features": basis.shape[1],
                                    "mean_cos": left_stats["mean_cos"],
                                    "median_cos": left_stats["median_cos"],
                                    "min_cos": left_stats["min_cos"],
                                    "max_cos": left_stats["max_cos"],
                                }
                            )
                        right_stats = alignment_stats(
                            basis.right, right_basis, args.max_k)
                        if right_stats:
                            records.append(
                                {
                                    "rank": rank_value,
                                    "checkpoint_step": checkpoint_step,
                                    "module_key": module_key.rstrip("."),
                                    "base_weight_key": base_key,
                                    "layer": layer,
                                    "proj": proj,
                                    "side": "right",
                                    "k": right_stats["k"],
                                    "k_base": right_stats["k_base"],
                                    "lora_rank": int(min(lora_rank, right_basis.shape[1])),
                                    "out_features": basis.shape[0],
                                    "in_features": basis.shape[1],
                                    "mean_cos": right_stats["mean_cos"],
                                    "median_cos": right_stats["median_cos"],
                                    "min_cos": right_stats["min_cos"],
                                    "max_cos": right_stats["max_cos"],
                                }
                            )

    if args.output is None:
        args.output = sweep_dir / "lora_eigen_alignment.csv"
    write_csv(args.output, records)
    print(f"Saved alignment CSV to {args.output}")

    if args.summary is None:
        args.summary = sweep_dir / "lora_eigen_alignment_summary.csv"
    summary_records = summarize_records(records)
    write_csv(args.summary, summary_records)
    print(f"Saved summary CSV to {args.summary}")

    time_series = summarize_alignment_over_time(records)
    plot_path = args.time_plot or sweep_dir / "lora_alignment_over_time.png"
    plot_alignment_over_time(time_series, plot_path)
    print(f"Saved alignment-over-time plot to {plot_path}")


if __name__ == "__main__":
    torch.set_grad_enabled(False)
    main()
