import torch
import torch.nn as nn
from torch import Tensor
import torch.nn.functional as F

import einops
import math

from .config import SAEConfig
from .model import BaseSAE
# from .utils import *

class KronSAE(BaseSAE):
 
    def __init__(self, config: SAEConfig):
        super().__init__(config)
        self.config = config

        assert self.config.cartesian_op in ["sum", "mul"]

        assert int((self.config.dict_size // self.config.num_heads) ** (1 / self.config.router_depth)) ** self.config.router_depth == self.config.dict_size // self.config.num_heads
        self.num_keys = int(math.sqrt(self.config.dict_size // self.config.num_heads))
        self.topk1 = self.config.topk1 # save topk1, config.topk1 can be changed due to train schedule if router_tree_width not exist

        self.p = self.config.router_depth
        self.h = self.config.num_heads
        self.m, self.n = self.config.num_mkeys, self.config.num_nkeys

        _t = torch.nn.init.normal_(
                torch.empty(self.config.act_size, self.config.num_heads * (self.m + self.n)) 
            ) / math.sqrt(self.config.dict_size) * math.sqrt(2.0) ## fan_in for enc

        self.W_enc = nn.Parameter(
            _t 
        )
        self.b_enc = nn.Parameter(torch.zeros(self.config.num_heads * (self.m + self.n)))

        W_dec_data = _t.t().clone()
        W_dec_v0 = einops.rearrange(W_dec_data, '(h mn) d -> h mn d', h=self.config.num_heads, mn=self.m + self.n)[:, :self.m]
        W_dec_v1 = einops.rearrange(W_dec_data, '(h mn) d -> h mn d', h=self.config.num_heads, mn=self.m + self.n)[:, self.m:]
        cartesian = (W_dec_v0[..., None, :] + W_dec_v1[..., None, :, :])
        cartesian = einops.rearrange(cartesian, 'h m n d -> (h m n) d')
        
        self.W_dec = nn.Parameter(
            cartesian
        )
        self.W_dec.data[:] = self.W_dec.data / self.W_dec.data.norm(dim=-1, keepdim=True)

        self.b_dec = nn.Parameter(torch.zeros(self.config.act_size))

        self.register_buffer('num_batches_not_active', torch.zeros((self.config.dict_size,)))

        self.to(self.config.get_torch_dtype(self.config.dtype))

        # Initialize input normalization parameters if provided
        if config.input_mean is not None and config.input_std is not None:
            self.register_buffer('input_mean', torch.tensor(config.input_mean))
            self.register_buffer('input_std', torch.tensor(config.input_std))
        else:
            self.input_mean = None
            self.input_std = None

        # Initialize CUDA events for timing
        if torch.cuda.is_available():
            self.start_event = torch.cuda.Event(enable_timing=True)
            self.end_event = torch.cuda.Event(enable_timing=True)
        else:
            self.start_event = None
            self.end_event = None

    def _standard_expert_retrieval(self, acts: torch.Tensor):
        B, H, MN = acts.shape
        m_acts = acts[..., :self.m].contiguous()
        n_acts = acts[..., self.m:].contiguous()
        scores_x, scores_y = m_acts, n_acts

        all_scores = scores_x[..., None] * scores_y[..., None, :]
        all_scores = torch.sqrt(all_scores + 1e-5)
        all_scores = all_scores.view(B, -1)

        # #top-k to choose final K candidates of K^2
        k = self.config.topk2
        if k > all_scores.shape[-1]:
            k = all_scores.shape[-1]

        scores, indices = all_scores.topk(k, dim=-1, sorted=False)

        return all_scores, scores, indices

    def encode(self, x: torch.Tensor) -> torch.Tensor:
        if self.config.input_unit_norm:
            x_mean = x.mean(dim=-1, keepdim=True)
            x = x - x_mean
            x_std = x.std(dim=-1, keepdim=True)
            x = x / (x_std + 1e-5)

        _, acts_topk = self.encode_forward(x)

        return acts_topk

    def encode_pre(self, x: torch.Tensor) -> torch.Tensor:
        if self.config.input_unit_norm:
            x_mean = x.mean(dim=-1, keepdim=True)
            x = x - x_mean
            x_std = x.std(dim=-1, keepdim=True)
            x = x / (x_std + 1e-5)
        
        B, D = x.shape
        acts = F.relu(x @ self.W_enc + self.b_enc).view(B, self.h, self.m + self.n)
        return acts

    def get_masked_pre_acts(self, pre_acts: torch.Tensor) -> torch.Tensor:
        """Returns pre-activations masked to only keep contributors to top-k post-activations"""
        B, h, mn = pre_acts.shape
        m = self.m
        n = self.n
        device = pre_acts.device
        
        # 1. Split into m_acts and n_acts
        m_acts = pre_acts[..., :m]
        n_acts = pre_acts[..., m:]
        
        # 2. Recompute scores and indices (without modifying original code)
        with torch.no_grad():
            # Compute outer product scores
            all_scores = m_acts[..., None] * n_acts[..., None, :]
            all_scores = torch.sqrt(all_scores + 1e-5) # uncomment this line to revert to old kronsae
            all_scores = all_scores.view(B, -1)
    
            # Get top-k indices
            k = self.config.topk2
            if k > all_scores.shape[-1]:
                k = all_scores.shape[-1]
            _, indices = all_scores.topk(k, dim=-1, sorted=False)  # [Batch, k]
    
        # 3. Create masks for original pre-activations
        flat_indices = indices.view(-1)  # Flatten to [Batch*k]
        
        # Calculate corresponding head/row/column indices
        head_indices = flat_indices // (m * n)
        pos_in_head = flat_indices % (m * n)
        row_indices = pos_in_head // n
        col_indices = pos_in_head % n
    
        # Create batch indices [0,0,...0, 1,1,...1, ...]
        batch_indices = torch.arange(B, device=device)[:, None].expand(-1, k).flatten()
    
        # Initialize masks
        mask_m = torch.zeros((B, h, m), device=device)
        mask_n = torch.zeros((B, h, n), device=device)
    
        # Set mask values using calculated indices
        mask_m[batch_indices, head_indices, row_indices] = 1
        mask_n[batch_indices, head_indices, col_indices] = 1
    
        # 4. Apply masks to original pre-activations
        masked_m = m_acts * mask_m
        masked_n = n_acts * mask_n
        
        # 5. Combine masked components
        return torch.cat([masked_m, masked_n], dim=-1)
    
    def encode_forward(self, x: torch.Tensor) -> torch.Tensor:
        B, D = x.shape
        acts = F.relu(x @ self.W_enc + self.b_enc).view(B, self.h, self.m + self.n)

        acts, scores, indices = self._standard_expert_retrieval(acts)
        acts_topk = torch.zeros((B, self.config.dict_size), device=scores.device, dtype=scores.dtype).scatter(
            -1, indices, scores,
        )
        return acts, acts_topk
    
    def update_inactive_features(self, acts):
        self.num_batches_not_active += (acts.sum(0) == 0).float()
        self.num_batches_not_active[acts.sum(0) > 0] = 0

    def decode(self, acts_topk: torch.Tensor):
        x_reconstruct = acts_topk @ self.W_dec + self.b_dec
        return x_reconstruct

    def forward(self, x):
        x, x_mean, x_std = self.preprocess_input(x)
        acts, acts_topk = self.encode_forward(x)
        x_reconstruct = self.decode(acts_topk)
        self.update_inactive_features(acts_topk)
        output = self.get_loss_dict(x, x_reconstruct, acts, acts_topk, x_mean, x_std)
        return output

    def get_loss_dict(self, x, x_reconstruct, acts, acts_topk, x_mean, x_std):
        l2_loss = (x_reconstruct.float() - x.float()).pow(2).mean()
        l1_norm = acts_topk.float().abs().sum(-1).mean()
        l1_loss = self.config.l1_coeff * l1_norm
        l0_norm = (acts_topk > 0).float().sum(-1).mean()
        
        # calc outer product for aux loss
        # m_acts = acts[..., :self.m].contiguous()
        # n_acts = acts[..., self.m:].contiguous()
        # acts = torch.sqrt(m_acts[..., None] * n_acts[..., None, :] + 1e-5).view(acts.shape[0], -1)
        aux_loss = self.get_auxiliary_loss(x, x_reconstruct, acts)
        
        loss = l2_loss + aux_loss
        num_dead_features = (
            self.num_batches_not_active > self.config.n_batches_to_dead
        ).sum()
        sae_out = self.postprocess_output(x_reconstruct, x_mean, x_std)
        per_token_l2_loss_A = (x_reconstruct.float() - x.float()).pow(2).sum(-1).squeeze()
        total_variance_A = (x.float() - x.float().mean(0)).pow(2).sum(-1).squeeze()
        explained_variance = (1 - per_token_l2_loss_A / total_variance_A).mean()
        output = {
            "sae_out": sae_out,
            "feature_acts": acts_topk,
            "num_dead_features": num_dead_features,
            "acts_mean_value": acts_topk[acts_topk > 0.0].mean(),
            "pre_topk_mean": acts.mean(),
            "pre_topk_std": acts.std(),
            "pre_topk_mean_group_std": acts.std(dim=-1).mean(),
            "loss": loss,
            "l1_loss": l1_loss,
            "l2_loss": l2_loss,
            "l0_norm": l0_norm,
            "l1_norm": l1_norm,
            "explained_variance": explained_variance,
            "topk2": self.config.topk2,
            "aux_loss": aux_loss,
        }
        return output
    
    def get_auxiliary_loss(self, x, x_reconstruct, acts):
        dead_features = self.num_batches_not_active >= self.config.n_batches_to_dead
        if dead_features.sum() > 0:
            residual = x.float() - x_reconstruct.float()
            acts_topk_aux = torch.topk(
                acts[:, dead_features],
                min(self.config.topk2_aux, dead_features.sum()),
                dim=-1,
            )
            acts_aux = torch.zeros_like(acts[:, dead_features]).scatter(
                -1, acts_topk_aux.indices, acts_topk_aux.values
            )
            x_reconstruct_aux = acts_aux @ self.W_dec[dead_features]
            l2_loss_aux = (
                self.config.aux_penalty
                * (x_reconstruct_aux.float() - residual.float()).pow(2).mean()
            )
            return l2_loss_aux
        else:
            return torch.tensor(0, dtype=x.dtype, device=x.device)

KronSAE.register_for_auto_class()