"""Load saved Q/K and produce analysis plots per layer."""

from __future__ import annotations

import argparse
import os
from typing import List
import sys

import torch
import matplotlib.pyplot as plt

# Configure Python path so we can import from both the long_context_eval root (for src/*)
# and the repository root (for top-level utilities like qk_collision_plot.py)
_CUR_DIR = os.path.dirname(os.path.abspath(__file__))
_EVAL_ROOT = os.path.abspath(os.path.join(_CUR_DIR, os.pardir))  # .../long_context_eval
_REPO_ROOT = os.path.abspath(os.path.join(_EVAL_ROOT, os.pardir))  # .../sampling
for _path in (_EVAL_ROOT, _REPO_ROOT):
    if _path not in sys.path:
        sys.path.insert(0, _path)

# Global plotting style
plt.rcParams.update({
    "font.size": 10,
    "axes.titlesize": 12,
    "axes.labelsize": 11,
    "xtick.labelsize": 9,
    "ytick.labelsize": 9,
    "legend.fontsize": 9,
    "figure.figsize": (3.5, 2.5),
})

from src.analysis.vector_analyzer import plot_QK_hist, plot_QK_pca2d
from qk_collision_plot import extract_qk_from_gqa2  # single-head + mapped KV head


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Analyze saved Q/K per layer")
    parser.add_argument("--model_key", type=str, required=True)
    parser.add_argument("--dataset", type=str, required=True)
    parser.add_argument("--exp_name", type=str, default="extract_attention")
    parser.add_argument("--layers", type=int, nargs="*", default=None, help="Optional explicit layer indices")
    parser.add_argument("--sample_idx", type=int, default=0, help="Which sample folder to analyze (sample_XXXX)")
    parser.add_argument("--out_dir", type=str, default=None, help="Optional override for output dir")
    parser.add_argument("--q_head", type=int, default=0, help="Q head index to match runtime single-head analysis")
    return parser.parse_args()


def main() -> None:
    args = parse_args()
    # Base directory for data/figures is the long_context_eval root (parent of this file)
    base_dir = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), os.pardir))
    data_dir = os.path.join(base_dir, "data", args.model_key, args.dataset, args.exp_name, f"sample_{args.sample_idx:04d}")
    default_fig_dir = os.path.join(base_dir, "figures", args.model_key, args.dataset, args.exp_name, f"sample_{args.sample_idx:04d}")
    out_dir = args.out_dir or default_fig_dir
    os.makedirs(out_dir, exist_ok=True)

    # Discover layers by saved files if not provided
    layer_indices: List[int]
    if args.layers is not None and len(args.layers) > 0:
        layer_indices = list(args.layers)
    else:
        layer_indices = []
        for fn in os.listdir(data_dir):
            if (fn.startswith("q_layer") or fn.startswith("k_layer")) and fn.endswith(".pt"):
                try:
                    base = os.path.splitext(fn)[0]  # e.g., q_layer03
                    idx_str = base.split("layer")[-1]
                    idx = int(idx_str)
                    layer_indices.append(idx)
                except Exception:
                    pass
        layer_indices.sort()

    for layer_idx in layer_indices:
        q_path = os.path.join(data_dir, f"q_layer{layer_idx:02d}.pt")
        k_path = os.path.join(data_dir, f"k_layer{layer_idx:02d}.pt")
        if not (os.path.exists(q_path) and os.path.exists(k_path)):
            print(f"Skip layer {layer_idx}: missing files")
            continue
        q = torch.load(q_path)["q"]
        k = torch.load(k_path)["k"]

        # Match runtime: use single Q head mapped to its KV head (GQA-aware)
        Q_flat, K_flat, _meta = extract_qk_from_gqa2(
            q, k, q_head=args.q_head, map_gqa=True, use_all_kv_heads=False
        )

        fig_pca = plot_QK_pca2d(Q_flat, K_flat, title=f"{args.model_key} {args.dataset} layer {layer_idx}")
        fig_pca.savefig(os.path.join(out_dir, f"pca_layer{layer_idx:02d}.png"), dpi=300, bbox_inches="tight")
        fig_pca.clf()
        print(f"Saved analysis for layer {layer_idx}")


if __name__ == "__main__":
    main()


