import math
import torch
import torch.nn as nn
import torch.nn.functional as F

import torchvision
from torchvision import transforms
from spikingjelly.activation_based import surrogate
from functional import t_last_vmap_forward, cal_fun_t
from triton_impl import t_frist_triton_impl, t_last_triton_impl


from torch import Tensor
from einops import rearrange
from typing import Tuple, Union, Callable
import numpy as np

_seed_ = 2020
import random
random.seed(2020)

torch.manual_seed(_seed_)  # use torch.manual_seed() to seed the RNG for all devices (both CPU and CUDA)
torch.cuda.manual_seed_all(_seed_)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

@torch.jit.script
def round_to_pow2_forward(x: torch.Tensor):
    sign = torch.sign(x)

    p = torch.log2(x.abs())
    p_round = torch.round(p)
    return sign * torch.pow(2., p_round)

class RoundToPow2(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x: torch.Tensor):
        return round_to_pow2_forward(x)

    @staticmethod
    def backward(ctx, grad_output: torch.Tensor):
        return grad_output

def round_to_pow_2(x: torch.Tensor):
    return RoundToPow2.apply(x)

class Layout(object):
    def __init__(self, layout='t_first'):
        self._layout = None
        self._selected_method = None
        self.methods = {
            "t_first": [],
            "t_last": []
        }
        self.layout = layout

    @property
    def layout(self):
        return self._layout

    @layout.setter
    def layout(self, value):
        assert value in ['t_first', 't_last'], "layout must be 't_first' or 't_last'"
        if self._layout is None or self._layout != value:
            self._layout = value
            self.selected_method = None

    @property
    def selected_method(self):
        return self._selected_method
    
    @selected_method.setter
    def selected_method(self, value):
        assert value in self.methods[self.layout] or value is None, "selected_method must match the layout"
        self._selected_method = value

    def select(self, x_seq, methods, repeat=8):
        cost_times = []
        for method in methods:
            time = cal_fun_t(self.training, repeat, x_seq.device, method, x_seq)
            cost_times.append(time)
        min_idx = cost_times.index(min(cost_times))
        return methods[min_idx], cost_times[min_idx]
    
    def auto_select_methods(self, x_seq):
        """
            which only takes effect the first time after the layout set or changed!
        """
        selected_method, min_time = self.select(x_seq, self.methods[self.layout])
        self.selected_method = selected_method
        return selected_method, min_time



class BN(nn.Module):
    def __init__(
        self,
        num_features: int,
        eps: float = 1e-5,
        momentum: float = 0.1,
        device=None,
        dtype=None,
        layout='t_first'
    ) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        Layout.__init__(self, layout)
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        self.weight = nn.Parameter(torch.ones(num_features, **factory_kwargs))
        self.bias = nn.Parameter(torch.zeros(num_features, **factory_kwargs))
        self.register_buffer('running_mean', torch.zeros(num_features, **factory_kwargs))
        self.register_buffer('running_var', torch.ones(num_features, **factory_kwargs))

    def get_mean_var(self, input: Tensor=None, dim:int=None) -> Tuple[Tensor, Tensor]:
        if self.training:
            dims = [i for i in range(input.dim()) if i != dim]

            mean = input.mean(dims)
            var = input.var(dims, unbiased=False)

            with torch.no_grad():
                self.running_mean = (1. - self.momentum) * self.running_mean + self.momentum * mean
                self.running_var = (1. - self.momentum) * self.running_var + self.momentum * input.var(dims, unbiased=True)
        else:
            mean, var = self.running_mean, self.running_var

        return mean, var


class Mul_Free_Depthwise_PSN(nn.Module, Layout):
    def __init__(self, C: int, K: int, surrogate_function=surrogate.ATan(), dilation: int=1, layout: str="t_first"):
        """
            if layout == 't_first', we thought the input x_seq is shaped as [T, N, ...]
            if layout == 't_last', we thought the input x_seq is shaped as [N, ..., T]
        """
        assert layout in ['t_first', 't_last']
        Layout.__init__(self, layout)
        nn.Module.__init__(self)
    
        self.C = C
        self.K = K
        self.surrogate_function = surrogate_function
        self.dilation = dilation
        self.conv = nn.Conv1d(in_channels=C, out_channels=C, kernel_size=K, groups=C, dilation=dilation, bias=False)
        self.bn = BN(C)
        nn.init.constant_(self.bn.bias, -1.)
        self.methods = {
            # "t_first": [self.t_first_conv1d, self.t_first_conv2d_xxxT, self.t_first_conv2d_Txxx, self.t_first_fc],
            # "t_last": [self.t_last_vmap_conv1d, self.t_last_conv2d, self.t_last_fc]
            "t_first": [self.t_first_fc],
            "t_last": [self.t_last_fc]
        }

        if dilation == 1:
            self.methods['t_first'].append(self.t_first_triton)
            self.methods['t_last'].append(self.t_last_triton)    

    def t_first_conv1d(self, x_seq):
        shape = x_seq.shape
        if shape.__len__() == 5:  
            x_seq = rearrange(x_seq, 'T N C H W -> (N H W) C T')
        elif shape.__len__() == 4: 
            x_seq = rearrange(x_seq, 'T N C L -> (N L) C T')
        else: 
            x_seq = rearrange(x_seq, 'T N C -> N C T')

        x_pad = F.pad(x_seq, (self.dilation * (self.K - 1), 0))

        if self.training:
            t = self.conv(x_pad)
            mean, var = self.bn.get_mean_var(t, dim=1)
        else:
            mean, var = self.bn.get_mean_var(dim=1)

        weight = (self.conv.weight.transpose(0, 2) * self.bn.weight / (var + self.bn.eps).sqrt()).transpose(0, 2)
        qweight = round_to_pow_2(weight)
        bias = self.bn.bias - mean * self.bn.weight / (var + self.bn.eps).sqrt()

        h_seq = F.conv1d(x_pad, qweight, bias=bias, stride=self.conv.stride, padding=self.conv.padding, dilation=self.conv.dilation, groups=self.conv.groups)
        
        s_seq = self.surrogate_function(h_seq)

        if shape.__len__() == 5:
            s_seq = rearrange(s_seq, '(N H W) C T -> T N C H W', N=shape[1], H=shape[3], W=shape[4])
        elif shape.__len__() == 4:
            s_seq = rearrange(s_seq, '(N L) C T -> T N C L', N=shape[1], L=shape[3])
        else:
            s_seq = rearrange(s_seq, 'N C T -> T N C')

        return s_seq

    def t_first_conv2d_xxxT(self, x_seq):
        shape = x_seq.shape
        if shape.__len__() == 5:  
            x_seq = rearrange(x_seq, 'T N C H W -> N C (H W) T')
        elif shape.__len__() == 4: 
            x_seq = rearrange(x_seq, 'T N C L -> N C L T')
        else: 
            x_seq = rearrange(x_seq, 'T N C -> N C 1 T')

        x_pad = F.pad(x_seq, (self.dilation * (self.K - 1), 0))

        if self.training:
            t = F.conv2d(x_pad, weight=self.conv.weight.unsqueeze(-2), bias=self.conv.bias, stride=(1,) + self.conv.stride, padding=(0,) + self.conv.padding, dilation=(1, ) + self.conv.dilation, groups=self.conv.groups)
            mean, var = self.bn.get_mean_var(t, dim=1)
        else:
            mean, var = self.bn.get_mean_var(dim=1)

        weight = (self.conv.weight.transpose(0, 2) * self.bn.weight / (var + self.bn.eps).sqrt()).transpose(0, 2)
        qweight = round_to_pow_2(weight)
        bias = self.bn.bias - mean * self.bn.weight / (var + self.bn.eps).sqrt()

        h_seq = F.conv2d(x_pad, weight=qweight.unsqueeze(-2), bias=bias, stride=(1,) + self.conv.stride, padding=(0,) + self.conv.padding, dilation=(1, ) + self.conv.dilation, groups=self.conv.groups)

        s_seq = self.surrogate_function(h_seq)

        if shape.__len__() == 5:
            s_seq = rearrange(s_seq, 'N C (H W) T -> T N C H W', H=shape[3], W=shape[4])
        elif shape.__len__() == 4:
            s_seq = rearrange(s_seq, 'N C L T -> T N C L')
        else:
            s_seq = rearrange(s_seq, 'N C 1 T -> T N C')
        
        return s_seq

    def t_first_conv2d_Txxx(self, x_seq):
        shape = x_seq.shape
        if shape.__len__() == 5:  
            x_seq = rearrange(x_seq, 'T N C H W -> N C T (H W)')
        elif shape.__len__() == 4: 
            x_seq = rearrange(x_seq, 'T N C L -> N C T L')
        else: 
            x_seq = rearrange(x_seq, 'T N C -> N C T 1')

        x_pad = F.pad(x_seq, (0, 0, self.dilation * (self.K - 1), 0))

        if self.training:
            t = F.conv2d(x_pad, weight=self.conv.weight.unsqueeze(-1), bias=self.conv.bias, stride=self.conv.stride + (1,), padding=self.conv.padding + (0,), dilation=self.conv.dilation + (1,), groups=self.conv.groups)
            mean, var = self.bn.get_mean_var(t, dim=1)
        else:
            mean, var = self.bn.get_mean_var(dim=1)

        weight = (self.conv.weight.transpose(0, 2) * self.bn.weight / (var + self.bn.eps).sqrt()).transpose(0, 2)
        qweight = round_to_pow_2(weight)
        bias = self.bn.bias - mean * self.bn.weight / (var + self.bn.eps).sqrt()

        h_seq = F.conv2d(x_pad, weight=qweight.unsqueeze(-1), bias=bias, stride=self.conv.stride + (1,), padding=self.conv.padding + (0,), dilation=self.conv.dilation + (1,), groups=self.conv.groups)
        

        s_seq = self.surrogate_function(h_seq)

        if shape.__len__() == 5:
            s_seq = rearrange(s_seq, 'N C T (H W) -> T N C H W', H=shape[3], W=shape[4])
        elif shape.__len__() == 4:
            s_seq = rearrange(s_seq, 'N C T L -> T N C L')
        else:
            s_seq = rearrange(s_seq, 'N C T 1 -> T N C')
        
        return s_seq


    def t_first_triton(self, x_seq):
        if self.training:
            t = t_frist_triton_impl(x_seq, self.conv.weight, self.dilation * (self.K - 1), self.conv.bias)
            mean, var = self.bn.get_mean_var(t,dim=2)
        else:
            mean, var = self.bn.get_mean_var(dim=2)

        weight = (self.conv.weight.transpose(0, 2) * self.bn.weight / (var + self.bn.eps).sqrt()).transpose(0, 2)
        qweight = round_to_pow_2(weight)
        bias = self.bn.bias - mean * self.bn.weight / (var + self.bn.eps).sqrt()

        h_seq = t_frist_triton_impl(x_seq, qweight,self.dilation * (self.K - 1), bias)
        s_seq = self.surrogate_function(h_seq)
        
        return s_seq

    def gen_gemm_weight_first(self, T: int, weight: torch.Tensor):
        res = torch.zeros([self.C, T, T], device=weight.device)
        for t in range(T):
            end = t + 1 
            start = max(t%self.dilation, end - self.K*self.dilation)
            length = min((end - start + self.dilation - 1) // self.dilation, self.K)
            res[:, t, start: end: self.dilation] = weight[:, 0, self.K - length:self.K]
        return res

    def gen_gemm_weight_last(self, T: int, weight: torch.Tensor):
        res = torch.zeros([self.C, T, T], device=weight.device)
        for t in range(T):
            end = t + 1 
            start = max(t%self.dilation, end - self.K*self.dilation)
            length = min((end - start + self.dilation - 1) // self.dilation, self.K)
            res[:, start: end: self.dilation, t] = weight[:, 0, self.K - length:self.K]
        return res

    def gen_gemm_weight(self, T:int, weight):
        if self.layout == 't_first':
            return self.gen_gemm_weight_first(T, weight)
        return self.gen_gemm_weight_last(T, weight)


    def t_first_fc(self, x_seq):
        shape = x_seq.shape  # [T, N, C, *]

        def apply_linear(x, w, b: None):
            shape = x.shape 
            y = w@x.flatten(1) 
            if b is not None:
                y += b
            return y.view(shape)

        if self.training:
            weight = self.gen_gemm_weight(shape[0], self.conv.weight)
            t = torch.vmap(apply_linear, in_dims=(2, 0, None), out_dims=2)(x_seq, weight, None).view(shape)
            mean, var = self.bn.get_mean_var(t, dim=2)
        else:
            mean, var = self.bn.get_mean_var()

        weight = (self.conv.weight.transpose(0, 2) * self.bn.weight / (var + self.bn.eps).sqrt()).transpose(0, 2)
        qweight = round_to_pow_2(weight)
        bias = self.bn.bias - mean * self.bn.weight / (var + self.bn.eps).sqrt()

        qweight = self.gen_gemm_weight(shape[0], qweight)
        h_seq = torch.vmap(apply_linear, in_dims=(2, 0, 0), out_dims=2)(x_seq, qweight, bias).view(shape)
        s_seq = self.surrogate_function(h_seq)
        return s_seq

    def t_last_fc(self, x_seq):
        shape = x_seq.shape # [N, C, *, T]

        def apply_linear(x, w, b: None):
            shape = x.shape
            y = x@w 
            if b is not None:
                y += b
            return y.view(shape)

        if self.training:
            weight = self.gen_gemm_weight(shape[-1], self.conv.weight)
            t = torch.vmap(apply_linear, in_dims=(1, 0, None), out_dims=1)(x_seq, weight, None).view(shape)
            mean, var = self.bn.get_mean_var(t, dim=1)
        else:
            mean, var = self.bn.get_mean_var()

        weight = (self.conv.weight.transpose(0, 2) * self.bn.weight / (var + self.bn.eps).sqrt()).transpose(0, 2)
        qweight = round_to_pow_2(weight)
        bias = self.bn.bias - mean * self.bn.weight / (var + self.bn.eps).sqrt()
        qweight = self.gen_gemm_weight(shape[-1], qweight)
        h_seq = torch.vmap(apply_linear, in_dims=(1, 0, 0), out_dims=1)(x_seq, qweight, bias)
        s_seq = self.surrogate_function(h_seq)
        return s_seq

    def t_last_vmap_conv1d(self, x_seq):
        shape = x_seq.shape
        if shape.__len__() == 5:  
            x_seq = rearrange(x_seq, 'N C H W T -> N C (H W) T')
        elif shape.__len__() == 3: 
            x_seq = rearrange(x_seq, 'N C T -> N C 1 T')

        x_pad = F.pad(x_seq, (self.dilation * (self.K - 1), 0)).float()
        
        if self.training:
            t = torch.vmap(self.conv, -2, -2)(x_pad)
            mean, var = self.bn.get_mean_var(t, dim=1)
        else:
            mean, var = self.bn.get_mean_var(dim=1)


        weight = (self.conv.weight.transpose(0, 2) * self.bn.weight / (var + self.bn.eps).sqrt()).transpose(0, 2)
        qweight = round_to_pow_2(weight)
        bias = self.bn.bias - mean * self.bn.weight / (var + self.bn.eps).sqrt()

        def conv1d(x):
            return F.conv1d(x, qweight, bias=bias, stride=self.conv.stride, padding=self.conv.padding, dilation=self.conv.dilation, groups=self.conv.groups)

        h_seq = torch.vmap(conv1d, in_dims=-2, out_dims=-2)(x_pad)

        s_seq = self.surrogate_function(h_seq)

        if shape.__len__() == 5:
            s_seq = rearrange(s_seq, 'N C (H W) T -> N C H W T', H=shape[2], W=shape[3])
        elif shape.__len__() == 3:
            s_seq = rearrange(s_seq, 'N C 1 T -> N C T')

        return s_seq

    def t_last_conv2d(self, x_seq):
        shape = x_seq.shape
        if shape.__len__() == 5:  
            x_seq = rearrange(x_seq, 'N C H W T -> N C (H W) T')
        elif shape.__len__() == 3: 
            x_seq = rearrange(x_seq, 'N C T -> N C 1 T')

        x_pad = F.pad(x_seq, (self.dilation * (self.K - 1), 0))
        
        if self.training:
            t = F.conv2d(x_pad, weight=self.conv.weight.unsqueeze(-2), bias=self.conv.bias, stride=(1,) + self.conv.stride, padding=(0,) + self.conv.padding, dilation=(1, ) + self.conv.dilation, groups=self.conv.groups)
            mean, var = self.bn.get_mean_var(t,dim=1)
        else:
            mean, var = self.bn.get_mean_var(dim=1)

        weight = (self.conv.weight.transpose(0, 2) * self.bn.weight / (var + self.bn.eps).sqrt()).transpose(0, 2)
        qweight = round_to_pow_2(weight)
        bias = self.bn.bias - mean * self.bn.weight / (var + self.bn.eps).sqrt()

        h_seq = F.conv2d(x_pad, weight=qweight.unsqueeze(-2), bias=bias, stride=(1,) + self.conv.stride, padding=(0,) + self.conv.padding, dilation=(1, ) + self.conv.dilation, groups=self.conv.groups)

        s_seq = self.surrogate_function(h_seq)

        if shape.__len__() == 5:
            s_seq = rearrange(s_seq, 'N C (H W) T -> N C H W T', H=shape[2], W=shape[3])
        elif shape.__len__() == 3:
            s_seq = rearrange(s_seq, 'N C 1 T -> N C T')

        return s_seq
    
    def t_last_triton(self, x_seq):
        # N, C, ..., T
        if self.training:
            t = t_last_triton_impl(x_seq, self.conv.weight, self.dilation * (self.K - 1), self.conv.bias)
            mean, var = self.bn.get_mean_var(t, dim=1)
        else:
            mean, var = self.bn.get_mean_var(dim=1)

        weight = (self.conv.weight.transpose(0, 2) * self.bn.weight / (var + self.bn.eps).sqrt()).transpose(0, 2)
        qweight = round_to_pow_2(weight)
        bias = self.bn.bias - mean * self.bn.weight / (var + self.bn.eps).sqrt()

        h_seq = t_last_triton_impl(x_seq, qweight,self.dilation * (self.K - 1), bias)
        s_seq = self.surrogate_function(h_seq)
        
        return s_seq
    
    def forward(self, x_seq: torch.Tensor):
        if self.selected_method is None:
            self.auto_select_methods(x_seq)
        return self.selected_method(x_seq)

class BN_WO_Quantize(nn.BatchNorm2d, Layout):
    def __init__(self, num_features, eps = 0.00001, momentum = 0.1, affine = True, track_running_stats = True, device=None, dtype=None, layout: str="t_first"):
        """
            if layout == 't_first', we thought the input x_seq is shaped as [T, N, C, H, W]
            if layout == 't_last', we thought the input x_seq is shaped as [N, C, H, W, T]
        """
        Layout.__init__(self, layout)
        nn.BatchNorm2d.__init__(self, num_features, eps, momentum, affine, track_running_stats, device, dtype)

        self.methods = {
            "t_first": [self.t_first_fused_TN_impl],
            "t_last": [self.t_last_high_dim_impl]
        }

    def _check_input_dim(self, input):
        pass

    def t_first_fused_TN_impl(self, x_seq):
        y_shape = [x_seq.shape[0], x_seq.shape[1]]
        y = x_seq.flatten(0, 1)
        y = super().forward(y)
        y_shape.extend(y.shape[1:])
        return y.view(y_shape)

    def t_last_high_dim_impl(self, x_seq):
        return super().forward(x_seq)
    
    def forward(self, x_seq):
        if self.selected_method is None:
            self.auto_select_methods(x_seq)
        return self.selected_method(x_seq)

class Mul_Free_Depthwise_PSN_WO_Quantize(nn.Module, Layout):
    def __init__(self, C: int, K: int, surrogate_function=surrogate.ATan(), dilation: int=1, layout: str="t_first"):
        """
            if layout == 't_first', we thought the input x_seq is shaped as [T, N, ...]
            if layout == 't_last', we thought the input x_seq is shaped as [N, ..., T]
        """
        assert layout in ['t_first', 't_last']
        nn.Module.__init__(self)
        self.bn = BN_WO_Quantize(num_features=C)
        Layout.__init__(self, layout)
    
        self.C = C
        self.K = K
        self.surrogate_function = surrogate_function
        self.dilation = dilation
        self.conv = nn.Conv1d(in_channels=C, out_channels=C, kernel_size=K, groups=C, dilation=dilation, bias=False)
        nn.init.constant_(self.bn.bias, -1.)

        self.methods = {
            "t_first": [self.t_first_fc, self.t_first_conv1d, self.t_first_conv2d_xxxT, self.t_first_conv2d_Txxx],
            "t_last": [self.t_last_fc, self.t_last_vmap_conv1d, self.t_last_conv2d]
        }

        if dilation == 1:
            self.methods['t_first'].append(self.t_first_triton)
            self.methods['t_last'].append(self.t_last_triton)    
    
    @Layout.layout.setter
    def layout(self, value):
        assert value in ['t_first', 't_last'], "layout must be 't_first' or 't_last'"
        if self._layout is None or self._layout != value:
            self._layout = value
            self.selected_method = None
            self.bn.layout = value

    def t_first_conv1d(self, x_seq):
        shape = x_seq.shape
        if shape.__len__() == 5:  
            x_seq = rearrange(x_seq, 'T N C H W -> (N H W) C T')
        elif shape.__len__() == 4: 
            x_seq = rearrange(x_seq, 'T N C L -> (N L) C T')
        else: 
            x_seq = rearrange(x_seq, 'T N C -> N C T')

        x_pad = F.pad(x_seq, (self.dilation * (self.K - 1), 0))

        h_seq = self.conv(x_pad)
        if shape.__len__() == 5:
            h_seq = rearrange(h_seq, '(N H W) C T -> T N C H W', N=shape[1], H=shape[3], W=shape[4])
        elif shape.__len__() == 4:
            h_seq = rearrange(h_seq, '(N L) C T -> T N C L', N=shape[1], L=shape[3])
        else:
            h_seq = rearrange(h_seq, 'N C T -> T N C')

        h_seq = self.bn(h_seq)
        s_seq = self.surrogate_function(h_seq)
        return s_seq

    def t_first_conv2d_xxxT(self, x_seq):
        shape = x_seq.shape
        if shape.__len__() == 5:  
            x_seq = rearrange(x_seq, 'T N C H W -> N C (H W) T')
        elif shape.__len__() == 4: 
            x_seq = rearrange(x_seq, 'T N C L -> N C L T')
        else: 
            x_seq = rearrange(x_seq, 'T N C -> N C 1 T')

        x_pad = F.pad(x_seq, (self.dilation * (self.K - 1), 0))

        h_seq = F.conv2d(x_pad, weight=self.conv.weight.unsqueeze(-2), bias=self.conv.bias, stride=(1,) + self.conv.stride, padding=(0,) + self.conv.padding, dilation=(1, ) + self.conv.dilation, groups=self.conv.groups)

        if shape.__len__() == 5:
            h_seq = rearrange(h_seq, 'N C (H W) T -> T N C H W', H=shape[3], W=shape[4])
        elif shape.__len__() == 4:
            h_seq = rearrange(h_seq, 'N C L T -> T N C L')
        else:
            h_seq = rearrange(h_seq, 'N C 1 T -> T N C')

        h_seq = self.bn(h_seq)
        s_seq = self.surrogate_function(h_seq)
        
        return s_seq

    def t_first_conv2d_Txxx(self, x_seq):
        shape = x_seq.shape
        if shape.__len__() == 5:  
            x_seq = rearrange(x_seq, 'T N C H W -> N C T (H W)')
        elif shape.__len__() == 4: 
            x_seq = rearrange(x_seq, 'T N C L -> N C T L')
        else: 
            x_seq = rearrange(x_seq, 'T N C -> N C T 1')

        x_pad = F.pad(x_seq, (0, 0, self.dilation * (self.K - 1), 0))

        h_seq = F.conv2d(x_pad, weight=self.conv.weight.unsqueeze(-1), bias=self.conv.bias, stride=self.conv.stride + (1,), padding=self.conv.padding + (0,), dilation=self.conv.dilation + (1,), groups=self.conv.groups)
        

        if shape.__len__() == 5:
            h_seq = rearrange(h_seq, 'N C T (H W) -> T N C H W', H=shape[3], W=shape[4])
        elif shape.__len__() == 4:
            h_seq = rearrange(h_seq, 'N C T L -> T N C L')
        else:
            h_seq = rearrange(h_seq, 'N C T 1 -> T N C')
        
        h_seq = self.bn(h_seq)
        s_seq = self.surrogate_function(h_seq)
        return s_seq

    def t_first_triton(self, x_seq):
        h_seq = t_frist_triton_impl(x_seq, self.conv.weight, self.dilation * (self.K - 1), self.conv.bias)
        h_seq = self.bn(h_seq)
        s_seq = self.surrogate_function(h_seq)
        
        return s_seq

    def gen_gemm_weight_first(self, T: int):
        weight = torch.zeros([self.C, T, T], device=self.conv.weight.device)
        for t in range(T):
            end = t + 1 
            start = max(t%self.dilation, end - self.K*self.dilation)
            length = min((end - start + self.dilation - 1) // self.dilation, self.K)
            weight[:, t, start: end: self.dilation] = self.conv.weight[:, 0, self.K - length:self.K]
        return weight


    def gen_gemm_weight_last(self, T: int):
        weight = torch.zeros([self.C, T, T], device=self.conv.weight.device)
        for t in range(T):
            end = t + 1 
            start = max(t%self.dilation, end - self.K*self.dilation)
            length = min((end - start + self.dilation - 1) // self.dilation, self.K)
            weight[:, start: end: self.dilation, t] = self.conv.weight[:, 0, self.K - length:self.K]

        return weight
    
    def gen_gemm_weight(self, T:int):
        if self.layout == 't_first':
            return self.gen_gemm_weight_first(T)
        return self.gen_gemm_weight_last(T)

    def t_first_fc(self, x_seq):
        shape = x_seq.shape  # [T, N, C, *]

        def apply_linear(x, w):
            shape = x.shape 
            y = w@x.flatten(1)
            return y.view(shape)

        weight = self.gen_gemm_weight(shape[0])
        h_seq = torch.vmap(apply_linear, in_dims=(2, 0), out_dims=2)(x_seq, weight)
        h_seq = h_seq.view(shape)
        h_seq = self.bn(h_seq)

        s_seq = self.surrogate_function(h_seq)

        return s_seq

    def t_last_fc(self, x_seq):
        shape = x_seq.shape # [N, C, *, T]

        def apply_linear(x, w):
            shape = x.shape
            y = x@w
            return y.view(shape)


        weight = self.gen_gemm_weight(shape[-1])
        h_seq = torch.vmap(apply_linear, in_dims=(1, 0), out_dims=1)(x_seq, weight).view(shape)
        h_seq = self.bn(h_seq)

        s_seq = self.surrogate_function(h_seq)

        return s_seq

    def t_last_vmap_conv1d(self, x_seq):
        shape = x_seq.shape
        if shape.__len__() == 5:  
            x_seq = rearrange(x_seq, 'N C H W T -> N C (H W) T')
        elif shape.__len__() == 3: 
            x_seq = rearrange(x_seq, 'N C T -> N C 1 T')

        x_pad = F.pad(x_seq, (self.dilation * (self.K - 1), 0)).float()
        
        h_seq = torch.vmap(self.conv, -2, -2)(x_pad)

        if shape.__len__() == 5:
            h_seq = rearrange(h_seq, 'N C (H W) T -> N C H W T', H=shape[2], W=shape[3])
        elif shape.__len__() == 3:
            h_seq = rearrange(h_seq, 'N C 1 T -> N C T')

        h_seq = self.bn(h_seq)
        s_seq = self.surrogate_function(h_seq)

        return s_seq

    def t_last_conv2d(self, x_seq):
        shape = x_seq.shape
        if shape.__len__() == 5:  
            x_seq = rearrange(x_seq, 'N C H W T -> N C (H W) T')
        elif shape.__len__() == 3: 
            x_seq = rearrange(x_seq, 'N C T -> N C 1 T')

        x_pad = F.pad(x_seq, (self.dilation * (self.K - 1), 0))
        
        h_seq = F.conv2d(x_pad, weight=self.conv.weight.unsqueeze(-2), bias=self.conv.bias, stride=(1,) + self.conv.stride, padding=(0,) + self.conv.padding, dilation=(1, ) + self.conv.dilation, groups=self.conv.groups)

        if shape.__len__() == 5:
            h_seq = rearrange(h_seq, 'N C (H W) T -> N C H W T', H=shape[2], W=shape[3])
        elif shape.__len__() == 3:
            h_seq = rearrange(h_seq, 'N C 1 T -> N C T')
        
        h_seq = self.bn(h_seq)
        s_seq = self.surrogate_function(h_seq)
        return s_seq
    
    def t_last_triton(self, x_seq):
        # N, C, ..., T
        h_seq = t_last_triton_impl(x_seq, self.conv.weight, self.dilation * (self.K - 1), self.conv.bias)
        h_seq = self.bn(h_seq)
        s_seq = self.surrogate_function(h_seq)
        
        return s_seq
    
    def forward(self, x_seq: torch.Tensor):
        if self.selected_method is None:
            self.auto_select_methods(x_seq)
        return self.selected_method(x_seq)

class TConv2d(nn.Conv2d, Layout):
    def __init__(self, in_channels, out_channels, kernel_size, stride = 1, padding = 0, dilation = 1, groups = 1, bias = True, padding_mode = 'zeros', device=None, dtype=None, layout: str="t_first",):
        """
            if layout == 't_first', we thought the input x_seq is shaped as [T, N, C, H, W]
            if layout == 't_last', we thought the input x_seq is shaped as [N, C, H, W, T]
        """
        Layout.__init__(self, layout)
        nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, device, dtype)

        self.methods = {
            "t_first": [self.t_first_fused_TN_impl],
            "t_last": [self.t_last_vmap_impl, self.t_last_high_dim_impl]
        }

    def t_first_fused_TN_impl(self, x_seq):
        y_shape = [x_seq.shape[0], x_seq.shape[1]]
        y = x_seq.flatten(0, 1)
        y = super().forward(y)
        y_shape.extend(y.shape[1:])
        return y.view(y_shape)
    
    def t_last_vmap_impl(self, x_seq):
        return t_last_vmap_forward(x_seq, super().forward)
        
    def t_last_high_dim_impl(self, x_seq):
        return F.conv3d(x_seq, self.weight.unsqueeze(-1), self.bias, self.stride + (1, ), self.padding + (0, ), self.dilation + (1, ), self.groups)
    
    def forward(self, x_seq):
        if self.selected_method is None:
            self.auto_select_methods(x_seq)
        return self.selected_method(x_seq)

class TConv1d(nn.Conv1d, Layout):
    def __init__(self, in_channels, out_channels, kernel_size, stride = 1, padding = 0, dilation = 1, groups = 1, bias = True, padding_mode = 'zeros', device=None, dtype=None,  layout: str="t_first"):
        """
            if layout == 't_first', we thought the input x_seq is shaped as [T, N, C, L]
            if layout == 't_last', we thought the input x_seq is shaped as [N, C, L, T]
        """
        Layout.__init__(self, layout)
        nn.Conv1d.__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, device, dtype)

        self.methods = {
            "t_first": [self.t_first_fused_TN_impl],
            "t_last": [self.t_last_vmap_impl, self.t_last_high_dim_impl]
        }

    def t_first_fused_TN_impl(self, x_seq):
        y_shape = [x_seq.shape[0], x_seq.shape[1]]
        y = x_seq.flatten(0, 1)
        y = super().forward(y)
        y_shape.extend(y.shape[1:])
        return y.view(y_shape)

    def t_last_vmap_impl(self, x_seq):
        return t_last_vmap_forward(x_seq, super().forward)
        
    def t_last_high_dim_impl(self, x_seq):
        return F.conv2d(x_seq, self.weight.unsqueeze(-1), self.bias, self.stride + (1, ), self.padding + (0, ), self.dilation + (1, ), self.groups)
    
    def forward(self, x_seq):
        if self.selected_method is None:
            self.auto_select_methods(x_seq)
        return self.selected_method(x_seq)
    
class TAvgPool2d(nn.AvgPool2d, Layout):
    def __init__(self, kernel_size, stride = None, padding = 0, ceil_mode = False, count_include_pad = True, divisor_override = None, layout: str="t_first"):
        """
            if layout == 't_first', we thought the input x_seq is shaped as [T, N, C, H, W]
            if layout == 't_last', we thought the input x_seq is shaped as [N, C, H, W, T]
        """
        Layout.__init__(self, layout)
        nn.AvgPool2d.__init__(self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override)

        if isinstance(self.kernel_size, int):
            self.kernel_size = (self.kernel_size, self.kernel_size)
        if isinstance(self.stride, int):
            self.stride = (self.stride, self.stride)
        if isinstance(self.padding, int):
            self.padding = (self.padding, self.padding)
        
        self.methods = {
            "t_first": [self.t_first_fused_TN_impl],
            "t_last": [self.t_last_high_dim_impl]
        }
    
    def t_first_fused_TN_impl(self, x_seq):
        y_shape = [x_seq.shape[0], x_seq.shape[1]]
        y = x_seq.flatten(0, 1)
        y = super().forward(y)
        y_shape.extend(y.shape[1:])
        return y.view(y_shape)

    def t_last_high_dim_impl(self, x_seq):
        return F.avg_pool3d(x_seq, self.kernel_size + (1, ), self.stride + (1, ), self.padding + (0, ))
    
    def forward(self, x_seq):
        if self.selected_method is None:
            self.auto_select_methods(x_seq)
        return self.selected_method(x_seq)
    
class TAvgPool1d(nn.AvgPool1d, Layout):
    def __init__(self, kernel_size, stride = None, padding = 0, ceil_mode = False, count_include_pad = True, layout: str="t_first"):
        """
            if layout == 't_first', we thought the input x_seq is shaped as [T, N, C, L]
            if layout == 't_last', we thought the input x_seq is shaped as [N, C, L, T]
        """
        Layout.__init__(self, layout)
        nn.AvgPool1d.__init__(self, kernel_size, stride, padding, ceil_mode, count_include_pad)

        if isinstance(self.kernel_size, int):
            self.kernel_size = (self.kernel_size, )
        if isinstance(self.stride, int):
            self.stride = (self.stride, )
        if isinstance(self.padding, int):
            self.padding = (self.padding, )

        self.methods = {
            "t_first": [self.t_first_fused_TN_impl],
            "t_last": [self.t_last_high_dim_impl]
        }
    
    def t_first_fused_TN_impl(self, x_seq):
        y_shape = [x_seq.shape[0], x_seq.shape[1]]
        y = x_seq.flatten(0, 1)
        y = super().forward(y)
        y_shape.extend(y.shape[1:])
        return y.view(y_shape)

    def t_first_fused_TN_impl(self, x_seq):
        shape = x_seq.shape
        y_seq = super().forward(x_seq.flatten(0, 1))
        return y_seq.view(shape)
    
    def t_last_high_dim_impl(self, x_seq):
        return F.avg_pool2d(x_seq, self.kernel_size + (1., ), self.stride + (1., ), self.padding + (0., ))
    
    def forward(self, x_seq):
        if self.selected_method is None:
            self.auto_select_methods(x_seq)
        return self.selected_method(x_seq)
    
class TBatchNorm2d(nn.BatchNorm2d, Layout):
    def __init__(self, num_features, eps = 0.00001, momentum = 0.1, affine = True, track_running_stats = True, device=None, dtype=None, layout: str="t_first"):
        """
            if layout == 't_first', we thought the input x_seq is shaped as [T, N, C, H, W]
            if layout == 't_last', we thought the input x_seq is shaped as [N, C, H, W, T]
        """
        Layout.__init__(self, layout)
        nn.BatchNorm2d.__init__(self, num_features, eps, momentum, affine, track_running_stats, device, dtype)

        self.methods = {
            "t_first": [self.t_first_fused_TN_impl],
            "t_last": [self.t_last_high_dim_impl]
        }

    def _check_input_dim(self, input):
        if self.layout == "t_first":
            if input.dim() != 4:
                raise ValueError(f"expected 4D input (got {input.dim()}D input)")
        else:
            if input.dim() != 5:
                raise ValueError(f"expected 5D input (got {input.dim()}D input)")

    def t_first_fused_TN_impl(self, x_seq):
        y_shape = [x_seq.shape[0], x_seq.shape[1]]
        y = x_seq.flatten(0, 1)
        y = super().forward(y)
        y_shape.extend(y.shape[1:])
        return y.view(y_shape)

    def t_last_high_dim_impl(self, x_seq):
        return super().forward(x_seq)
    
    def forward(self, x_seq):
        if self.selected_method is None:
            self.auto_select_methods(x_seq)
        return self.selected_method(x_seq)

class TBatchNorm1d(nn.BatchNorm1d, Layout):
    def __init__(self, num_features, eps = 0.00001, momentum = 0.1, affine = True, track_running_stats = True, device=None, dtype=None, layout: str="t_first"):
        """
            if layout == 't_first', we thought the input x_seq is shaped as [T, N, C, L]
            if layout == 't_last', we thought the input x_seq is shaped as [N, C, L, T]
        """
        Layout.__init__(self, layout)
        nn.BatchNorm1d.__init__(self, num_features, eps, momentum, affine, track_running_stats, device, dtype)

        self.methods = {
            "t_first": [self.t_first_fused_TN_impl],
            "t_last": [self.t_last_high_dim_impl]
        }

    def t_first_fused_TN_impl(self, x_seq):
        y_shape = [x_seq.shape[0], x_seq.shape[1]]
        y = x_seq.flatten(0, 1)
        y = super().forward(y)
        y_shape.extend(y.shape[1:])
        return y.view(y_shape)
    
    def t_last_high_dim_impl(self, x_seq):
        return super().forward(x_seq)
    
    def _check_input_dim(self, input):
        if self.layout == "t_first":
            if input.dim() != 2 and input.dim() != 3:
                raise ValueError(f"expected 2D or 3D input (got {input.dim()}D input)")
        else:
            if input.dim() != 3 and input.dim() != 4:
                raise ValueError(f"expected 3D or 4D input (got {input.dim()}D input)")
        
    def forward(self, x_seq):
        self.auto_select_methods(x_seq)
        return self.selected_method(x_seq)

class TLinear(nn.Linear, Layout):
    def __init__(self, in_features, out_features, bias = True, device=None, dtype=None, layout: str='t_first'):
        Layout.__init__(self, layout)
        nn.Linear.__init__(self, in_features, out_features, bias, device, dtype)

        self.methods = {
            "t_first": [self.t_first_fused_TN_impl],
            "t_last": [self.t_last_vmap_impl, self.t_last_high_dim_impl]
        }
    
    def t_first_fused_TN_impl(self, x_seq):
        y_shape = [x_seq.shape[0], x_seq.shape[1]]
        y = x_seq.flatten(0, 1)
        y = super().forward(y)
        y_shape.extend(y.shape[1:])
        return y.view(y_shape)
    
    def t_last_vmap_impl(self, x_seq):
        return t_last_vmap_forward(x_seq, super().forward)
        
    def t_last_high_dim_impl(self, x_seq):
        return F.conv1d(x_seq, self.weight.unsqueeze(-1), self.bias, stride=1, padding=0)

    def forward(self, x_seq):
        if self.selected_method is None:
            self.auto_select_methods(x_seq)
        return self.selected_method(x_seq)

class TFlatten(nn.Flatten, Layout):
    def __init__(self, start_dim = 1, end_dim = -1, layout: str='t_first'):
        nn.Flatten.__init__(self, start_dim, end_dim)
        Layout.__init__(self, layout)
        self.methods = {
            "t_first": [self.t_first_fused_TN_impl],
            "t_last": [self.t_last_vmap_impl]
        }

    def t_first_fused_TN_impl(self, x_seq):
        y_shape = [x_seq.shape[0], x_seq.shape[1]]
        y = x_seq.flatten(0, 1)
        y = super().forward(y)
        y_shape.extend(y.shape[1:])
        return y.view(y_shape)
    
    def t_last_vmap_impl(self, x_seq):
        return t_last_vmap_forward(x_seq, super().forward)

    def forward(self, x_seq):
        if self.selected_method is None:
            self.auto_select_methods(x_seq)
        return self.selected_method(x_seq)
    
if __name__ == '__main__':
    # conv = TConv2d('t_last', 3, 128, 3, 1, 1).to('cuda:0')
    # x = torch.rand([128, 3, 32, 32, 16]).to('cuda:0')
    # y = conv(x)
    # print(f'y.shape = {y.shape}')
    # print(f'conv.selected_method = {conv.selected_method}')
    # sn = Mul_Free_Depthwise_PSN(128, 2).to('cuda:0')
    # x = torch.rand([16, 128, 128, 32, 32]).to('cuda:0')
    # y = sn(x)
    # print(f'y.shape = {y.shape}')

    sn = Mul_Free_Depthwise_PSN(128, 2).to('cuda:0')
    x = torch.rand([16, 128, 128]).to('cuda:0')
    y = sn(x)
    print(f'y.shape = {y.shape}')