import torch
import torch.nn as nn
from models.spiking_layer import LIFSpike, ExpandTime
from models.surrogate_module import SurrogateModule
from spikingjelly.clock_driven.neuron import MultiStepIFNode  # , surrogate
from Qtrick_architecture.clock_driven import neuron
from Qtrick_architecture.clock_driven import surrogate
import torch.nn.functional as F


class Quant(torch.autograd.Function):
    @staticmethod
    @torch.cuda.amp.custom_fwd
    def forward(ctx, i, min_value, max_value):
        ctx.min = min_value
        ctx.max = max_value
        ctx.save_for_backward(i)
        return torch.round(torch.clamp(i, min=min_value, max=max_value))

    @staticmethod
    @torch.cuda.amp.custom_fwd
    def backward(ctx, grad_output):
        grad_input = grad_output.clone()
        i, = ctx.saved_tensors
        grad_input[i < ctx.min] = 0
        grad_input[i > ctx.max] = 0
        return grad_input, None, None


class MultiSpike(nn.Module):
    def __init__(
            self,
            min_value=0,
            max_value=4,
            Norm=None,
    ):
        super().__init__()
        if Norm == None:
            self.Norm = max_value
        else:
            self.Norm = Norm
        self.min_value = min_value
        self.max_value = max_value

    @staticmethod
    def spike_function(x, min_value, max_value):
        return Quant.apply(x, min_value, max_value)

    def __repr__(self):
        return f"MultiSpike(Max_Value=self.{self.max_value}, Min_Value={self.min_value}, Norm={self.Norm})"

    def forward(self, x):  # B C H W
        return self.spike_function(x, min_value=self.min_value, max_value=self.self.max_value) / (self.Norm)


class DynamicMultiSpike(nn.Module):
    def __init__(
            self,
            min_value=0,
            max_value=4,
            num_dim=0,
            Norm=None,
    ):
        super().__init__()
        if Norm == None:
            self.Norm = max_value
        else:
            self.Norm = Norm
        self.min_value = min_value
        self.max_value = max_value
        self.scale = nn.Parameter(torch.randn(num_dim, ))

    @staticmethod
    def spike_function(x, min_value, max_value):
        return Quant.apply(x, min_value, max_value)

    def __repr__(self):
        return f"DynamicMultiSpike(Max_Value=self.{self.max_value}, Min_Value={self.min_value}, Norm={self.Norm})"

    def forward(self, x):  # B C H W
        scale = torch.clamp(self.scale, min=0, max=1)
        return self.spike_function(x, min_value=self.min_value * scale, max_value=self.self.max_value * scale) / (self.Norm)


def spatial_dynamic_token(x, score, mask_ratio=0.2):
    T, B, C, N = x.shape
    score_sum = torch.sum(score, dim=[0, 2])  # B,N
    len_keep = int(N * (1 - mask_ratio))
    ids_shuffle = torch.argsort(score_sum, dim=-1, descending=True)
    ids_restore = torch.argsort(ids_shuffle, dim=-1, descending=True)
    # keep the first subset
    x_sort = torch.gather(x, dim=3, index=ids_shuffle.unsqueeze(1).unsqueeze(0).repeat(T, 1, C, 1))  # T,B,C,N
    score_sort = torch.gather(score, dim=3, index=ids_shuffle.unsqueeze(1).repeat(T, 1, C, 1))  # T,B,C,N
    x_unmasked = x_sort[:, :, :, :len_keep]
    score_unmasked = score_sort[:, :, :, :len_keep]
    x_masked = x_sort[:, :, :, len_keep:]
    score_masked = score_sort[:, :, :, len_keep:]

    return x_masked, x_unmasked, score_masked, score_unmasked


def soft_matching(x_src, x_remove, score_src, score_remove):
    scores = score_remove.transpose(-1, -2) @ score_src
    node_max, node_idx = scores.max(dim=-1)
    T, B, C, N2 = x_remove.shape
    index_for_scatter = node_idx.unsqueeze(2).repeat(1, 1, C, 1)
    out = x_src.scatter_add(-1, index_for_scatter, x_remove)
    return out


def fix_temporal_dynamic_token(x, score, fix_prun_T=1):
    T, B, C, N = x.shape
    score_flatten = score.flatten(-2)
    score_sum = torch.sum(score_flatten, dim=-1) # T,B
    len_keep = int(T - fix_prun_T)
    ids_shuffle = torch.argsort(score_sum, dim=0, descending=True)  # T,B

    x_sort = torch.gather(x, dim=0, index=ids_shuffle.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, C, N))  # B,N,C
    x_unprune = x_sort[:len_keep]
    x_prune = x_sort[len_keep:]

    return x_unprune, x_prune


def temporal_dynamic_token(x, mask_ratio=0.25):
    T, B, C, N = x.shape
    x_flatten = x.flatten(-2)
    x_sum = torch.sum(x_flatten, dim=-1) # T,B
    len_keep = int(T * (1 - mask_ratio))
    ids_shuffle = torch.argsort(x_sum, dim=0, descending=True)  # T,B

    x_sort = torch.gather(x, dim=0, index=ids_shuffle.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, C, N))  # B,N,C
    x_unprune = x_sort[:len_keep]
    x_prune = x_sort[len_keep:]

    return x_unprune, x_prune


class SPSV2(nn.Module):
    def __init__(self, img_size=128, downsample_times=4, in_channels=3, embd_dims=256, T=1, max_value=0):
        super(SPSV2, self).__init__()
        self.img_size = img_size
        self.in_channels = in_channels
        self.downsample_times = downsample_times
        self.max_value = int(max_value)
        self.T = T
        self.main_embd_dims = (embd_dims // 8) * 7
        self.short_embd_dims = (embd_dims // 8) * 1

        if downsample_times == 2:
            self.proj_conv1 = nn.Sequential(
                nn.Conv2d(in_channels, embd_dims // 8, kernel_size=3, stride=1, padding=1, bias=False),
                nn.BatchNorm2d(embd_dims // 8))
            self.neuron1 = neuron.MultiStepIFNode(surrogate_function=surrogate.Quant(max_value=max_value))
            self.proj_conv2 = nn.Sequential(
                nn.Conv2d(embd_dims // 8, embd_dims // 4, kernel_size=3, stride=1, padding=1, bias=False),
                nn.BatchNorm2d(embd_dims // 4))
            self.neuron2 = neuron.MultiStepIFNode(surrogate_function=surrogate.Quant(max_value=max_value))
            self.proj_conv3 = nn.Sequential(
                nn.Conv2d(embd_dims // 4, embd_dims // 2, kernel_size=3, stride=1, padding=1, bias=False),
                nn.BatchNorm2d(embd_dims // 2))
            self.neuron3 = neuron.MultiStepIFNode(surrogate_function=surrogate.Quant(max_value=max_value))
            self.proj_conv4 = nn.Sequential(
                nn.Conv2d(embd_dims // 2, embd_dims, kernel_size=3, stride=2, padding=1, bias=False),
                nn.BatchNorm2d(embd_dims))
            self.neuron4 = neuron.MultiStepIFNode(surrogate_function=surrogate.Quant(max_value=max_value))
            self.proj_conv5 = nn.Sequential(
                nn.Conv2d(embd_dims, self.main_embd_dims, kernel_size=3, stride=2, padding=1, bias=False),
                nn.BatchNorm2d(self.main_embd_dims))
        else:
            raise NotImplementedError

        short_stride = 2 ** downsample_times
        self.short_proj1 = nn.Sequential(
            nn.Conv2d(in_channels, self.short_embd_dims, kernel_size=short_stride,
                      stride=short_stride, padding=0, bias=False),
            nn.BatchNorm2d(self.short_embd_dims),
        )

    def forward(self, x):
        short_x1 = self.short_proj1(x.flatten(0, 1))
        x = self.proj_conv1(x.flatten(0, 1))
        x = self.neuron1(x.reshape(self.T, -1, x.shape[1], x.shape[2], x.shape[3])).flatten(0, 1)
        x = self.proj_conv2(x)
        x = self.neuron2(x.reshape(self.T, -1, x.shape[1], x.shape[2], x.shape[3])).flatten(0, 1)
        x = self.proj_conv3(x)
        x = self.neuron3(x.reshape(self.T, -1, x.shape[1], x.shape[2], x.shape[3])).flatten(0, 1)
        x = self.proj_conv4(x)
        x = self.neuron4(x.reshape(self.T, -1, x.shape[1], x.shape[2], x.shape[3])).flatten(0, 1)
        x = self.proj_conv5(x)
        x = torch.cat([short_x1, x], dim=1)
        x = x.flatten(-2)
        x = x.reshape(self.T, -1, x.shape[1], x.shape[2])
        return x


class STM(nn.Module):
    def __init__(self, feature_dim, num_dim, num_head, T, spatial_ratio):
        super().__init__()
        self.num_head = num_head
        self.T = T
        self.num_dim = num_dim
        self.spatial_ratio = spatial_ratio
        self.feature_dim = feature_dim
        self.num_group = feature_dim // num_head
        self.linear_mix = nn.ModuleList([nn.Linear(num_dim, num_dim, bias=False) for _ in range(num_head)])
        # init the weights
        for m in self.linear_mix:
            if isinstance(m, nn.Linear):
                # kaiming init
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

    def forward(self, x):
        x = x.reshape(x.size(0), x.size(1), self.num_head, self.num_group, -1)
        y = torch.zeros_like(x)
        for i in range(self.num_head):
            y[:, :, i, :, :] = self.linear_mix[i](x[:, :, i, :, :])
        x = y.reshape(x.size(0), x.size(1), self.feature_dim, -1)
        return x


class TokenMixer(nn.Module):
    def __init__(self, feature_dim, num_patches, num_head, T, max_value=0, spatial_ratio=0.0):
        super(TokenMixer, self).__init__()
        self.T = T
        self.max_value = int(max_value)
        ratio = 2
        self.mid_dim1 = int(feature_dim * ratio)
        self.neuron_start = neuron.MultiStepIFNode(surrogate_function=surrogate.Quant(max_value=max_value))
        self.fc1 = nn.Sequential(
            nn.Conv1d(feature_dim, self.mid_dim1, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm1d(self.mid_dim1),
        )
        self.neuron1 = neuron.MultiStepIFNode(surrogate_function=surrogate.Quant(max_value=max_value))
        self.attn = STM(self.mid_dim1, num_patches, num_head, T, spatial_ratio)  #
        # self.attn = SSA(self.mid_dim1, num_patches, num_head, T) # change to standard SSA module
        self.neuron_mixer = neuron.MultiStepIFNode(surrogate_function=surrogate.Quant(max_value=max_value))
        self.fc2 = nn.Sequential(
            nn.Conv1d(self.mid_dim1, feature_dim, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm1d(feature_dim),
        )

    def forward(self, x):
        T, B, C, N = x.shape
        x = self.neuron_start(x.reshape(-1, B, x.shape[2], x.shape[3])).flatten(0, 1)
        x = self.fc1(x)
        x = self.neuron1(x.reshape(-1, B, x.shape[1], x.shape[2]))
        x = self.attn(x)
        x = self.neuron_mixer(x).flatten(0, 1)
        x = self.fc2(x).reshape(-1, B, C, N).contiguous()

        return x


class FFN(nn.Module):
    def __init__(self, feature_dim, ratio, T, max_value=0):
        super().__init__()
        self.feature_dim = feature_dim
        self.ratio = ratio
        self.T = T
        self.max_value = int(max_value)
        self.mid_dim = int(feature_dim * ratio)
        self.neuron1 = neuron.MultiStepIFNode(surrogate_function=surrogate.Quant(max_value=max_value))
        self.fc1 = nn.Sequential(
            nn.Conv1d(feature_dim, self.mid_dim, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm1d(self.mid_dim),
        )
        self.neuron2 = neuron.MultiStepIFNode(surrogate_function=surrogate.Quant(max_value=max_value))
        self.fc2 = nn.Sequential(
            nn.Conv1d(self.mid_dim, feature_dim, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm1d(feature_dim),
        )

    def forward(self, x):
        T, B, C, N = x.shape
        x = self.neuron1(x.reshape(-1, B, self.feature_dim, x.shape[3])).flatten(0, 1)
        x = self.fc1(x)
        x = self.neuron2(x.reshape(-1, B, self.mid_dim, x.shape[2])).flatten(0, 1)
        x = self.fc2(x).reshape(-1, B, self.feature_dim, N).contiguous()
        return x


class Encoder(nn.Module):
    def __init__(self, feature_dim, num_pathes, ratio, num_head, T, max_value=0, spatial_ratio=0.0):
        super().__init__()
        self.token_mix = TokenMixer(feature_dim, num_pathes, num_head, T, max_value, spatial_ratio)
        self.channel_mix = FFN(feature_dim, ratio, T, max_value)

    def forward(self, x):
        x = x + self.token_mix(x)
        x = x + self.channel_mix(x)
        return x


class STMixerV3(nn.Module):
    def __init__(self, img_size=128, downsample_times=4, in_channels=3, embd_dims=256,
                 T=1, mlp_ratio=2, depths=6, num_head=8, num_classes=100, sml=False):
        super(STMixerV3, self).__init__()
        self.img_size = img_size
        self.T = T
        self.img_size = img_size
        self.depths = depths
        self.sml = sml
        self.HW = img_size // (2 ** downsample_times)
        self.num_patches = self.HW ** 2
        self.in_channels = in_channels
        if self.in_channels == 3:
            self.expand = ExpandTime(T=T)
        self.max_value = 8
        self.patch_embd = SPSV2(img_size=img_size, downsample_times=downsample_times,
                              in_channels=in_channels, embd_dims=embd_dims, T=T, max_value=self.max_value)

        # self.spatial_ratio = 0.1
        self.temporal_ratio = 1
        self.spatial_ratio = 0.3
        self.token_dim = int(self.num_patches * (1.0 - self.spatial_ratio))

        self.block = nn.ModuleList(
            [Encoder(embd_dims, self.token_dim, mlp_ratio, num_head, T, self.max_value, self.spatial_ratio) for _ in range(depths)]
        )

        self.head = nn.Linear(embd_dims, num_classes)
        self.temporal_lif = neuron.MultiStepIFNode(surrogate_function=surrogate.Quant(max_value=self.max_value))
        self.spatial_lif = neuron.MultiStepIFNode(surrogate_function=surrogate.Quant(max_value=self.max_value))

        # init the weights
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Conv1d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

        if self.sml:
            self.sml1 = SurrogateModule(embd_dims, embd_dims, self.num_patches, num_classes, num_head, num_layer=2, T=T)
            self.sml2 = SurrogateModule(embd_dims, embd_dims, self.num_patches, num_classes, num_head, num_layer=1, T=T)


    def forward_sdt(self, x):
        x = self.patch_embd(x)
        for blk in self.block:
            x = blk(x)
        x = x.mean(dim=-1)
        x = x.reshape(-1, B, x.shape[-1])
        x = self.head(x)
        x = x.mean(dim=0)
        return x

    def forward_sp(self, x):
        T, B, C, H, W = x.shape
        outs = []
        # x = self.expand(x)
        x = self.patch_embd(x)
        outs.append(self.sml1(x))
        if self.training:
            score = self.temporal_lif(x)
            x, x_prune = fix_temporal_dynamic_token(x, score, 3)  # 0.3 3:81.05
        score = self.spatial_lif(x)
        x_prune, x, score_prune, score_src = spatial_dynamic_token(x, score, self.spatial_ratio)
            # x = soft_matching(x, x_prune, score_src, score_prune)
        for bi, blk in enumerate(self.block):
            x = blk(x)

            if bi == 1:
                 outs.append(self.sml2(x))

        x = x.mean(dim=-1)
        x = x.reshape(-1, B, x.shape[-1])
        x = self.head(x)
        x = x.mean(dim=0)
        outs.insert(0, x)
        return outs

    def forward(self, x):
        if self.in_channels == 3:
            x = self.expand(x)
        elif self.in_channels == 2:
            x = x.permute(1, 0, 2, 3, 4)
            x = x.reshape(-1, 2, self.img_size, self.img_size)
        if self.sml:
            return self.forward_sp(x)
        else:
            return self.forward_sdt(x)


if __name__ == '__main__':
    input = torch.ones([4, 1]).cuda()
    input[0] = 1.4
    input[1] = 0.7
    input[2] = 1
    input[3] = 4
    # neuron2 = neuron.Q_IFNode(surrogate_function=surrogate.Quant(max_value=4)).cuda()
    neuron3 = neuron.MultiStepIFNode(surrogate_function=surrogate.Quant(max_value=4), backend='torch').cuda()
    # neuron3 = MultiStepIFNode(tau=2.0, detach_reset=True, backend='cupy', surrogate_function=surrogate.Quant()).cuda()
    print(neuron3.v_threshold)
    # output = neuron2(input)
    output2 = neuron3(input)
    # print(output)
    print(output2)
