# -*- coding: utf-8 -*-
import math, random
from dataclasses import dataclass
from typing import Optional, Tuple, Dict, List
import torch
import torch.nn as nn
import torch.nn.functional as F

def set_seed(seed: int = 23):
    random.seed(seed); torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)

@torch.no_grad()
def channel_stats(x: torch.Tensor):
    v = x.abs()
    a1, c_star = v.max(dim=-1)
    top2 = torch.topk(v, k=2, dim=-1).values
    a2 = top2[:, 1]
    rho   = a1 / (x.norm(dim=-1) + 1e-12)
    share = a1 / (v.sum(dim=-1)    + 1e-12)
    gap   = a1 - a2
    return c_star, a1, a2, rho, gap, share

class RunningMoments:
    def __init__(self, d: int):
        self.n = 0
        self.mean = torch.zeros(d, dtype=torch.float64)
        self.M2   = torch.zeros(d, dtype=torch.float64)

    @torch.no_grad()
    def update(self, x: torch.Tensor):
        x = x.double()
        n = x.size(0)
        if n == 0: return
        self.n += n
        delta = x.mean(0) - self.mean
        self.mean += delta * (n / self.n)
        diff = x - self.mean
        self.M2 += (diff * diff).sum(0)

    def finalize(self, eps=1e-5):
        var = self.M2 / max(self.n - 1, 1)
        std = torch.sqrt(var + eps)
        return self.mean.float(), std.float()

@dataclass
class DataSpec:
    d_model: int = 4096
    num_layers: int = 33
    ignore_layer0: bool = True
    batch_size_stream: int = 8192
    token_ignore_first_n: int = 0
    token_ignore_last_n: int = 0

def iter_all_rows(
    all_hidden_states: List[torch.Tensor],
    steps: Optional[List[int]],
    layers: Optional[List[int]],
    spec: DataSpec
):
    S = len(all_hidden_states)
    steps = steps if steps is not None else list(range(S))
    L  = spec.num_layers
    layers = layers if layers is not None else list(range(L))
    if spec.ignore_layer0 and 0 in layers:
        layers = [l for l in layers if l != 0]
    start_idx = 0
    for t in steps:
        H = all_hidden_states[t]
        B = H.size(1); T = H.size(2); D = H.size(3)
        tok_lo = spec.token_ignore_first_n
        tok_hi = T - spec.token_ignore_last_n
        for l in layers:
            X = H[l]
            X = X.reshape(B*T, D)
            X = X.view(B, T, D)[0, tok_lo:tok_hi, :]
            for i in range(0, X.size(0), spec.batch_size_stream):
                chunk = X[i:i+spec.batch_size_stream].detach().to('cpu', dtype=torch.float32)
                if chunk.size(0) == 0: continue
                yield chunk, start_idx
                start_idx += chunk.size(0)

def pass1_compute_stats_and_shares(
    all_hidden_states: List[torch.Tensor],
    steps: Optional[List[int]],
    layers: Optional[List[int]],
    spec: DataSpec
):
    shares: List[torch.Tensor] = []
    rhos:   List[torch.Tensor] = []
    gaps:   List[torch.Tensor] = []
    d = spec.d_model
    rm = RunningMoments(d)
    N = 0
    for X_chunk, _ in iter_all_rows(all_hidden_states, steps, layers, spec):
        rm.update(X_chunk)
        _, _, _, rho, gap, share = channel_stats(X_chunk)
        shares.append(share.cpu())
        rhos.append(rho.cpu())
        gaps.append(gap.cpu())
        N += X_chunk.size(0)
    shares = torch.cat(shares, dim=0)
    rhos   = torch.cat(rhos,   dim=0)
    gaps   = torch.cat(gaps,   dim=0)
    mean, std = rm.finalize()
    return N, shares, rhos, gaps, mean, std

@dataclass
class SAEConfig:
    d: int
    m: int
    l1: float = 1e-3
    lr: float = 3e-4
    epochs: int = 2
    device: str = "cuda"
    k_sparse: Optional[int] = None
    grad_clip: Optional[float] = 1.0
    norm_decoder: bool = True

class SparseAutoencoder(nn.Module):
    def __init__(self, cfg: SAEConfig):
        super().__init__()
        self.cfg = cfg
        self.encoder = nn.Linear(cfg.d, cfg.m, bias=True)
        self.decoder = nn.Linear(cfg.m, cfg.d, bias=True)
        nn.init.kaiming_uniform_(self.encoder.weight, a=math.sqrt(5))
        nn.init.zeros_(self.decoder.bias)

    def forward(self, x):
        z = F.relu(self.encoder(x))
        if self.cfg.k_sparse is not None:
            k = min(self.cfg.k_sparse, z.size(-1))
            thr = torch.topk(z, k=k, dim=-1).values[..., -1].unsqueeze(-1)
            z = torch.where(z >= thr, z, torch.zeros_like(z))
        x_hat = self.decoder(z)
        return x_hat, z

    @torch.no_grad()
    def normalize_decoder(self, eps=1e-12):
        if not self.cfg.norm_decoder: return
        W = self.decoder.weight.data
        col_norm = W.norm(dim=0, keepdim=True) + eps
        self.decoder.weight.data = W / col_norm

@dataclass
class SAETrained:
    cfg: SAEConfig
    model: SparseAutoencoder
    mean: torch.Tensor
    std:  torch.Tensor
    feat_channel_proto: torch.Tensor

def iterate_with_mask(all_hidden_states, steps, layers, spec, mask: torch.Tensor):
    assert mask.dtype == torch.bool
    N_total = mask.numel()
    idx_start = 0
    for X_chunk, start in iter_all_rows(all_hidden_states, steps, layers, spec):
        n = X_chunk.size(0)
        if n == 0: 
            idx_start += 0
            continue
        m = mask[start:start+n]
        if m.any():
            yield X_chunk[m], start + torch.nonzero(m).squeeze(-1)
        idx_start += n

def compute_train_moments_stream(all_hidden_states, steps, layers, spec, mask):
    rm = RunningMoments(spec.d_model)
    for X_chunk, _ in iterate_with_mask(all_hidden_states, steps, layers, spec, mask):
        rm.update(X_chunk)
    mean, std = rm.finalize()
    return mean, std

def train_sae_stream(
    all_hidden_states, steps, layers, spec: DataSpec,
    mask: torch.Tensor,
    cfg: SAEConfig,
    mean: torch.Tensor, std: torch.Tensor
) -> SAETrained:
    device = cfg.device if torch.cuda.is_available() else "cpu"
    model = SparseAutoencoder(cfg).to(device)
    opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr)
    mean = mean.to(device); std = std.to(device)
    for ep in range(cfg.epochs):
        for X_chunk, _ in iterate_with_mask(all_hidden_states, steps, layers, spec, mask):
            Xn = ((X_chunk.to(device) - mean) / std).float()
            x_hat, z = model(Xn)
            rec = F.mse_loss(x_hat, Xn)
            loss = rec + cfg.l1 * z.abs().sum(dim=-1).mean()
            opt.zero_grad(set_to_none=True); loss.backward()
            if cfg.grad_clip is not None:
                nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip)
            opt.step()
            model.normalize_decoder()
    with torch.no_grad():
        Wd = model.decoder.weight.data.t().abs().cpu()
        proto = torch.argmax(Wd, dim=-1)
    return SAETrained(cfg, model.eval(), mean.cpu(), std.cpu(), proto)

class AlignmentEvaluator:
    def __init__(self, trained: SAETrained, ks=(1,5,10)):
        self.trained = trained
        self.ks = tuple(sorted(set(ks)))
        W = trained.model.decoder.weight.data.detach().cpu().abs()
        self.W = W
        self.col_l1 = W.sum(dim=0) + 1e-12
        self.topk_idx = {k: torch.topk(W, k=k, dim=0).indices for k in self.ks}
        self.d = W.size(0); self.m = W.size(1)
        self.n = 0
        self.topk_hits = {k: 0 for k in self.ks}
        self.purity_sum = 0.0
        self.active_sum = 0.0
        self.mse_sum = 0.0
        self.M_counts = torch.zeros(self.d, self.m, dtype=torch.float64)

    def update(self, X_chunk: torch.Tensor, device='cuda'):
        dev = device if torch.cuda.is_available() else 'cpu'
        m = self.trained
        X_dev = X_chunk.to(dev)
        Xn = ((X_dev - m.mean.to(dev)) / m.std.to(dev)).float()
        x_hat, z = m.model(Xn)
        X_cpu = X_chunk.to('cpu').float()
        c_star, _, _, _, _, _ = channel_stats(X_cpu)
        j_star = torch.argmax(z.detach().to('cpu'), dim=-1)
        for k in self.ks:
            hits = (self.topk_idx[k][:, j_star] == c_star)
            self.topk_hits[k] += hits.any(dim=0).sum().item()
        self.purity_sum += (self.W[c_star, j_star] / self.col_l1[j_star]).sum().item()
        self.active_sum += (z > 0).float().sum(dim=-1).sum().item()
        self.mse_sum += F.mse_loss(x_hat, Xn, reduction='none').mean(dim=-1).sum().item()
        for c, j in zip(c_star, j_star):
            self.M_counts[c.long(), j.long()] += 1.0
        self.n += X_cpu.size(0)

    def finalize(self) -> Dict:
        eps = 1e-12
        P = self.M_counts / max(self.n, 1)
        Pc = P.sum(dim=1, keepdim=True)
        Pj = P.sum(dim=0, keepdim=True)
        mask = P > 0
        MI = (P[mask] * (torch.log(P[mask] + eps) - torch.log(Pc.repeat(1, self.m)[mask] + eps)
                         - torch.log(Pj.repeat(self.d,1)[mask] + eps))).sum().item()
        out = {
            "N_eval": self.n,
            "topk": {k: (self.topk_hits[k] / max(self.n,1)) for k in self.ks},
            "purity_l1": self.purity_sum / max(self.n, 1),
            "mean_active_features": self.active_sum / max(self.n, 1),
            "mse_recon": self.mse_sum / max(self.n, 1),
            "mi": MI
        }
        return out

def evaluate_streaming(all_hidden_states, steps, layers, spec, mask, trained: SAETrained, ks=(1,5,10)):
    ev = AlignmentEvaluator(trained, ks=ks)
    for X_chunk, _ in iterate_with_mask(all_hidden_states, steps, layers, spec, mask):
        ev.update(X_chunk)
    return ev.finalize()

def run_sae_alignment(
    all_hidden_states: List[torch.Tensor],
    steps: Optional[List[int]] = None,
    layers: Optional[List[int]] = None,
    spec: DataSpec = DataSpec(),
    extreme_mode: str = "share",
    extreme_p: float = 0.95,
    train_ratio: float = 0.8,
    m_mult: int = 4, l1_global: float = 1e-3,
    epochs: int = 2, k_sparse: Optional[int] = None, device: str = "cuda"
) -> Dict:
    set_seed(23)
    N, shares, rhos, gaps, mean_all, std_all = pass1_compute_stats_and_shares(
        all_hidden_states, steps, layers, spec
    )
    thr_share = torch.quantile(shares, torch.tensor(extreme_p))
    thr_gap   = torch.quantile(gaps,   torch.tensor(extreme_p))
    thr_rho   = torch.quantile(rhos,   torch.tensor(extreme_p))
    if extreme_mode == "share": thr = thr_share
    elif extreme_mode == "gap": thr = thr_gap
    elif extreme_mode == "rho": thr = thr_rho
    elif extreme_mode == "quantile": thr = torch.quantile(shares, torch.tensor(extreme_p))
    else: raise ValueError("extreme_mode invalid")
    ext_mask = (shares >= thr)
    perm = torch.randperm(N)
    n_train = int(train_ratio * N)
    train_ids = perm[:n_train]
    test_ids  = perm[n_train:]
    train_mask = torch.zeros(N, dtype=torch.bool); train_mask[train_ids] = True
    test_mask  = ~train_mask
    mean_g, std_g = compute_train_moments_stream(all_hidden_states, steps, layers, spec, train_mask)
    cfg_g = SAEConfig(d=spec.d_model, m=m_mult*spec.d_model, l1=l1_global, lr=3e-4,
                      epochs=epochs, device=device, k_sparse=k_sparse)
    sae_global = train_sae_stream(all_hidden_states, steps, layers, spec, train_mask, cfg_g, mean_g, std_g)
    met_global_test = evaluate_streaming(all_hidden_states, steps, layers, spec, test_mask, sae_global)
    met_global_ext  = evaluate_streaming(all_hidden_states, steps, layers, spec, ext_mask & test_mask, sae_global)
    sizes = {
        "N_all": N,
        "N_train": n_train,
        "N_test": N - n_train,
        "N_ext": int(ext_mask.sum().item()),
        "extreme_threshold": float(thr.item()),
        "thr_share": float(thr_share.item()),
        "thr_gap":   float(thr_gap.item()),
        "thr_rho":   float(thr_rho.item()),
    }
    return {
        "sizes": sizes,
        "global": sae_global,
        "metrics": {
            "global_on_test": met_global_test,
            "global_on_ext_test": met_global_ext
        }
    }

def quick_single_layer_step(
    all_hidden_states: List[torch.Tensor],
    step_idx: int, layer_idx: int,
    spec: DataSpec = DataSpec(),
    extreme_p=0.95, m_mult=4, device="cuda"
):
    H = all_hidden_states[step_idx]
    X = H[layer_idx].reshape(-1, spec.d_model)
    tok_lo = spec.token_ignore_first_n; tok_hi = H.size(2) - spec.token_ignore_last_n
    X = X.view(1, -1, spec.d_model)[0, tok_lo:tok_hi, :]
    N = X.size(0)
    perm = torch.randperm(N)
    n_train = int(0.8 * N)
    Xtr, Xte = X[perm[:n_train]], X[perm[n_train:]]
    c, a1, a2, rho, gap, share = channel_stats(Xtr)
    thr = torch.quantile(share, torch.tensor(extreme_p))
    def fit_sae(X, l1):
        mean = X.mean(0); std = X.std(0) + 1e-5
        cfg = SAEConfig(d=spec.d_model, m=m_mult*spec.d_model, l1=l1, device=device)
        model = SparseAutoencoder(cfg).to(device)
        opt = torch.optim.AdamW(model.parameters(), lr=3e-4)
        for ep in range(3):
            x = ((X.to(device)-mean.to(device))/std.to(device)).float()
            x_hat, z = model(x)
            loss = F.mse_loss(x_hat, x) + cfg.l1 * z.abs().sum(-1).mean()
            opt.zero_grad(set_to_none=True); loss.backward(); opt.step(); model.normalize_decoder()
        with torch.no_grad():
            Wd = model.decoder.weight.data.t().abs().cpu()
            proto = torch.argmax(Wd, dim=-1)
        return SAETrained(cfg, model.eval(), mean.cpu(), (std).cpu(), proto)
    sae_g = fit_sae(Xtr, 1e-3)
    m_g = evaluate_streaming([H], steps=[0], layers=[layer_idx], spec=spec,
                             mask=torch.ones(N, dtype=torch.bool), trained=sae_g)
    return {"sizes": {"N": N}, "global_on_all": m_g}
