#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Gradient similarity analysis for LABRAM finetune models (classification task only).

This mirrors `tools/allee/grad/cbramod_grad_ft_cls_ds.py` but adapts the entire pipeline to the
LABRAM baseline stack:
- Load a LABRAM config + checkpoints (encoder + dataset heads)
- Sample batches dataset-by-dataset using the LABRAM adapter/dataloaders
- Backprop CE to collect encoder-only gradients, grouped by major submodules
- Aggregate gradient samples per dataset and report similarities via heatmaps,
 subspace affinity, PCA clouds, and temporal evolution plots

The implementation intentionally stays close to the Allee variant for parity while
respecting LABRAM-specific components (config, model, dataloaders, checkpoints).
"""

import argparse
import csv
import json
import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple

import datasets
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from omegaconf import OmegaConf

from baseline.analysis.grad import (
 GradStore,
 compute_axis_heatmap,
 compute_subspace_affinity,
 compute_2d_embeddings,
 compute_axis_shared_specific,
 r_sweep_shared_energy_axes,
)
from baseline.analysis.plot import (
 plot_cloud2d,
 plot_cloud2d_kde,
 plot_heatmaps,
 plot_series_with_ema,
 plot_subspace_affinity,
 plot_r_sweep_axes,
 plot_shared_bars_multi_rank_axes,
)
from baseline.analysis.utils import (
 HashingProjector,
 flatten_from_grads,
 make_run_base_dir,
 set_seeds,
 strip_all_suffixes,
 export_group_matrix_csv,
)
from baseline.labram.labram_config import LabramConfig
DESIRED_DATASET_ORDER = ['tuab', 'tuev', 'seed', 'hmc', 'workload', 'tusl']


def _order_dataset_names(names: List[str]) -> List[str]:
 order = {name.lower(): idx for idx, name in enumerate(DESIRED_DATASET_ORDER)}
 return sorted(names, key=lambda n: order.get(n.lower(), len(order)))

from baseline.labram.labram_adapter import LabramDatasetAdapter
from baseline.labram.labram_trainer import LabramUnifiedModel
from baseline.labram.model import NeuralTransformer
from baseline.abstract.classifier import MultiHeadClassifier
from common.path import get_conf_file_path
from data.processor.wrapper import get_dataset_n_class, load_concat_eeg_datasets
from training.distributed.loader import DistributedGroupBatchSampler

# -----------------------------
# CLI and run configuration
# -----------------------------

@dataclass
class RunArgs:
 conf: str
 ckpts: List[str]
 steps: int = 512
 batch_size: int = 8
 seed: int = 100
 proj_dim: int = 512
 subspace_k: List[int] = field(default_factory=lambda: [3])
 max_cloud_points: int = 512
 split: str = "train"
 lr: float = 5e-5
 log_interval: int = 25
 out_dir: Optional[str] = None
 run_name: Optional[str] = None
 ema_beta: float = 0.9
 blocks_as_one: bool = False


# -----------------------------
# Config, dataloaders, models
# -----------------------------

def load_config(args: RunArgs) -> LabramConfig:
 conf_path = get_conf_file_path(args.conf)
 cfg_raw = OmegaConf.load(conf_path)
 cfg = LabramConfig.model_validate(OmegaConf.to_container(cfg_raw, resolve=True))
 cfg.data.batch_size = int(args.batch_size)
 cfg.conf_file = conf_path
 return cfg


def build_ds_dict(cfg: LabramConfig) -> Dict[str, int]:
 ds_map = cfg.data.datasets or {}
 return {name: get_dataset_n_class(name, conf) for name, conf in ds_map.items()}


def build_loaders(cfg: LabramConfig, split: str, seed: int) -> Tuple[List[str], List[DataLoader]]:
 split_enum = datasets.Split.TRAIN if split.lower().startswith('train') else datasets.Split.VALIDATION
 ds_names = _order_dataset_names(list(cfg.data.datasets.keys()))
 ds_configs = [cfg.data.datasets[name] for name in ds_names]
 loader_list: List[DataLoader] = []

 for name, conf in zip(ds_names, ds_configs):
 dataset, _ = load_concat_eeg_datasets(
 dataset_names=[name],
 builder_configs=[conf],
 split=split_enum,
 add_ds_name=True,
 cast_label=True,
 fs=256,
 )
 adapter = LabramDatasetAdapter(dataset=dataset, dataset_names=[name], dataset_configs=[conf])
 sampler = DistributedGroupBatchSampler(
 dataset=dataset,
 batch_size=int(cfg.data.batch_size),
 num_replicas=1,
 rank=0,
 shuffle=True,
 seed=int(seed),
 drop_last=False,
 )
 pin_memory = torch.cuda.is_available()
 if cfg.data.num_workers > 0:
 loader = DataLoader(
 adapter,
 batch_sampler=sampler,
 num_workers=cfg.data.num_workers,
 pin_memory=pin_memory,
 persistent_workers=True,
 prefetch_factor=2,
 )
 else:
 loader = DataLoader(
 adapter,
 batch_sampler=sampler,
 num_workers=0,
 pin_memory=pin_memory,
 )
 loader_list.append(loader)

 return ds_names, loader_list


def build_model(cfg: LabramConfig, ds_dict: Dict[str, int], device: torch.device) -> LabramUnifiedModel:
 m = cfg.model
 encoder = NeuralTransformer(
 EEG_size=m.eeg_size,
 patch_size=m.patch_size,
 in_chans=m.in_chans,
 out_chans=m.out_chans,
 num_classes=0,
 embed_dim=m.embed_dim,
 depth=m.depth,
 num_heads=m.num_heads,
 mlp_ratio=m.mlp_ratio,
 qkv_bias=m.qkv_bias,
 qk_norm=torch.nn.LayerNorm,
 qk_scale=None,
 drop_rate=m.dropout_rate,
 attn_drop_rate=m.attn_dropout_rate,
 drop_path_rate=m.drop_path_rate,
 norm_layer=torch.nn.LayerNorm,
 init_values=m.init_values,
 use_abs_pos_emb=m.use_abs_pos_emb,
 use_rel_pos_bias=m.use_rel_pos_bias,
 use_shared_rel_pos_bias=m.use_shared_rel_pos_bias,
 use_mean_pooling=m.use_mean_pooling,
 init_scale=m.init_scale,
 )

 classifier = MultiHeadClassifier(
 embed_dim=m.embed_dim,
 head_configs={ds: n_cls for ds, n_cls in ds_dict.items()},
 head_cfg=m.classifier_head,
 t_sne=m.t_sne,
 )

 model = LabramUnifiedModel(encoder, classifier, grad_cam=m.grad_cam)
 model = model.to(device)
 model.train()
 return model


def load_model_weights(model: LabramUnifiedModel, ckpt_path: str, device: torch.device):
 if not os.path.exists(ckpt_path):
 raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")
 ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)

 if 'model_state_dict' in ckpt:
 state = {k.replace('module.', ''): v for k, v in ckpt['model_state_dict'].items()}
 missing, unexpected = model.load_state_dict(state, strict=False)
 elif 'model' in ckpt:
 encoder_state = {}
 for k, v in ckpt['model'].items():
 if k.startswith('student.'):
 encoder_state[k[len('student.'):]] = v
 missing, unexpected = model.encoder.load_state_dict(encoder_state, strict=False)
 else:
 raise RuntimeError(f"Unrecognized checkpoint format for {ckpt_path}")

 if missing:
 print(f"[ckpt] missing keys: {missing}")
 if unexpected:
 print(f"[ckpt] unexpected keys: {unexpected}")

 model.train()


# -----------------------------
# Parameter grouping & gradient collection helpers
# -----------------------------

def labram_encoder_group_of(
 name: str,
 num_blocks: int,
 bucket_size: int = 6,
 merge_blocks: bool = False,
) -> Optional[str]:
 if not name.startswith('encoder.'):
 return 'non_encoder'
 sub = name[len('encoder.'):]
 if sub.startswith('head.') or sub.startswith('norm') or sub.startswith('fc_norm'):
 return None
 if sub in ('cls_token', 'pos_embed', 'time_embed'):
 return 'patch_embed'
 if sub.startswith('patch_embed.'):
 return 'patch_embed'
 if sub.startswith('blocks.'):
 if merge_blocks:
 return 'backbone'
 parts = sub.split('.')
 if len(parts) >= 2 and parts[1].isdigit():
 idx = int(parts[1])
 if bucket_size <= 0 or num_blocks <= 0:
 return f'blocks_{idx:02d}'
 bucket = idx // bucket_size
 start = bucket * bucket_size
 end = min(num_blocks - 1, (bucket + 1) * bucket_size - 1)
 return f'blocks_{start:02d}_{end:02d}'
 return 'blocks_misc'
 return 'other'


def build_encoder_param_order_labram(
 model: torch.nn.Module,
 bucket_size: int = 6,
 merge_blocks: bool = False,
):
 encoder = getattr(model, 'encoder', None)
 num_blocks = 0
 if encoder is not None and hasattr(encoder, 'blocks'):
 blocks = getattr(encoder, 'blocks')
 if isinstance(blocks, torch.nn.ModuleList):
 num_blocks = len(blocks)

 names: List[str] = []
 groups: Dict[str, List[str]] = {}
 sizes: Dict[str, int] = {}

 for name, p in model.named_parameters():
 if not name.startswith('encoder.'):
 continue
 g = labram_encoder_group_of(
 name,
 num_blocks=num_blocks,
 bucket_size=bucket_size,
 merge_blocks=merge_blocks,
 )
 if g in (None, 'non_encoder'):
 continue
 names.append(name)
 groups.setdefault(g, []).append(name)
 sizes[name] = int(p.numel())

 names = sorted(names)
 for g in list(groups.keys()):
 groups[g] = sorted(groups[g])
 name_to_idx = {n: i for i, n in enumerate(names)}
 return names, name_to_idx, groups, sizes


# -----------------------------
# Main loop helpers
# -----------------------------

def cycle_loader(loader: DataLoader):
 while True:
 for batch in loader:
 yield batch


def move_batch_to_device(batch: Dict[str, Any], device: torch.device) -> Dict[str, Any]:
 out = {}
 for k, v in batch.items():
 if torch.is_tensor(v):
 out[k] = v.to(device, non_blocking=True)
 else:
 out[k] = v
 return out


def run_one_ckpt(
 args: RunArgs,
 cfg: LabramConfig,
 device: torch.device,
 ds_list: List[str],
 loader_list: List[DataLoader],
 ds_dict: Dict[str, int],
 base_run_dir: Path,
 ckpt_path: str,
):
 ckpt_short = strip_all_suffixes(Path(ckpt_path))
 multi_ckpt_run = len(args.ckpts) > 1
 out_dir = base_run_dir / ckpt_short if multi_ckpt_run else base_run_dir
 out_dir.mkdir(parents=True, exist_ok=True)
 print(f"\n=== Running for ckpt: {ckpt_path} -> {out_dir} ===")

 amp_enabled = device.type == 'cuda'
 amp_dtype = torch.bfloat16
 projector = HashingProjector(args.proj_dim, int(args.seed))

 ds_names: List[str] = _order_dataset_names(list(cfg.data.datasets.keys()))
 store: Optional[GradStore] = None

 steps = int(args.steps)
 for ld_idx, loader in enumerate(loader_list):
 model = build_model(cfg, ds_dict, device)
 load_model_weights(model, ckpt_path, device)
 optim = torch.optim.Adam(model.parameters(), lr=float(args.lr))

 names, name_to_idx, group_names, size_map = build_encoder_param_order_labram(
 model,
 merge_blocks=bool(args.blocks_as_one),
 )
 param_list = [dict(model.named_parameters())[n] for n in names]
 sizes_in_order = [size_map[n] for n in names]

 if store is None:
 store = GradStore(sorted(group_names.keys()), ds_names)
 try:
 with open(out_dir / 'group_params.json', 'w') as f:
 json.dump({g: sorted(list(gns)) for g, gns in group_names.items()}, f, indent=2)
 except Exception:
 pass
 assert store is not None

 cyc = cycle_loader(loader)
 ds_name_cur = ds_names[ld_idx] if ld_idx < len(ds_names) else ds_list[ld_idx]
 for step in range(steps):
 batch = next(cyc)
 batch_gpu = move_batch_to_device(batch, device)

 optim.zero_grad(set_to_none=True)
 with torch.autocast(device_type=device.type, dtype=amp_dtype, enabled=amp_enabled):
 logits = model(batch_gpu)
 labels = batch_gpu['label'].to(device, dtype=torch.long)
 num_classes = logits.shape[-1]
 if labels.numel() > 0:
 lmin = int(labels.min().item())
 lmax = int(labels.max().item())
 if lmin < 0 or lmax >= num_classes:
 raise RuntimeError(
 f"Label out of range for dataset '{ds_name_cur}': min={lmin}, max={lmax}, num_classes={num_classes}"
 )
 loss_ce = F.cross_entropy(logits, labels, reduction='mean')

 grads = torch.autograd.grad(loss_ce, param_list, retain_graph=True, allow_unused=True)
 for g, names in group_names.items():
 idxs = [name_to_idx[n] for n in names]
 g_grads = [grads[i] for i in idxs]
 sizes_g = [sizes_in_order[i] for i in idxs]
 v = flatten_from_grads(g_grads, sizes_g, device)
 v = v / (torch.norm(v) + 1e-12)
 if args.proj_dim > 0:
 v = projector.project_and_norm(v, key=g)
 store.add(ds_name_cur, g, v.detach().cpu())

 loss_ce.backward()
 optim.step()

 if (step + 1) % max(1, int(args.log_interval)) == 0 or (step + 1) == steps:
 print(f"[ds {ld_idx+1}/{len(loader_list)}:{ds_name_cur}] step {step+1}/{steps} CE={float(loss_ce.item()):.4f}")

 assert store is not None

 fig_heatmap = str(out_dir / 'heatmap.png')
 fig_cloud = str(out_dir / 'cloud2d.png')
 fig_cloud_kde = str(out_dir / 'cloud2d_kde.png')
 fig_evo = str(out_dir / 'dataset_evolution.png')
 fig_r_sweep = str(out_dir / 'r_sweep_datasets.png')
 fig_shared_multi = str(out_dir / 'shared_bars_multi_rank_datasets.png')
 json_cfg = str(out_dir / 'run_config.json')

 heatmaps_final = compute_axis_heatmap(store)
 ds_names_plot = store.get_axes()

 print('[plot] heatmap')
 plot_heatmaps(heatmaps_final, ds_names_plot, fig_heatmap)
 try:
 export_group_matrix_csv(out_dir / 'heatmaps.csv', heatmaps_final, ds_names_plot)
 except Exception:
 pass

 subspace_figs: List[str] = []
 k_values = args.subspace_k or [3]
 multi_k = len(k_values) > 1
 print('[plot] subspace affinity')
 for k in k_values:
 subspace_final = compute_subspace_affinity(store, k=k)
 if not subspace_final:
 continue
 suffix = f'_topk_{k}' if multi_k else ''
 fig_subspace = str(out_dir / f'subspace_affinity{suffix}.png')
 plot_subspace_affinity(subspace_final, ds_names_plot, fig_subspace)
 subspace_figs.append(fig_subspace)
 try:
 export_group_matrix_csv(out_dir / f'subspace_affinity{suffix}.csv', subspace_final, ds_names_plot)
 except Exception:
 pass

 try:
 emb = compute_2d_embeddings(store, max_points=args.max_cloud_points)
 print('[plot] cloud2d')
 plot_cloud2d(emb, fig_cloud)
 print('[plot] cloud2d (kde)')
 plot_cloud2d_kde(emb, fig_cloud_kde)
 except Exception:
 pass

 try:
 shared_stats_ranks: Dict[int, Dict[str, Dict[str, Any]]] = {}
 ranks = [1, 2, 3, 4, 5]
 for r in ranks:
 stats_map, names_map = compute_axis_shared_specific(store, shared_rank=r)
 merged: Dict[str, Dict[str, Any]] = {}
 for g, st in stats_map.items():
 merged[g] = {**st, 'names': names_map.get(g)}
 shared_stats_ranks[r] = merged
 print('[plot] r_sweep (datasets as axes)')
 r_sweep = r_sweep_shared_energy_axes(store, ranks=list(range(1, 11)))
 plot_r_sweep_axes(r_sweep, fig_r_sweep, title='Gradient Shared Energy VS Rank')
 print('[plot] shared_bars_multi_rank (datasets as axes)')
 plot_shared_bars_multi_rank_axes(
 shared_stats_ranks,
 ds_names_plot,
 fig_shared_multi,
 title='Shared Energy Percentage by Dataset',
 )
 except Exception as e:
 print(f"[plot shared-specific datasets] warn: {e}")

 try:
 ds_names_sorted = store.get_axes()
 groups = store.get_groups()
 curves: Dict[str, List[float]] = {d: [] for d in ds_names_sorted}
 for d in ds_names_sorted:
 per_g = {g: store.data[g][d] for g in groups}
 length = max((len(vs) for vs in per_g.values()), default=0)
 for t in range(length):
 vals = []
 for g in groups:
 xs = per_g[g]
 if t == 0 or t >= len(xs):
 continue
 v_t = xs[t]
 m = torch.stack(xs[:t], dim=0).mean(dim=0)
 m = m / (torch.norm(m) + 1e-12)
 vtn = v_t / (torch.norm(v_t) + 1e-12)
 c = float(torch.dot(m.flatten(), vtn.flatten()).item())
 vals.append(max(-1.0, min(1.0, c)))
 curves[d].append(float(np.mean(vals))) if vals else curves[d].append(float('nan'))

 max_len = max((len(v) for v in curves.values()), default=0)
 steps_axis = list(range(1, max_len + 1))
 stride = max(1, max_len // 400)
 steps_ds = steps_axis[::stride] if stride > 1 else steps_axis
 series_ds = [curves[d][::stride] if stride > 1 else curves[d] for d in ds_names_sorted]
 plot_series_with_ema(
 steps_ds,
 series_ds,
 [d.upper() for d in ds_names_sorted],
 fig_evo,
 y_label='Mean Cosine Similarity',
 ema_beta=float(args.ema_beta),
 alpha_raw=0.2,
 )

 with open(out_dir / 'dataset_evolution.csv', 'w', newline='') as f:
 writer = csv.writer(f)
 writer.writerow(['step'] + ds_names_sorted)
 for i in range(max_len):
 row = [str(i + 1)]
 for d in ds_names_sorted:
 v = curves[d][i] if i < len(curves[d]) else None
 row.append(f"{v:.6f}" if (v is not None and np.isfinite(v)) else '')
 writer.writerow(row)
 except Exception as e:
 print(f"[plot evolution] warn: {e}")

 payload = {
 'conf': args.conf,
 'ckpt': ckpt_path,
 'seed': int(args.seed),
 'steps': int(args.steps),
 'batch_size': int(args.batch_size),
 'split': args.split,
 'subspace_k': args.subspace_k,
 'proj_dim': int(args.proj_dim),
 'lr': float(args.lr),
 'datasets': ds_names_plot,
 'blocks_as_one': bool(args.blocks_as_one),
 'figures': {
 'heatmap': fig_heatmap,
 'subspace_affinity': subspace_figs if len(subspace_figs) > 1 else (subspace_figs[0] if subspace_figs else None),
 'cloud2d': fig_cloud,
 'cloud2d_kde': fig_cloud_kde,
 'r_sweep': fig_r_sweep,
 'shared_bars_multi_rank': fig_shared_multi,
 'dataset_evolution': fig_evo,
 }
 }
 with open(json_cfg, 'w') as f:
 json.dump(payload, f, indent=2)

 print(json.dumps({'saved_dir': str(out_dir), 'figs': payload['figures']}, indent=2))


# -----------------------------
# Entrypoint
# -----------------------------

def main():
 args = parse_args()
 set_seeds(args.seed)
 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 cfg = load_config(args)
 ds_dict = build_ds_dict(cfg)
 ds_list, loader_list = build_loaders(cfg, split=args.split, seed=int(args.seed))

 base_run_dir = make_run_base_dir(__file__, args.out_dir, args.run_name, prefix='labram_grad_ft_cls_ds')
 for ck in args.ckpts:
 run_one_ckpt(args, cfg, device, ds_list, loader_list, ds_dict, base_run_dir, ck)


def parse_args() -> RunArgs:
 ap = argparse.ArgumentParser(
 description="LABRAM dataset-wise encoder gradient similarity under finetune classification"
 )
 ap.add_argument('--conf', type=str, required=True, help='LABRAM config file (relative to assets/conf or absolute)')
 ap.add_argument('--ckpts', type=str, nargs='+', required=True, help='One or more checkpoint paths (.pt/.pth)')
 ap.add_argument('--steps', type=int, default=512, help='Consecutive optimization steps per dataset')
 ap.add_argument('--batch-size', type=int, default=8)
 ap.add_argument('--seed', type=int, default=100)
 ap.add_argument('--proj-dim', type=int, default=512, help='Dim of hashing projection before storing gradients; set >0 to enable')
 ap.add_argument('--subspace-k', type=int, nargs='+', default=[3], help='One or more top-k values for subspace affinity plots')
 ap.add_argument('--max-cloud-points', type=int, default=2000)
 ap.add_argument('--split', type=str, default='train', choices=['train', 'validation'])
 ap.add_argument('--lr', type=float, default=5e-5)
 ap.add_argument('--log-interval', type=int, default=1)
 ap.add_argument('--out-dir', type=str, default=None)
 ap.add_argument('--run-name', type=str, default=None)
 ap.add_argument('--ema-beta', type=float, default=0.9, help='EMA beta for smoothing evolution curves')
 ap.add_argument(
 '--blocks-as-one',
 action='store_true',
 help='Treat all transformer blocks as one parameter group instead of bucketed ranges',
 )
 args = ap.parse_args()

 seen = set()
 ckpts = [c for c in args.ckpts if not (c in seen or seen.add(c))]

 subspace_k = sorted({int(max(1, k)) for k in args.subspace_k}) or [3]

 return RunArgs(
 conf=args.conf,
 ckpts=ckpts,
 steps=args.steps,
 batch_size=args.batch_size,
 seed=args.seed,
 proj_dim=args.proj_dim,
 subspace_k=subspace_k,
 max_cloud_points=args.max_cloud_points,
 split=args.split,
 lr=args.lr,
 log_interval=args.log_interval,
 out_dir=args.out_dir,
 run_name=args.run_name,
 ema_beta=args.ema_beta,
 blocks_as_one=bool(args.blocks_as_one),
 )


if __name__ == '__main__':
 main()
