import math

import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
from spikingjelly.activation_based import functional, layer, neuron
from collections import namedtuple
try:
    from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
except ImportError:
    causal_conv1d_fn, causal_conv1d_update = None
from einops import rearrange, repeat
try:
    from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn, bimamba_inner_fn, mamba_inner_fn_no_out_proj
except ImportError:
    selective_scan_fn, mamba_inner_fn, bimamba_inner_fn, mamba_inner_fn_no_out_proj = None, None, None, None, None

try:
    from mamba_ssm.ops.triton.selective_state_update import selective_state_update
except ImportError:
    selective_state_update = None
import torch.nn.functional as F
try:
    from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn
except ImportError:
    RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
from mamba_ssm.modules.mamba_simple import Mamba
from mamba_ssm.utils.generation import GenerationMixin
from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf
import random

try:
    from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn
except ImportError:
    RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
import torch.fft
from .multfft import *
def import_class(name):
    components = name.split('.')
    mod = __import__(components[0])
    for comp in components[1:]:
        mod = getattr(mod, comp)
    return mod


def conv_branch_init(conv, branches):
    weight = conv.weight
    n = weight.size(0)
    k1 = weight.size(1)
    k2 = weight.size(2)
    nn.init.normal_(weight, 0, math.sqrt(2. / (n * k1 * k2 * branches)))
    if conv.bias is not None:
        nn.init.constant_(conv.bias, 0)


def conv_init(conv):
    if conv.weight is not None:
        nn.init.kaiming_normal_(conv.weight, mode='fan_out')
    if conv.bias is not None:
        nn.init.constant_(conv.bias, 0)


def bn_init(bn, scale):
    nn.init.constant_(bn.weight, scale)
    nn.init.constant_(bn.bias, 0)

def conv_bn_init(conv, bn, std=0.1, bias=0.1):
    nn.init.normal_(conv.weight, mean=0.0, std=std)  # Increase std to 0.1 for larger initial outputs
    if conv.bias is not None:
        nn.init.constant_(conv.bias, bias)  # Bias is set to 0.1 for positive output shift
    nn.init.constant_(bn.weight, 1.0)  # BatchNorm weight initialized to 1 (no scaling effect)
    
    if bn.bias is not None:
        nn.init.constant_(bn.bias, 0.1)  # Bias
def linear_init(linear, std=0.1, bias=0.1):
    nn.init.normal_(linear.weight, mean=0.0, std=std)  
    if linear.bias is not None:
        nn.init.constant_(linear.bias, bias)  


  



def import_class(name):
    components = name.split('.')
    mod = __import__(components[0])
    for comp in components[1:]:
        mod = getattr(mod, comp)
    return mod


def conv_branch_init(conv, branches):
    weight = conv.weight
    n = weight.size(0)
    k1 = weight.size(1)
    k2 = weight.size(2)
    nn.init.normal_(weight, 0, math.sqrt(2. / (n * k1 * k2 * branches)))
    if conv.bias is not None:
        nn.init.constant_(conv.bias, 0)

def bn_branch_init(bn,branches):
    if isinstance(bn, nn.BatchNorm1d) or isinstance(bn, nn.BatchNorm2d) or isinstance(bn, nn.BatchNorm3d):
    
        nn.init.constant_(bn.weight, 1)
        if bn.bias is not None:
            nn.init.constant_(bn.bias, 0)

def conv_init(conv):
    if conv.weight is not None:
        nn.init.kaiming_normal_(conv.weight, mode='fan_out')
    if conv.bias is not None:
        nn.init.constant_(conv.bias, 0)


def bn_init(bn, scale):
    nn.init.constant_(bn.weight, scale)
    nn.init.constant_(bn.bias, 0)



def import_class(name):
    components = name.split('.')
    mod = __import__(components[0])
    for comp in components[1:]:
        mod = getattr(mod, comp)
    return mod


def conv_branch_init(conv, branches):
    weight = conv.weight
    n = weight.size(0)
    k1 = weight.size(1)
    k2 = weight.size(2)
    nn.init.normal_(weight, 0, math.sqrt(2. / (n * k1 * k2 * branches)))
    if conv.bias is not None:
        nn.init.constant_(conv.bias, 0)


def conv_init(conv):
    if conv.weight is not None:
        nn.init.kaiming_normal_(conv.weight, mode='fan_out')
    if conv.bias is not None:
        nn.init.constant_(conv.bias, 0)


def bn_init(bn, scale):
    nn.init.constant_(bn.weight, scale)
    nn.init.constant_(bn.bias, 0)


class MSA_Conv(nn.Module):
    def __init__(
        self,
        dim,
        num_heads=8,
        qkv_bias=False,
        qk_scale=None,
        attn_drop=0.0,
        proj_drop=0.0,
    ):
        super().__init__()
        self.dim=dim
        self.num_heads = num_heads
        self.scale = 0.125
        self.q_conv = nn.Conv1d(dim, dim, kernel_size=1, stride=1, bias=False)
        self.q_bn = nn.BatchNorm1d(dim)
        self.q_lif = neuron.ParametricLIFNode(step_mode='m',backend='cupy')
        self.k_conv = nn.Conv1d(dim, dim, kernel_size=1, stride=1, bias=False)
        self.k_bn = nn.BatchNorm1d(dim)
        self.k_lif = neuron.ParametricLIFNode(step_mode='m',backend='cupy')
        self.v_conv = nn.Conv1d(dim, dim, kernel_size=1, stride=1, bias=False)
        self.v_bn = nn.BatchNorm1d(dim)
        self.v_lif = neuron.ParametricLIFNode(step_mode='m',backend='cupy')
        self.attn_lif = neuron.ParametricLIFNode(step_mode='m',v_threshold=0.5,backend='cupy')
        self.talking_heads_lif = neuron.ParametricLIFNode(step_mode='m',v_threshold=0.5,backend='cupy')
        self.shortcut_lif = neuron.ParametricLIFNode(step_mode='m',v_threshold=0.5,backend='cupy')
        self.proj_conv = nn.Conv1d(dim, dim, kernel_size=1, stride=1)
        self.proj_bn = nn.BatchNorm1d(dim)

    def forward(self,x):
        T,N,C,V = x.size()
        identity = x
        # O = T*V        
        x = self.shortcut_lif(x)
        x_for_qkv = x.flatten(0, 1)
        q_conv_out = self.q_conv(x_for_qkv)
        q_conv_out = self.q_bn(q_conv_out).reshape(T,N,C,V).contiguous()
        q_conv_out = self.q_lif(q_conv_out)    
        q = q_conv_out.reshape(T,N, self.num_heads, C//self.num_heads,V).permute(0, 1, 3, 2, 4).contiguous()

        k_conv_out = self.k_conv(x_for_qkv)
        k_conv_out = self.k_bn(k_conv_out).reshape(T,N,C,V).contiguous()
        k_conv_out = self.k_lif(k_conv_out) 
    
        k = k_conv_out.reshape(T,N, self.num_heads, C//self.num_heads,V).permute(0, 1, 3, 2, 4).contiguous()

        v_conv_out = self.v_conv(x_for_qkv)
        v_conv_out = self.v_bn(v_conv_out).reshape(T,N,C,V).contiguous()
        v_conv_out = self.v_lif(v_conv_out) 
                  
        v = v_conv_out.reshape(T,N, self.num_heads, C//self.num_heads,V).permute(0, 1, 3, 2, 4).contiguous()
        # T N C/H H V 
        
        x = k.transpose(-2,-1) @ v
        x = (q @ x) * self.scale

        x = x.transpose(3, 4).reshape(T, N, C, V).contiguous()
        x = self.attn_lif(x)
        x = (
            self.proj_bn(self.proj_conv(x.flatten(0, 1))).
            reshape(T, N, C, V).contiguous()
            .contiguous()
        )  
  
        x = x+identity
        return x


class unit_gcn(nn.Module):
    def __init__(self, in_channels, out_channels, A, adaptive=True,Times =10):
        super(unit_gcn, self).__init__()
        self.out_c = out_channels
        self.in_c = in_channels
        self.Times=Times
        self.num_subset = A.shape[0]
        self.adaptive = adaptive
        self.PA = nn.Parameter(torch.from_numpy(A.astype(np.float32)), requires_grad=True)
        self.A = Variable(torch.from_numpy(A.astype(np.float32)), requires_grad=False)
        self.SSA = MSA_Conv(out_channels,8)
        self.conv_d = nn.ModuleList()
        self.bn_d = nn.ModuleList()
        for i in range(self.num_subset):
            self.conv_d.append(nn.Conv1d(in_channels, out_channels, 1))
        for i in range(self.num_subset):
            self.bn_d.append(nn.BatchNorm1d(out_channels))
        if in_channels != out_channels:
            self.down = nn.Sequential(
                nn.Conv1d(in_channels, out_channels, 1),
                nn.BatchNorm1d(out_channels)
            )
        else:
            self.down = lambda x: x

        self.bn = nn.BatchNorm1d(out_channels)
        self.lif1 = neuron.ParametricLIFNode(step_mode='m',backend='cupy')
        self.lif2 = neuron.ParametricLIFNode(step_mode='m',backend='cupy')   
        self.relu = neuron.ParametricLIFNode(step_mode='m',backend='cupy')

        for m in self.modules():
            if isinstance(m, nn.Conv1d):
                conv_init(m)
            elif isinstance(m, nn.BatchNorm1d):
                bn_init(m, 1)
        bn_init(self.bn, 1e-6)
        # for i in range(self.num_subset):
        #     conv_branch_init(self.conv_d[i], self.num_subset)
        for i in range(self.num_subset):
            bn_branch_init(self.bn_d[i], self.num_subset)
    def L2_norm(self, A):
        # A:N,V,V
        A_norm = torch.norm(A, 2, dim=1, keepdim=True) + 1e-4  # N,1,V
        A = A / A_norm
        return A

    def forward(self, x):
        T,N,C,V = x.size()
        y = None
        if self.adaptive:
            PA = self.PA
            A = self.L2_norm(PA)
        for i in range(self.num_subset): 
            A1 = A[i]
            A2 = x.view(T * N, C, V).contiguous()
            m = torch.matmul(A2, A1)
            z = self.conv_d[i](m)
            y = z + y if y is not None else z
        y = self.bn(y).reshape(T,N,-1,V).contiguous()
        y = self.lif1(y)
        x = x.flatten(0,1)
        x = self.down(x)
        x = x.reshape(T,N,-1,V).contiguous()
        x = self.lif2(x)
        y = y+x
        y = self.SSA(y)
        # y = y.flatten(0,1).contiguous()
        return y

class DilatedConvBranch(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=1, dilation=1):
        super(DilatedConvBranch, self).__init__()

        self.conv = nn.Conv1d(in_channels, out_channels, kernel_size)
        self.bn = nn.BatchNorm1d(out_channels) 
        self.lif = neuron.ParametricLIFNode(step_mode='m',v_threshold=0.5,backend='cupy')
    def forward(self, x):
        T,N,C,V = x.size()
        x = x.flatten(0,1)
        x = self.conv(x)
        x = self.bn(x)
        x = x.reshape(T, N, -1, V).contiguous()
        x = self.lif(x)
        return x

class TCN_GCN_unit(nn.Module):
    def __init__(self, in_channels, out_channels, A, stride=1, residual=True, adaptive=True,Times = 10):
        super(TCN_GCN_unit, self).__init__()
        self.Times = Times
        self.shortcut_lif = neuron.ParametricLIFNode(step_mode='m',backend='cupy')
        self.gcn1 = unit_gcn(in_channels, out_channels, A, adaptive=adaptive,Times=self.Times)
        self.fcn = MultiScaleDilatedFourierFeatureExtractorplus(
        in_channels= out_channels,
        out_channels=out_channels,
        kernel_size=3,
        stride=stride,
        dilations=[1,2,3,4],
        residual=residual,
        residual_kernel_size=True
    )
        # self.relu = neuron.ParametricLIFNode(step_mode='m',backend='cupy')
        if (in_channels == out_channels) and (stride == 1):
            self.residual = lambda x: x
        else:
            self.residual = DilatedConvBranch(in_channels, out_channels)

    def forward(self, x):
        T,N,C,V = x.shape
        # M = int(B/self.Times)
        # x = x.reshape(self.Times,M,C,T,V).contiguous()
        # x = x.flatten(0,1).contiguous()
        res = self.residual(x)
        y = self.gcn1(x) 
        y = self.fcn(y)
        y = y+res
        return y
    




class SGNModel(nn.Module):
    def __init__(self, num_class=60, num_point=25, num_person=2, graph=None, graph_args=dict(), in_channels=3,
                 drop_out=0.3, adaptive=True, num_set=3,Times=4,num_frames=16):
        super(SGNModel, self).__init__()

        # if graph is None:
        #     raise ValueError()
        # else:
        #     Graph = import_class(graph)
        #     self.graph = Graph(**graph_args)

        # A = self.graph.A # 3,25,25
        A = np.stack([np.eye(num_point)] * num_set, axis=0)
        self.Times = Times
        self.num_class = num_class
        self.num_point = num_point
        self.pos_embedding = nn.Parameter(torch.randn(1, self.num_point, 128))
        self.to_joint_embedding = nn.Linear(in_channels,128)
        self.data_bn = nn.BatchNorm1d(num_person *128 * num_point)
        self.l1 = TCN_GCN_unit(128, 128, A, residual=False, adaptive=adaptive,Times = self.Times)
        self.l4 = TCN_GCN_unit(128,256, A, adaptive=adaptive,Times = self.Times)
        self.l8 = TCN_GCN_unit(256, 256, A, adaptive=adaptive,Times = self.Times)
        self.l9 = TCN_GCN_unit(256, 256, A, adaptive=adaptive,Times = self.Times)
        self.proj_lif = neuron.ParametricLIFNode(step_mode='m',backend='cupy')
        self.fc = nn.Linear(256, num_class)
        nn.init.normal_(self.fc.weight, 0, math.sqrt(2. / num_class))
        bn_init(self.data_bn, 1)
        if drop_out:
            self.drop_out = nn.Dropout(drop_out)
        else:
            self.drop_out = lambda x: x
        self.proj_conv = nn.Conv1d(
            128,128,
            kernel_size=3,
            stride=1,
            padding=1,
            bias=False,
        )
        self.proj_bn = nn.BatchNorm1d(128)
        self.proj_lif = neuron.ParametricLIFNode(step_mode='m',backend='cupy')
    
        self.out_embedding = nn.Linear(256,256)
        self.out_bn = nn.BatchNorm1d(256)
        self.out_lif = neuron.ParametricLIFNode(step_mode='m',backend='cupy')
    def forward(self, x):
        # print(x.size())
        N, C, T, V, M = x.size()
        x = rearrange(x, 'n c t v m -> n m t v c', m=M, v=V).contiguous()
        x = rearrange(x, 'n m t v c -> (n m t) v c', m=M, v=V).contiguous()
        x = self.to_joint_embedding(x)
        x += self.pos_embedding[:, :self.num_point] 
        x = rearrange(x, '(n m t) v c->n m t v c', m=M,t=T, v=V).contiguous()     
        x = rearrange(x, 'n m t v c-> n (m v c) t', m=M, t=T).contiguous()   
        x = self.data_bn(x)       
        x = x.view(N, M, V, 128, T).permute(0, 1, 3, 4, 2).contiguous().view(N * M, 128, T, V)
        x = x.permute(2,0,1,3) # T N C V
        x = x
        x = x.flatten(0,1)
        x = self.proj_conv(x)
        x = self.proj_bn(x).reshape(T,N*M,-1,V).contiguous()
        x = self.proj_lif(x)
        x= x.reshape(T,N*M,-1,V).contiguous()
        x = self.l1(x)
        x = self.l4(x)
        x = self.l8(x)
        x = self.l9(x)
        x = x.view(T,N, M, -1, V)
        x = x.transpose(-1, -2)
        x = x.mean(2)
        x = self.out_embedding(x.flatten(0,1))
        x = self.out_bn(x.transpose(-1, -2)).reshape(T,N,-1,V).contiguous()
        x = self.out_lif(x)
        return x