import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from spikingjelly.activation_based import functional, layer, neuron
try:
    from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn, bimamba_inner_fn, mamba_inner_fn_no_out_proj
except ImportError:
    selective_scan_fn, mamba_inner_fn, bimamba_inner_fn, mamba_inner_fn_no_out_proj = None, None, None, None, None

try:
    from mamba_ssm.ops.triton.selective_state_update import selective_state_update
except ImportError:
    selective_state_update = None
import torch.nn.functional as F
try:
    from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn
except ImportError:
    RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
from mamba_ssm.modules.mamba_simple import Mamba
from mamba_ssm.utils.generation import GenerationMixin
from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf
import random
from einops import rearrange, repeat
try:
    from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn
except ImportError:
    RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None

from .Skeletonbackbone import SGNModel  
from .Eventbackbone import SpikMAMBA 
from .Semanticlearning import SparseSemanticExtractor,MLP
from .ITHU import *

try:
    from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
except ImportError:
    causal_conv1d_fn, causal_conv1d_update = None
from collections import OrderedDict
from torch.distributions import Bernoulli
import numpy as np
import cv2
import matplotlib.pyplot as plt

class GradCAM:
    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer
        self.gradients = None
        self.activation = None

        self.target_layer.register_forward_hook(self.save_activation)
        self.target_layer.register_backward_hook(self.save_gradient)

    def save_activation(self, module, input, output):
        self.activation = output.detach()

    def save_gradient(self, module, grad_input, grad_output):
        self.gradients = grad_output[0].detach()

    def compute_heatmap(self, class_idx):
        self.model.zero_grad()
        score = self.activation[:, class_idx].sum()
        score.backward(retain_graph=True)

        weights = torch.mean(self.gradients, dim=(2, 3), keepdim=True)
        cam = torch.sum(weights * self.activation, dim=1).squeeze(0)
        cam = F.relu(cam)  
        cam = cam - cam.min()
        cam = cam / cam.max()
        return cam.cpu().numpy()

def overlay_heatmap(img, cam):
    cam = cv2.resize(cam, (img.shape[1], img.shape[0]))
    heatmap = cv2.applyColorMap(np.uint8(255 * cam), cv2.COLORMAP_JET)
    return cv2.addWeighted(img, 0.5, heatmap, 0.5, 0)
class MacMamba(Mamba):
    def __init__(self, d_model,d_inner,expand=1,**kwargs):

        super().__init__(d_model,d_inner,expand=1, **kwargs)
        
        self.fc1_bn = nn.BatchNorm1d(self. d_model)
        self.fc1_lif = neuron.ParametricLIFNode(step_mode='m',backend='cupy',v_threshold=0.5)
        self.bn1 = nn.BatchNorm1d(self.d_model)
        self.bn1_lif = neuron.ParametricLIFNode(step_mode='m',backend='cupy',v_threshold=0.5)
        self.bn2 = nn.BatchNorm1d(self.d_model)
        self.bn2_lif = neuron.ParametricLIFNode(step_mode='m',backend='cupy',v_threshold=0.5)

        self.fc2_bn = nn.BatchNorm1d(self.d_model)
        self.fc2_lif = neuron.ParametricLIFNode(step_mode='m',backend='cupy',v_threshold=0.5)
        self.in_proj1 = nn.Linear(self.d_model, self.d_model, bias=self.bias, **self.factory_kwargs)
        self.in_proj2 = nn.Linear(self.d_model, self.d_model, bias=self.bias, **self.factory_kwargs)
    
        
    def forward(self, hidden_states1,hidden_states2,inference_params=None):
        """
        hidden_states: (B, L, D)
        Returns: same shape as hidden_states
        """
        # # hidden_states = hidden_states.permute(1, 0, 2)
        # print(hidden_states.size())
        # print(hidden_states1.size())
        # print(hidden_states2.size())
        hidden_states1 = hidden_states1.permute(0,1,3,2)
        hidden_states2 = hidden_states2.permute(0,1,3,2)
        Times,batch, seqlen, dim = hidden_states1.shape

        conv_state, ssm_state = None, None
        if inference_params is not None:
            conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)
            if inference_params.seqlen_offset > 0:
                # The states are updated inplace
                out, _, _ = self.step(hidden_states1, conv_state, ssm_state)
                return out

        # We do matmul and transpose BLH -> HBL at the same time\
        z = hidden_states2.flatten(0,1)
        z = rearrange(
            self.in_proj2.weight @ rearrange(z, "b l d -> d (b l)"),
            "d (b l) -> b d l",
            l=seqlen,
        )
        # z = self.fc1_bn(z).reshape(Times,batch,-1,seqlen)
        # z = self.fc1_lif(z).flatten(0,1)


        x = hidden_states1.flatten(0,1)
        x = rearrange(
            self.in_proj1.weight @ rearrange(x, "b l d -> d (b l)"),
            "d (b l) -> b d l",
            l=seqlen,
        )
        if self.in_proj.bias is not None:
            x = x + rearrange(self.in_proj.bias.to(dtype=x.dtype), "d -> d 1")

        # x = self.fc1_bn(x).reshape(Times,batch,-1,seqlen)
        # x = self.fc1_lif(x).flatten(0,1)

        A = -torch.exp(self.A_log.float())  # (d_inner, d_state)

   
            # Compute short convolution
        if conv_state is not None:
                # If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
                # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
            conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0)))  # Update state (B D W)
        if causal_conv1d_fn is None:
            x = self.conv1d(x)[..., :seqlen]
            # x = self.bn1(x).reshape(Times,batch,-1,seqlen)
            # x = self.bn1_lif(x).flatten(0,1)

        else:
            x = causal_conv1d_fn(
                    x=x,
                    weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
                    bias=self.conv1d.bias,
                    activation=None,
                )
            # x = self.bn1(x).reshape(Times,batch,-1,seqlen)
            # x = self.bn1_lif(x).flatten(0,1)   


        x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d"))  # (bl d)
        dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
        dt = self.dt_proj.weight @ dt.t()
        dt = rearrange(dt, "d (b l) -> b d l", l=seqlen)
        B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
        C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
        assert self.activation in ["silu", "swish"]
        y = selective_scan_fn(
                x,
                dt,
                A,
                B,
                C,
                self.D.float(),
                z=z,
                delta_bias=self.dt_proj.bias.float(),
                delta_softplus=True,
                return_last_state=ssm_state is not None,
            )
        if ssm_state is not None:
            y, last_state = y
            ssm_state.copy_(last_state)
        y = rearrange(y, "b d l -> b l d")
        out = self.out_proj(y)
        if self.init_layer_scale is not None:
            out = out * self.gamma    
        y = y.permute(0,2,1)
        out = self.fc2_bn(y).reshape(Times,batch,seqlen,-1)
        out = self.fc2_lif(out)
        out = out + hidden_states1+hidden_states2
        out = out.permute(0,1,3,2)
        # print(out.size())
        return out

class BernoulliXORMask(nn.Module):
    def __init__(self, learnable_p=True, initial_p=0.5):

        super(BernoulliXORMask, self).__init__()
        self.learnable_p = learnable_p

        if learnable_p:
            self.logits = nn.Parameter(torch.logit(torch.tensor(initial_p)))
        else:
            self.p = initial_p

    def forward(self, x, training=True):

        if self.learnable_p:
           
            p = torch.sigmoid(self.logits)
        else:
            p = self.p

        if training:
          
            mask = torch.bernoulli(torch.ones_like(x) * p).float()
        else:
         
            mask = torch.zeros_like(x)

    
        output = torch.fmod(x + mask, 2)  
        return output

class EN(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1_conv = nn.Linear(in_features, hidden_features)
        self.fc1_bn = nn.BatchNorm1d(hidden_features)
        self.fc1_lif = neuron.ParametricLIFNode(step_mode='m', backend='cupy')
        self.c_hidden = hidden_features
        self.c_output = out_features
    def forward(self, x):
        T,B,C,N = x.shape
        x = x.transpose(-1, -2)
        x = self.fc1_conv(x.flatten(0,1))
        x = self.fc1_bn( x.transpose(-1, -2)).reshape(T,B,-1,N).contiguous()
        x = self.fc1_lif(x)
        return x



# class ITHP(nn.Module):
#     def __init__(self, in_channel=256, p_beta=0.2, p_gamma=0.5, p_lambda=0.3):
#         super(ITHP, self).__init__()
#         self.in_channel = in_channel
#         self.en1 = EN(self.in_channel)  
#         self.en4 = EN(self.in_channel)  
#         self.MLP1 = MLP(self.in_channel)  
#         self.MLP2 = MLP(self.in_channel)  
#         self.p_beta = p_beta
#         self.p_gamma = p_gamma
#         self.p_lambda = p_lambda

#         self.criterion = nn.MSELoss()

#     def discrete_kl_loss(self, p, q):

#         eps = 1e-8  
#         p = torch.clamp(p, min=eps, max=1 - eps)
#         q = torch.clamp(q, min=eps, max=1 - eps)
#         kl_div = p * (torch.log(p) - torch.log(q)) + (1 - p) * (torch.log(1 - p) - torch.log(1 - q))
#         return torch.mean(kl_div)

#     def reparameterise(self, logits, method="bernoulli"):

#         if method == "bernoulli":
#             probs = torch.sigmoid(logits)  
#             b = Bernoulli(probs).sample()  
#         elif method == "threshold":
#             probs = torch.sigmoid(logits)
#             b = (probs > 0.5).float()  
#         else:
#             raise ValueError("Unknown reparameterisation method: choose 'bernoulli' or 'threshold'")
#         return b

#     def forward(self, output, skeleton_sementic, event_semantic):
        
#         h1 = self.en1(output) 
#         mu1 = torch.sigmoid(h1)  
#         kl_loss_0 = self.discrete_kl_loss(mu1, torch.full_like(mu1, 0.5))  
#         b0 = self.reparameterise(h1, method="bernoulli") 
        
     
#         output1 = self.MLP1(b0)
#         mse_0 = self.criterion(output1, skeleton_sementic) 
#         IB0 = kl_loss_0 + self.p_beta * mse_0 

       
#         h2 = self.en4(b0) 
#         mu2 = torch.sigmoid(h2)  
#         kl_loss_1 = self.discrete_kl_loss(mu2, torch.full_like(mu2, 0.5))  
#         b1 = self.reparameterise(h2, method="bernoulli") 

      
#         output2 = self.MLP2(b1)
#         mse_1 = self.criterion(output2, event_semantic)  
#         IB1 = kl_loss_1 + self.p_gamma * mse_1 

   
#         IB_total = IB0 + IB1 

  
#         final_output = output + output1 + output2
#         return final_output, IB_total, kl_loss_0, mse_0, kl_loss_1, mse_1


class MultiModalFusionNetwork(nn.Module):
    def __init__(self, num_classes, skeleton_config,event_config,skeleton_pretrained_path=None,event_pretrained_path=None,device=None):
        super(MultiModalFusionNetwork, self).__init__()
        

        self.skeleton_backbone = SGNModel(**skeleton_config)  
        self.event_backbone = SpikMAMBA(**event_config) rs=8, init_sparsity=0.1)
        self.share_mlp = MLP(in_features=256)
        self.Skeleton_Sementic =SparseSemanticExtractor(in_channels=256, k_neighbors=5, init_sparsity=0.1)
        self.Infonet= ITHPd(in_channel=256)
        # self.Infonet= ITHP(in_channel=256)
        self.dropout  = nn.Dropout(0.1)
        self.classifier = nn.Linear(256, num_classes)
        nn.init.normal_(self.classifier.weight, 0, math.sqrt(2. / num_classes))
        self.mmamba = MacMamba(d_model=256,d_inner=256)

    def forward(self, skeleton_input, event_input,train=True):

        event_input = event_input.permute(1,0,2,3,4)
        T,B,_,_,_ = event_input.size()
        skeleton_features_0 = self.skeleton_backbone(skeleton_input)
        _,W,event_features_0 = self.event_backbone(event_input)
        event_semantic= self.Event_Sementic(event_features_0)
        skeleton_semantic=self.Skeleton_Sementic(skeleton_features_0)

        event_features = self.share_mlp(event_semantic)+event_features_0
        skeleton_features = self.share_mlp(skeleton_semantic )+skeleton_features_0
           
        output = self.mmamba(skeleton_features,event_features)

        b1, IB_total= self.Infonet(output,event_features,skeleton_features,train)
        # # B0=b1
        # # IB_total=0
        b1 =self.dropout(b1)
        b1 = b1.mean(3).mean(0)
        # out = output.mean(3).mean(0)
        b1 =self.dropout(b1)
        out = self.classifier(b1)
        return out, IB_total
from thop import profile

if __name__ == "__main__":

    skeleton_config = {
        'num_class': 60,
        'num_point': 25,
        'num_person': 2,
        'in_channels': 3,
        'num_frames': 10,
        'num_set': 3,
        'adaptive': True,
        'drop_out': 0.5
    }

    event_config = {
        'img_size_h': 640,
        'img_size_w': 480,
        'patch_size': 16,
        'in_channels': 2,
        'num_classes': 11,
        'embed_dims': 256,
        'num_heads': 4,
        'mlp_ratios': 4,
        'qkv_bias': False,
        'qk_scale': None,
        'drop_rate': 0.,
        'attn_drop_rate': 0.,
        'drop_path_rate': 0.,
        'depths': 4,
        'sr_ratios': 4
    }

    device_id = 0
    device = torch.device(f"cuda:{device_id}")
    print(f"Using device: {device}")

 
    model = MultiModalFusionNetwork(num_classes=10, 
                                    skeleton_config=skeleton_config, 
                                    event_config=event_config).to(device)
    

    skeleton_input = torch.randn(1, 3, 16, 25, 2).to(device) 
    event_input = torch.randn(1, 16, 2, 480, 640).to(device)  


   
    model.eval()
    flops, params = profile(model, inputs=(skeleton_input, event_input))

    print(f"Total FLOPs: {flops / 1e9:.2f} GFLOPs") 
    print(f"Total Params: {params / 1e6:.2f} M") 