import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from spikingjelly.activation_based import functional, layer, neuron
try:
    from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn, bimamba_inner_fn, mamba_inner_fn_no_out_proj
except ImportError:
    selective_scan_fn, mamba_inner_fn, bimamba_inner_fn, mamba_inner_fn_no_out_proj = None, None, None, None, None

try:
    from mamba_ssm.ops.triton.selective_state_update import selective_state_update
except ImportError:
    selective_state_update = None
import torch.nn.functional as F
try:
    from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn
except ImportError:
    RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
from mamba_ssm.modules.mamba_simple import Mamba
from mamba_ssm.utils.generation import GenerationMixin
from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf
import random
from einops import rearrange, repeat
try:
    from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn
except ImportError:
    RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
from .Semanticlearning import MLP


try:
    from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
except ImportError:
    causal_conv1d_fn, causal_conv1d_update = None
from collections import OrderedDict
from torch.distributions import Bernoulli
from geomloss import SamplesLoss


class SNNEmbeddingWithXOR(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, initial_p=0.5):
        """
  
        """
        super(SNNEmbeddingWithXOR, self).__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features

        self.fc1_conv = nn.Linear(in_features, hidden_features)
        self.fc1_bn = nn.BatchNorm1d(hidden_features)
        self.fc1_lif = neuron.ParametricLIFNode(step_mode='m', backend='cupy')

    def mask(self,x):
        p = torch.sigmoid(x)
        p = torch.log1p(x) / torch.log(torch.tensor(2.0))
        mask = torch.bernoulli(torch.ones_like(x) * p).float()
       
        return mask
    def forward(self, x, train=True):
        """

        """
        T, B, C, N = x.shape


        x = x.transpose(-1, -2)
        x = self.fc1_conv(x.flatten(0, 1))
        x = self.fc1_bn(x.transpose(-1, -2)).reshape(T, B, -1, N).contiguous()

        x = self.fc1_lif(x)
        mask = self.mask(x)
        if train:
            x = torch.fmod(x + mask, 2) 
        else:
            x
        return x




class CosineSimilarity(nn.Module):
    def __init__(self, eps: float = 1e-6, margin: Optional[float] = None, reduction: str = "mean"):
        super().__init__()
        assert reduction in ("mean", "sum", "none")
        self.eps = float(eps)
        self.margin = margin
        self.reduction = reduction

    def forward(
        self,
        pred: torch.Tensor,   # (T, B, C, N)
        gt: torch.Tensor,     # (T, B, C, N)
        mask: Optional[torch.Tensor] = None  # (T, B, N)
    ) -> torch.Tensor:
        T, B, C, N = pred.shape

        # (T, B, C, N) -> (B, T*N, C)
        pred = pred.permute(1, 0, 3, 2).reshape(B, T * N, C).contiguous()
        gt   = gt.permute(1, 0, 3, 2).reshape(B, T * N, C).contiguous()

        # L2 normalization with epsilon stabilizer
        pred = F.normalize(pred, p=2, dim=-1, eps=self.eps)
        gt   = F.normalize(gt,   p=2, dim=-1, eps=self.eps)

        # Cosine per token: (B, T*N)
        cos = (pred * gt).sum(dim=-1)

        # Optional mask handling
        if mask is not None:
            mask = mask.permute(1, 0, 2).reshape(B, T * N).to(dtype=cos.dtype, device=cos.device)
        else:
            # all ones
            mask = torch.ones_like(cos)

        if self.margin is None:
            # Return cosine similarity (no negative sign)
            val = cos * mask
        else:
            # Hinge loss: max(0, margin - cos), still masked
            val = F.relu(self.margin - cos) * mask

        if self.reduction == "mean":
            denom = mask.sum().clamp_min(1.0)
            return val.sum() / denom
        elif self.reduction == "sum":
            return val.sum()
        else:
            return val  # shape (B, T*N)
class PB(nn.Module):
    def __init__(self, in_features, hidden_features=None, ema_momentum=0.95):
        super(PB, self).__init__()
        hidden_features = hidden_features or in_features
        self.fc = nn.Linear(in_features, hidden_features)
        self.bn = nn.BatchNorm1d(hidden_features)
        self.sigmoid = nn.Sigmoid()
        self.register_buffer('ema_q', torch.zeros(1, hidden_features, 1))  # [1, C, 1]
        self.ema_momentum = ema_momentum
        self.initialized = False

    def forward(self, x, update_ema=True):
        # x: [T, B, C, N]
        T, B, C, N = x.shape
        x = x.transpose(-1, -2)  # [T, B, N, C]
        x = self.fc(x.flatten(0, 1))  # [T*B, N, C]
        x = self.bn(x.transpose(-1, -2)).reshape(T, B, -1, N).contiguous()
        p = self.sigmoid(x)  # posterior distribution

        if update_ema:
            q_batch_mean = p.detach().mean(dim=(0, 1, 3), keepdim=True)  # [1, C, 1]
            if not self.initialized:
                self.ema_q.copy_(q_batch_mean)
                self.initialized = True
            else:
                self.ema_q.mul_(self.ema_momentum).add_((1 - self.ema_momentum) * q_batch_mean)

        q = self.ema_q.expand_as(p)  # broadcast prior shape
        return p, q
    
class EN(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, drop=0.):
        super(EN,self).__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1_conv = nn.Linear(in_features, hidden_features)
        self.fc1_bn = nn.BatchNorm1d(hidden_features)
        self.fc1_lif = neuron.ParametricLIFNode(step_mode='m', backend='cupy')
        self.c_hidden = hidden_features
        self.c_output = out_features
    def forward(self, x):
        T,B,C,N = x.shape
        x = x.transpose(-1, -2)
        x = self.fc1_conv(x.flatten(0,1))
        x = self.fc1_bn( x.transpose(-1, -2)).reshape(T,B,-1,N).contiguous()
        x = self.fc1_lif(x)
        return x




class ITHPd(nn.Module):
    def __init__(self, in_channel=256, lambda1=0.5, lambda2=0.6, alpha=0.05):
        super(ITHPd, self).__init__()
        self.in_channel = in_channel
        self.en1 = EN(self.in_channel,self.in_channel)  
        self.en2 = EN(self.in_channel,self.in_channel) 
        self.pb1 = PB(self.in_channel)
        self.pb2 = PB(self.in_channel)
        self.re1 = SNNEmbeddingWithXOR(self.in_channel)  
        self.re2 = SNNEmbeddingWithXOR(self.in_channel)  
        self.MLP1 = MLP(self.in_channel) 
        self.MLP2 = MLP(self.in_channel)  
        self.MLP3 = MLP(self.in_channel) 
        # Paper notation: lambda1, lambda2, alpha (set as non-learnable)
        self.register_buffer('lambda1', torch.tensor(lambda1, dtype=torch.float32), persistent=True)
        self.register_buffer('lambda2', torch.tensor(lambda2, dtype=torch.float32), persistent=True)
        self.register_buffer('alpha', torch.tensor(alpha, dtype=torch.float32), persistent=True)
        self.spike_neuron = neuron.ParametricLIFNode(step_mode='m', backend='cupy')  
        self.lamda = nn.Parameter(torch.tensor(0.02))  
       
        self.criterion =CosineSimilarity()

    def bernoulli_kl_loss(self, p, q):

        eps = 1e-6 
        p = torch.clamp(p, min=eps, max=1 - eps)
        q = torch.clamp(q, min=eps, max=1 - eps)
        kl_div = p * torch.log(p / q) + (1 - p) * torch.log((1 - p) / (1 - q))
        return kl_div.mean()

    def similarity_loss(self, B1, B2):
        loss = self.criterion(B1,B2)
        return loss

    def forward(self, output, skeleton_semantic, event_semantic,train=False):
       
        B1 = self.en1(output)
        p1, q1 = self.pb1(B1, update_ema=train)
        kl_loss_0 = self.bernoulli_kl_loss(p1, q1)
        B1 = self.re1(B1, train)
        output1 = self.MLP1(B1)
        mse_0 = self.similarity_loss(output1, skeleton_semantic)
        # LDIB,1 = KL - lambda1 * cos
        IB0 = kl_loss_0 - self.lambda1 * mse_0

        # Stage 2
        B2 = self.en2(B1)  # Bug fix: should use B1 as input, not B2
        p2, q2 = self.pb2(B2, update_ema=self.train)
        kl_loss_1 = self.bernoulli_kl_loss(p2, q2)
        B2 = self.re2(B2, train)
        output2 = self.MLP2(B2)
        mse_1 = self.similarity_loss(output2, event_semantic)
        # LDIB,2 = KL - lambda2 * cos
        IB1 = kl_loss_1 - self.lambda2 * mse_1

 
        IB_total = IB0 + IB1


        return B2, IB_total







def main():
 
    device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    model = ITHPd(in_channel=256, p_beta=0.2, p_gamma=0.5, p_lambda=0.3, alpha=0.9).to(device)

    print("Model initialized and moved to device.")

    T, B, C, V = 16, 64, 256, 25
    output = torch.randn(T, B, C, V).to(device) 
    skeleton_semantic = torch.randn(T, B, C, V).to(device)  
    event_semantic = torch.randn(T, B, C, V).to(device)  
    print("Input tensors created and moved to device.")

    try:
        final_output, IB_total, kl_loss_0, mse_0, kl_loss_1, mse_1 = model(output, skeleton_semantic, event_semantic,True)
        print("Model forward pass completed.")
    except Exception as e:
        print("Error during model forward pass:", str(e))
        return

    print("Final Output Shape:", final_output.shape)
    print("Total IB Loss:", IB_total.item())
    print("KL Loss 0:", kl_loss_0.item())
    print("MSE 0:", mse_0.item())
    print("KL Loss 1:", kl_loss_1.item())
    print("MSE 1:", mse_1.item())

if __name__ == "__main__":
    main()