import argparse
import csv
import json
import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, cast

import datasets
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset as TorchDataset
from omegaconf import OmegaConf

from baseline.cbramod.cbramod_config import CBraModConfig
from baseline.cbramod.cbramod_adapter import CBraModDatasetAdapter
from baseline.cbramod.cbramod_trainer import CBraModUnifiedModel
from baseline.cbramod.model import CBraMod
from baseline.abstract.classifier import MultiHeadClassifier
from common.path import get_conf_file_path
from data.processor.wrapper import get_dataset_n_class, load_concat_eeg_datasets
from training.distributed.loader import DistributedGroupBatchSampler
from baseline.analysis.grad import (
 GradStore,
 compute_axis_heatmap,
 compute_subspace_affinity,
 compute_2d_embeddings,
 compute_axis_shared_specific,
 r_sweep_shared_energy_axes,
)
from baseline.analysis.plot import (
 plot_cloud2d,
 plot_cloud2d_kde,
 plot_heatmaps,
 plot_series_with_ema,
 plot_subspace_affinity,
 plot_r_sweep_axes,
 plot_shared_bars_multi_rank_axes,
)
from baseline.analysis.utils import (
 HashingProjector,
 flatten_from_grads,
 make_run_base_dir,
 set_seeds,
 strip_all_suffixes,
 export_group_matrix_csv,
)


DESIRED_DATASET_ORDER = ['tuab', 'tuev', 'seed', 'hmc', 'workload', 'tusl']


def _order_dataset_names(names: List[str]) -> List[str]:
 order = {name.lower(): idx for idx, name in enumerate(DESIRED_DATASET_ORDER)}
 return sorted(names, key=lambda n: order.get(n.lower(), len(order)))


# -----------------------------
# CLI and run configuration
# -----------------------------

@dataclass
class RunArgs:
 conf: str
 ckpts: List[str]
 steps: int = 512
 batch_size: int = 8
 seed: int = 100
 proj_dim: int = 512
 subspace_k: List[int] = field(default_factory=lambda: [3])
 max_cloud_points: int = 512
 split: str = "train"
 lr: float = 5e-5
 log_interval: int = 25
 out_dir: Optional[str] = None
 run_name: Optional[str] = None
 ema_beta: float = 0.9
 blocks_as_one: bool = False


# -----------------------------
# Config, dataloaders, models
# -----------------------------

def load_config(args: RunArgs) -> CBraModConfig:
 conf_path = get_conf_file_path(args.conf)
 cfg_raw = OmegaConf.load(conf_path)
 cfg = CBraModConfig.model_validate(OmegaConf.to_container(cfg_raw, resolve=True))
 cfg.data.batch_size = int(args.batch_size)
 cfg.conf_file = conf_path
 return cfg


def build_ds_dict(cfg: CBraModConfig) -> Dict[str, int]:
 ds_map = cfg.data.datasets or {}
 return {name: get_dataset_n_class(name, conf) for name, conf in ds_map.items()}


def build_loaders(cfg: CBraModConfig, split: str, seed: int) -> Tuple[List[str], List[DataLoader]]:
 split_enum = datasets.Split.TRAIN if split.lower().startswith('train') else datasets.Split.VALIDATION
 ds_names = _order_dataset_names(list(cfg.data.datasets.keys()))
 ds_configs = [cfg.data.datasets[name] for name in ds_names]
 loader_list: List[DataLoader] = []

 for name, conf in zip(ds_names, ds_configs):
 dataset, _ = load_concat_eeg_datasets(
 dataset_names=[name],
 builder_configs=[conf],
 split=split_enum,
 add_ds_name=True,
 cast_label=True,
 fs=200,
 )
 adapter = CBraModDatasetAdapter(dataset=dataset, dataset_names=[name], dataset_configs=[conf])
 dataset_for_sampler = cast(TorchDataset, dataset)
 sampler = DistributedGroupBatchSampler(
 dataset=dataset_for_sampler,
 batch_size=int(cfg.data.batch_size),
 num_replicas=1,
 rank=0,
 shuffle=True,
 seed=int(seed),
 drop_last=False,
 )
 pin_memory = torch.cuda.is_available()
 if cfg.data.num_workers > 0:
 loader = DataLoader(
 adapter,
 batch_sampler=sampler,
 num_workers=cfg.data.num_workers,
 pin_memory=pin_memory,
 persistent_workers=True,
 prefetch_factor=2,
 )
 else:
 loader = DataLoader(
 adapter,
 batch_sampler=sampler,
 num_workers=0,
 pin_memory=pin_memory,
 )
 loader_list.append(loader)

 return ds_names, loader_list


def build_model(cfg: CBraModConfig, ds_dict: Dict[str, int], device: torch.device) -> CBraModUnifiedModel:
 m = cfg.model
 encoder = CBraMod(
 in_dim=m.in_dim,
 out_dim=m.out_dim,
 d_model=m.d_model,
 dim_ffn=m.dim_ffn,
 n_layer=m.n_layer,
 n_head=m.n_head,
 )

 classifier = MultiHeadClassifier(
 embed_dim=m.out_dim,
 head_configs={ds: n_cls for ds, n_cls in ds_dict.items()},
 head_cfg=m.classifier_head,
 t_sne=m.t_sne,
 )

 model = CBraModUnifiedModel(encoder, classifier, grad_cam=getattr(m, 'grad_cam', False))
 model = model.to(device)
 model.train()
 return model


def load_model_weights(model: CBraModUnifiedModel, ckpt_path: str, device: torch.device) -> None:
 if not os.path.exists(ckpt_path):
 raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")

 ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)

 missing, unexpected = model.encoder.load_state_dict(ckpt, strict=False)

 if missing:
 print(f"[ckpt] missing keys: {missing}")
 if unexpected:
 print(f"[ckpt] unexpected keys: {unexpected}")

 model.train()



# -----------------------------
# Parameter grouping & gradient helpers
# -----------------------------

def cbramod_encoder_group_of(
 name: str,
 num_layers: int,
 bucket_size: int = 6,
 merge_blocks: bool = False,
) -> str:
 if not name.startswith('encoder.'):
 return 'non_encoder'
 sub = name[len('encoder.'):]
 if sub.startswith('patch_embedding.mask_encoding'):
 return 'patch_mask_encoding'
 if sub.startswith('patch_embedding.positional_encoding'):
 return 'patch_embed_positional'
 if sub.startswith('patch_embedding.proj_in'):
 return 'patch_embed_temporal'
 if sub.startswith('patch_embedding.spectral_proj'):
 return 'patch_embed_spectral'
 if sub.startswith('encoder.layers.'):
 if merge_blocks:
 return 'backbone'
 parts = sub.split('.')
 if len(parts) >= 3 and parts[2].isdigit():
 idx = int(parts[2])
 if bucket_size <= 0 or num_layers <= 0:
 return f'blocks_{idx:02d}'
 bucket = idx // bucket_size
 start = bucket * bucket_size
 end = min(num_layers - 1, (bucket + 1) * bucket_size - 1)
 return f'blocks_{start:02d}_{end:02d}'
 return 'blocks_misc'
 # if sub.startswith('proj_out.'):
 # return 'projection_head'
 return 'other'


def build_encoder_param_order_cbramod(
 model: torch.nn.Module,
 bucket_size: int = 6,
 merge_blocks: bool = False,
):
 encoder = getattr(model, 'encoder', None)
 num_layers = 0
 if encoder is not None and hasattr(encoder, 'encoder'):
 inner = getattr(encoder, 'encoder')
 if hasattr(inner, 'num_layers'):
 num_layers = int(inner.num_layers)

 names: List[str] = []
 groups: Dict[str, List[str]] = {}
 sizes: Dict[str, int] = {}

 for name, p in model.named_parameters():
 if not name.startswith('encoder.'):
 continue
 if not p.requires_grad:
 continue
 g = cbramod_encoder_group_of(
 name,
 num_layers=num_layers,
 bucket_size=bucket_size,
 merge_blocks=merge_blocks,
 )
 if g == 'non_encoder':
 continue
 names.append(name)
 groups.setdefault(g, []).append(name)
 sizes[name] = int(p.numel())

 names = sorted(names)
 for g in list(groups.keys()):
 groups[g] = sorted(groups[g])
 name_to_idx = {n: i for i, n in enumerate(names)}
 return names, name_to_idx, groups, sizes


# -----------------------------
# Gradient loop helpers
# -----------------------------

def set_seeds_local(seed: int):
 set_seeds(seed, deterministic=True)


def cycle_loader(loader: DataLoader):
 while True:
 for batch in loader:
 yield batch


def move_batch_to_device(batch: Dict[str, Any], device: torch.device) -> Dict[str, Any]:
 out: Dict[str, Any] = {}
 for k, v in batch.items():
 if torch.is_tensor(v):
 out[k] = v.to(device, non_blocking=True)
 elif isinstance(v, np.ndarray):
 t = torch.from_numpy(v)
 if t.dtype in (torch.int32, torch.int64, torch.int16, torch.uint8, torch.long):
 t = t.long()
 else:
 t = t.float()
 out[k] = t.to(device, non_blocking=True)
 else:
 out[k] = v
 return out


def run_one_ckpt(
 args: RunArgs,
 cfg: CBraModConfig,
 device: torch.device,
 ds_list: List[str],
 loader_list: List[DataLoader],
 ds_dict: Dict[str, int],
 base_run_dir: Path,
 ckpt_path: str,
):
 ckpt_short = strip_all_suffixes(Path(ckpt_path))
 out_dir = base_run_dir / ckpt_short
 out_dir.mkdir(parents=True, exist_ok=True)
 print(f"\n=== Running for ckpt: {ckpt_path} -> {out_dir} ===")

 amp_enabled = device.type == 'cuda'
 amp_dtype = torch.bfloat16
 projector = HashingProjector(args.proj_dim, int(args.seed))

 ds_names: List[str] = _order_dataset_names(list(cfg.data.datasets.keys()))
 store: Optional[GradStore] = None

 steps = int(args.steps)
 for ld_idx, loader in enumerate(loader_list):
 model = build_model(cfg, ds_dict, device)
 load_model_weights(model, ckpt_path, device)
 optim = torch.optim.Adam(model.parameters(), lr=float(args.lr))

 names, name_to_idx, group_names, size_map = build_encoder_param_order_cbramod(
 model,
 merge_blocks=bool(args.blocks_as_one),
 )
 visible_group_names = {g: n for g, n in group_names.items() if g != 'other'}
 if not visible_group_names:
 visible_group_names = dict(group_names)
 param_list = [dict(model.named_parameters())[n] for n in names]
 sizes_in_order = [size_map[n] for n in names]

 if store is None:
 store = GradStore(sorted(visible_group_names.keys()), ds_names)
 try:
 with open(out_dir / 'group_params.json', 'w') as f:
 json.dump({g: sorted(list(gns)) for g, gns in visible_group_names.items()}, f, indent=2)
 except Exception:
 pass
 assert store is not None

 cyc = cycle_loader(loader)
 ds_name_cur = ds_names[ld_idx] if ld_idx < len(ds_names) else ds_list[ld_idx]
 for step in range(steps):
 batch = next(cyc)
 batch_gpu = move_batch_to_device(batch, device)

 optim.zero_grad(set_to_none=True)
 with torch.autocast(device_type=device.type, dtype=amp_dtype, enabled=amp_enabled):
 logits = model(batch_gpu)
 labels = batch_gpu['label'].to(device, dtype=torch.long)
 num_classes = logits.shape[-1]
 if labels.numel() > 0:
 lmin = int(labels.min().item())
 lmax = int(labels.max().item())
 if lmin < 0 or lmax >= num_classes:
 raise RuntimeError(
 f"Label out of range for dataset '{ds_name_cur}': min={lmin}, max={lmax}, num_classes={num_classes}"
 )
 loss_ce = F.cross_entropy(logits, labels, reduction='mean')

 grads = torch.autograd.grad(loss_ce, param_list, retain_graph=True, allow_unused=True)
 for g, names in visible_group_names.items():
 idxs = [name_to_idx[n] for n in names]
 g_grads = [grads[i] for i in idxs]
 sizes_g = [sizes_in_order[i] for i in idxs]
 v = flatten_from_grads(g_grads, sizes_g, device)
 v = v / (torch.norm(v) + 1e-12)
 if args.proj_dim > 0:
 v = projector.project_and_norm(v, key=g)
 store.add(ds_name_cur, g, v.detach().cpu())

 loss_ce.backward()
 optim.step()

 if (step + 1) % max(1, int(args.log_interval)) == 0 or (step + 1) == steps:
 print(f"[ds {ld_idx+1}/{len(loader_list)}:{ds_name_cur}] step {step+1}/{steps} CE={float(loss_ce.item()):.4f}")

 assert store is not None

 fig_heatmap = str(out_dir / 'heatmap.png')
 fig_cloud = str(out_dir / 'cloud2d.png')
 fig_cloud_kde = str(out_dir / 'cloud2d_kde.png')
 fig_evo = str(out_dir / 'dataset_evolution.png')
 fig_r_sweep = str(out_dir / 'r_sweep_datasets.png')
 fig_shared_multi = str(out_dir / 'shared_bars_multi_rank_datasets.png')
 json_cfg = str(out_dir / 'run_config.json')

 heatmaps_final = compute_axis_heatmap(store)
 ds_names_plot = store.get_axes()

 print('[plot] heatmap')
 plot_heatmaps(heatmaps_final, ds_names_plot, fig_heatmap)
 try:
 export_group_matrix_csv(out_dir / 'heatmaps.csv', heatmaps_final, ds_names_plot)
 except Exception:
 pass

 subspace_figs: List[str] = []
 k_values = args.subspace_k or [3]
 multi_k = len(k_values) > 1
 print('[plot] subspace affinity')
 for k in k_values:
 subspace_final = compute_subspace_affinity(store, k=k)
 if not subspace_final:
 continue
 suffix = f'_topk_{k}' if multi_k else ''
 fig_subspace = str(out_dir / f'subspace_affinity{suffix}.png')
 plot_subspace_affinity(subspace_final, ds_names_plot, fig_subspace)
 subspace_figs.append(fig_subspace)
 try:
 export_group_matrix_csv(out_dir / f'subspace_affinity{suffix}.csv', subspace_final, ds_names_plot)
 except Exception:
 pass

 try:
 emb = compute_2d_embeddings(store, max_points=args.max_cloud_points)
 print('[plot] cloud2d')
 plot_cloud2d(emb, fig_cloud)
 print('[plot] cloud2d (kde)')
 plot_cloud2d_kde(emb, fig_cloud_kde)
 except Exception:
 pass

 try:
 shared_stats_ranks: Dict[int, Dict[str, Dict[str, Any]]] = {}
 ranks = [1, 2, 3, 4, 5]
 for r in ranks:
 stats_map, names_map = compute_axis_shared_specific(store, shared_rank=r)
 merged: Dict[str, Dict[str, Any]] = {}
 for g, st in stats_map.items():
 merged[g] = {**st, 'names': names_map.get(g)}
 shared_stats_ranks[r] = merged
 print('[plot] r_sweep (datasets as axes)')
 r_sweep = r_sweep_shared_energy_axes(store, ranks=list(range(1, 11)))
 plot_r_sweep_axes(r_sweep, fig_r_sweep, title='Gradient Shared Energy VS Rank')
 print('[plot] shared_bars_multi_rank (datasets as axes)')
 plot_shared_bars_multi_rank_axes(
 shared_stats_ranks,
 ds_names_plot,
 fig_shared_multi,
 title='Shared Energy Percentage by Dataset',
 )
 except Exception as e:
 print(f"[plot shared-specific datasets] warn: {e}")

 try:
 ds_names_sorted = store.get_axes()
 groups = store.get_groups()
 curves: Dict[str, List[float]] = {d: [] for d in ds_names_sorted}

 for d in ds_names_sorted:
 per_g = {g: store.data[g][d] for g in groups}
 length = max((len(vs) for vs in per_g.values()), default=0)
 for t in range(length):
 vals = []
 for g in groups:
 xs = per_g[g]
 if t == 0 or t >= len(xs):
 continue
 v_t = xs[t]
 m_prev = torch.stack(xs[:t], dim=0).mean(dim=0)
 m_prev = m_prev / (torch.norm(m_prev) + 1e-12)
 vtn = v_t / (torch.norm(v_t) + 1e-12)
 c = float(torch.dot(m_prev.flatten(), vtn.flatten()).item())
 vals.append(max(-1.0, min(1.0, c)))
 curves[d].append(float(np.mean(vals))) if vals else curves[d].append(float('nan'))

 max_len = max((len(v) for v in curves.values()), default=0)
 steps_axis = list(range(1, max_len + 1))
 stride = max(1, max_len // 400)
 steps_ds = steps_axis[::stride] if stride > 1 else steps_axis
 series_ds = [curves[d][::stride] if stride > 1 else curves[d] for d in ds_names_sorted]
 plot_series_with_ema(
 steps_ds,
 series_ds,
 [d.upper() for d in ds_names_sorted],
 fig_evo,
 y_label='Mean Cosine Similarity',
 ema_beta=float(args.ema_beta),
 alpha_raw=0.2,
 )

 with open(out_dir / 'dataset_evolution.csv', 'w', newline='') as f:
 writer = csv.writer(f)
 writer.writerow(['step'] + ds_names_sorted)
 for i in range(max_len):
 row = [str(i + 1)]
 for d in ds_names_sorted:
 v = curves[d][i] if i < len(curves[d]) else None
 row.append(f"{v:.6f}" if (v is not None and np.isfinite(v)) else '')
 writer.writerow(row)

 except Exception as e:
 print(f"[plot evolution] warn: {e}")

 payload = {
 'conf': args.conf,
 'ckpt': ckpt_path,
 'seed': int(args.seed),
 'steps': int(args.steps),
 'batch_size': int(args.batch_size),
 'split': args.split,
 'subspace_k': args.subspace_k,
 'proj_dim': int(args.proj_dim),
 'lr': float(args.lr),
 'datasets': ds_names_plot,
 'blocks_as_one': bool(args.blocks_as_one),
 'figures': {
 'heatmap': fig_heatmap,
 'subspace_affinity': subspace_figs if len(subspace_figs) > 1 else (subspace_figs[0] if subspace_figs else None),
 'cloud2d': fig_cloud,
 'cloud2d_kde': fig_cloud_kde,
 'r_sweep': fig_r_sweep,
 'shared_bars_multi_rank': fig_shared_multi,
 'dataset_evolution': fig_evo,
 }
 }
 with open(json_cfg, 'w') as f:
 json.dump(payload, f, indent=2)

 print(json.dumps({'saved_dir': str(out_dir), 'figs': payload['figures']}, indent=2))


# -----------------------------
# Entrypoint
# -----------------------------

def main():
 args = parse_args()
 set_seeds(args.seed, deterministic=True)
 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 cfg = load_config(args)
 ds_dict = build_ds_dict(cfg)
 ds_list, loader_list = build_loaders(cfg, split=args.split, seed=int(args.seed))

 base_run_dir = make_run_base_dir(__file__, args.out_dir, args.run_name, prefix='cbramod_grad_ft_cls_ds')
 for ck in args.ckpts:
 run_one_ckpt(args, cfg, device, ds_list, loader_list, ds_dict, base_run_dir, ck)


def parse_args() -> RunArgs:
 ap = argparse.ArgumentParser(
 description="CBraMod dataset-wise encoder gradient similarity under finetune classification"
 )
 ap.add_argument('--conf', type=str, required=True, help='CBraMod config file (relative to assets/conf or absolute)')
 ap.add_argument('--ckpts', type=str, nargs='+', required=True, help='One or more checkpoint paths (.pt/.pth)')
 ap.add_argument('--steps', type=int, default=512, help='Consecutive optimization steps per dataset')
 ap.add_argument('--batch-size', type=int, default=8)
 ap.add_argument('--seed', type=int, default=100)
 ap.add_argument('--proj-dim', type=int, default=512, help='Dim of hashing projection before storing gradients; set >0 to enable')
 ap.add_argument('--subspace-k', type=int, nargs='+', default=[3], help='One or more top-k values for subspace affinity plots')
 ap.add_argument('--max-cloud-points', type=int, default=2000)
 ap.add_argument('--split', type=str, default='train', choices=['train', 'validation'])
 ap.add_argument('--lr', type=float, default=5e-5)
 ap.add_argument('--log-interval', type=int, default=1)
 ap.add_argument('--out-dir', type=str, default=None)
 ap.add_argument('--run-name', type=str, default=None)
 ap.add_argument('--ema-beta', type=float, default=0.9, help='EMA beta for smoothing evolution curves')
 ap.add_argument(
 '--blocks-as-one',
 action='store_true',
 help='Treat all transformer blocks as a single group instead of splitting them by depth buckets',
 )
 args = ap.parse_args()

 seen = set()
 ckpts = [c for c in args.ckpts if not (c in seen or seen.add(c))]

 subspace_k = sorted({int(max(1, k)) for k in args.subspace_k}) or [3]

 return RunArgs(
 conf=args.conf,
 ckpts=ckpts,
 steps=args.steps,
 batch_size=args.batch_size,
 seed=args.seed,
 proj_dim=args.proj_dim,
 subspace_k=subspace_k,
 max_cloud_points=args.max_cloud_points,
 split=args.split,
 lr=args.lr,
 log_interval=args.log_interval,
 out_dir=args.out_dir,
 run_name=args.run_name,
 ema_beta=args.ema_beta,
 blocks_as_one=bool(args.blocks_as_one),
 )


if __name__ == '__main__':
 main()