from turtle import forward
import torch.nn as nn
import torch
import torch.nn.functional as F
from typing import Optional
import numpy as np
from timm.models.layers import to_2tuple
from einops import rearrange
from torch.cuda.amp import custom_fwd, custom_bwd
from timm.models.layers import trunc_normal_


import math
from typing import Optional, List

class LoRALayer():
    def __init__(
        self, 
        r: int, 
        lora_alpha: int, 
        lora_dropout: float,
        merge_weights: bool,
    ):
        self.r = r
        self.lora_alpha = lora_alpha
        # Optional dropout
        if lora_dropout > 0.:
            self.lora_dropout = nn.Dropout(p=lora_dropout)
        else:
            self.lora_dropout = lambda x: x
        # Mark the weight as unmerged
        self.merged = False
        self.merge_weights = merge_weights



class Linear(nn.Linear, LoRALayer):
    # LoRA implemented in a dense layer
    def __init__(
        self, 
        in_features: int, 
        out_features: int, 
        r: int = 0, 
        lora_alpha: int = 1, 
        lora_dropout: float = 0.,
        fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
        merge_weights: bool = True,
        **kwargs
    ):
        nn.Linear.__init__(self, in_features, out_features, **kwargs)
        LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
                           merge_weights=merge_weights)

        self.fan_in_fan_out = fan_in_fan_out
        # Actual trainable parameters
        if r > 0:
            self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features)))
            self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r)))
            self.scaling = self.lora_alpha / self.r
            # Freezing the pre-trained weight matrix
            self.weight.requires_grad = False
            if self.bias is not None:
                self.bias.requires_grad = False
        self.reset_parameters()
        if fan_in_fan_out:
            self.weight.data = self.weight.data.transpose(0, 1)

    def reset_parameters(self):
        self.use_lora = True
        # nn.Linear.reset_parameters(self)
        trunc_normal_(self.weight, std=.02)
        if self.bias is not None:
            nn.init.constant_(self.bias, 0)
        if hasattr(self, 'lora_A'):
            # initialize B the same way as the default for nn.Linear and A to zero
            # this is different than what is described in the paper but should not affect performance
            nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
            nn.init.zeros_(self.lora_B)
    
    def lora(self, mode: bool = True):
        def T(w):
            return w.transpose(0, 1) if self.fan_in_fan_out else w
        if self.r == 0:
            return
        if mode:
            self.use_lora = True
            if self.training:
                if self.merge_weights and self.merged:
                    # Make sure that the weights are not merged
                    if self.r > 0:
                        self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
                    self.merged = False
            else:
                if self.merge_weights and not self.merged:
                    # Merge the weights and mark it
                    if self.r > 0:
                        self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
                    self.merged = True  
        else:
            self.use_lora = False
            if self.merged and self.r > 0:
                self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
                self.merged = False


    def train(self, mode: bool = True):
        def T(w):
            return w.transpose(0, 1) if self.fan_in_fan_out else w
        nn.Linear.train(self, mode)
        if mode:
            if self.merge_weights and self.merged:
                # Make sure that the weights are not merged
                if self.r > 0:
                    self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
                self.merged = False
        else:
            if self.merge_weights and not self.merged and self.use_lora:
                # Merge the weights and mark it
                if self.r > 0:
                    self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
                self.merged = True       

    def forward(self, x: torch.Tensor):
        def T(w):
            return w.transpose(0, 1) if self.fan_in_fan_out else w
        if self.r > 0 and not self.merged and self.use_lora:
            result = F.linear(x, T(self.weight), bias=self.bias)            
            result += (self.lora_dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) * self.scaling if self.use_lora else 0
            return result
        else:
            return F.linear(x, T(self.weight), bias=self.bias)



# class ConvLoRA(nn.Module, LoRALayer):
#     def __init__(self, conv_module, in_channels, out_channels, kernel_size, r=0, lora_alpha=1, lora_dropout=0., merge_weights=True, **kwargs):
#         super(ConvLoRA, self).__init__()
#         self.conv = conv_module(in_channels, out_channels, kernel_size, **kwargs)
#         LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights)
#         assert isinstance(kernel_size, int) or len(kernel_size) <= 2
#         if len(kernel_size) == 1:
#             kernel_size = kernel_size[0]
#         # Actual trainable parameters
#         if r > 0 and isinstance(kernel_size, int):
#             self.lora_A = nn.Parameter(
#                 self.conv.weight.new_zeros((r * kernel_size, in_channels * kernel_size))
#             )
#             self.lora_B = nn.Parameter(
#               self.conv.weight.new_zeros((out_channels//self.conv.groups*kernel_size, r*kernel_size))
#             )
#             self.scaling = self.lora_alpha / self.r
#             # Freezing the pre-trained weight matrix
#             self.conv.weight.requires_grad = False
#             if self.conv.bias is not None:
#                 self.conv.bias.requires_grad = False
#         elif r > 0:
#             self.lora_A = nn.Parameter(
#                 self.conv.weight.new_zeros((r * kernel_size[0], in_channels * kernel_size[0]))
#             )
#             self.lora_B = nn.Parameter(
#               self.conv.weight.new_zeros((out_channels//self.conv.groups*kernel_size[1], r*kernel_size[0]))
#             )
#             self.scaling = self.lora_alpha / self.r
#             # Freezing the pre-trained weight matrix
#             self.conv.weight.requires_grad = False
#             if self.conv.bias is not None:
#                 self.conv.bias.requires_grad = False

#         self.reset_parameters()
#         self.merged = False

#     def reset_parameters(self):
#         self.use_lora = True
#         self.conv.reset_parameters()
#         if hasattr(self, 'lora_A'):
#             # initialize A the same way as the default for nn.Linear and B to zero
#             nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
#             nn.init.zeros_(self.lora_B)

#     def lora(self, mode: bool = True):
#         def T(w):
#             return w.transpose(0, 1) if self.fan_in_fan_out else w
#         if self.r == 0:
#             return
#         if mode:
#             self.use_lora = True
#             if self.training:
#                 if self.merge_weights and self.merged:
#                     # Make sure that the weights are not merged
#                     if self.r > 0:
#                         self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
#                     self.merged = False
#             else:
#                 if self.merge_weights and not self.merged:
#                     # Merge the weights and mark it
#                     if self.r > 0:
#                         self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
#                     self.merged = True  
#         else:
#             self.use_lora = False
#             if self.merged:
#                 self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
#                 self.merged = False

#     def train(self, mode=True):
#         super(ConvLoRA, self).train(mode)
#         if mode:
#             if self.merge_weights and self.merged:
#                 if self.r > 0:
#                     # Make sure that the weights are not merged
#                     self.conv.weight.data -= (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling
#                 self.merged = False
#         else:
#             if self.merge_weights and not self.merged and self.use_lora:
#                 if self.r > 0:
#                     # Merge the weights and mark it
#                     self.conv.weight.data += (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling
#                 self.merged = True

#     def forward(self, x):
#         if self.r > 0 and not self.merged:
#             return self.conv._conv_forward(
#                 x, 
#                 self.conv.weight + (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling,
#                 self.conv.bias
#             )
#         return self.conv(x)



class Conv2d(nn.Conv2d, LoRALayer):
    # LoRA implemented in a dense layer
    def __init__(
        self, 
        in_channels: int, 
        out_channels: int, 
        kernel_size,
        r: int = 0, 
        lora_alpha: int = 1, 
        lora_dropout: float = 0.,
        merge_weights: bool = True,
        **kwargs
    ):
        nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, **kwargs)
        LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
                           merge_weights=merge_weights)

        assert isinstance(kernel_size, int) or len(kernel_size) <= 2
        if len(kernel_size) == 1:
            kernel_size = kernel_size[0]
        # Actual trainable parameters
        if r > 0 and isinstance(kernel_size, int):
            self.lora_A = nn.Parameter(
                self.weight.new_zeros((r * kernel_size, in_channels * kernel_size))
            )
            self.lora_B = nn.Parameter(
              self.weight.new_zeros((out_channels//self.groups*kernel_size, r*kernel_size))
            )
            self.scaling = self.lora_alpha / self.r
            # Freezing the pre-trained weight matrix
            self.weight.requires_grad = False
            if self.bias is not None:
                self.bias.requires_grad = False
        elif r > 0:
            self.lora_A = nn.Parameter(
                self.weight.new_zeros((r * kernel_size[0], in_channels * kernel_size[0]))
            )
            self.lora_B = nn.Parameter(
              self.weight.new_zeros((out_channels//self.groups*kernel_size[1], r*kernel_size[0]))
            )
            self.scaling = self.lora_alpha / self.r
            # Freezing the pre-trained weight matrix
            self.weight.requires_grad = False
            if self.bias is not None:
                self.bias.requires_grad = False

        self.reset_parameters()
        self.merged = False

    def reset_parameters(self):
        self.use_lora = True
        nn.Conv2d.reset_parameters(self)
        if hasattr(self, 'lora_A'):
            # initialize A the same way as the default for nn.Linear and B to zero
            nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
            nn.init.zeros_(self.lora_B)

    def lora(self, mode: bool = True):
        if self.r == 0:
            return
        if mode:
            self.use_lora = True
            if self.training:
                if self.merge_weights and self.merged:
                    # Make sure that the weights are not merged
                    if self.r > 0:
                        self.weight.data -= (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling
                    self.merged = False
            else:
                if self.merge_weights and not self.merged:
                    # Merge the weights and mark it
                    if self.r > 0:
                        self.weight.data += (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling
                    self.merged = True  
        else:
            self.use_lora = False
            if self.merged and self.r > 0:
                self.weight.data -= (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling
                self.merged = False

    def train(self, mode=True):
        nn.Conv2d.train(self, mode)
        if mode:
            if self.merge_weights and self.merged:
                if self.r > 0:
                    # Make sure that the weights are not merged
                    self.weight.data -= (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling
                self.merged = False
        else:
            if self.merge_weights and not self.merged and self.use_lora:
                if self.r > 0:
                    # Merge the weights and mark it
                    self.weight.data += (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling
                self.merged = True

    def forward(self, x):
        if self.r > 0 and not self.merged and self.use_lora:
            return self._conv_forward(
                x, 
                self.weight + (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling,
                self.bias
            )
        return self._conv_forward(x, self.weight, self.bias)

class ConvTranspose2d(nn.ConvTranspose2d, LoRALayer):
    # LoRA implemented in a dense layer
    def __init__(
        self, 
        in_channels: int, 
        out_channels: int, 
        kernel_size,
        r: int = 0, 
        lora_alpha: int = 1, 
        lora_dropout: float = 0.,
        merge_weights: bool = True,
        **kwargs
    ):
        nn.ConvTranspose2d.__init__(self, in_channels, out_channels, kernel_size, **kwargs)
        LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
                           merge_weights=merge_weights)

        assert isinstance(kernel_size, int) or len(kernel_size) <= 2
        if len(kernel_size) == 1:
            kernel_size = kernel_size[0]
        # Actual trainable parameters
        if r > 0 and isinstance(kernel_size, int):
            self.lora_A = nn.Parameter(
                self.weight.new_zeros((r * kernel_size, in_channels * kernel_size))
            )
            self.lora_B = nn.Parameter(
              self.weight.new_zeros((out_channels//self.groups*kernel_size, r*kernel_size))
            )
            self.scaling = self.lora_alpha / self.r
            # Freezing the pre-trained weight matrix
            self.weight.requires_grad = False
            if self.bias is not None:
                self.bias.requires_grad = False
        elif r > 0:
            self.lora_A = nn.Parameter(
                self.weight.new_zeros((r * kernel_size[0], in_channels * kernel_size[0]))
            )
            self.lora_B = nn.Parameter(
              self.weight.new_zeros((out_channels//self.groups*kernel_size[1], r*kernel_size[0]))
            )
            self.scaling = self.lora_alpha / self.r
            # Freezing the pre-trained weight matrix
            self.weight.requires_grad = False
            if self.bias is not None:
                self.bias.requires_grad = False

        self.reset_parameters()
        self.merged = False

    def reset_parameters(self):
        self.use_lora = True
        nn.ConvTranspose2d.reset_parameters(self)
        if hasattr(self, 'lora_A'):
            # initialize A the same way as the default for nn.Linear and B to zero
            nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
            nn.init.zeros_(self.lora_B)

    def lora(self, mode: bool = True):
        if self.r == 0:
            return
        if mode:
            self.use_lora = True
            if self.training:
                if self.merge_weights and self.merged:
                    # Make sure that the weights are not merged
                    if self.r > 0:
                        self.weight.data -= (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling
                    self.merged = False
            else:
                if self.merge_weights and not self.merged:
                    # Merge the weights and mark it
                    if self.r > 0:
                        self.weight.data += (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling
                    self.merged = True  
        else:
            self.use_lora = False
            if self.merged and self.r > 0:
                self.weight.data -= (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling
                self.merged = False

    def train(self, mode=True):
        nn.ConvTranspose2d.train(self, mode)
        if mode:
            if self.merge_weights and self.merged:
                if self.r > 0:
                    # Make sure that the weights are not merged
                    self.weight.data -= (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling
                self.merged = False
        else:
            if self.merge_weights and not self.merged and self.use_lora:
                if self.r > 0:
                    # Merge the weights and mark it
                    self.weight.data += (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling
                self.merged = True

    def forward(self, x):
        if self.r > 0 and not self.merged:
            return self._conv_forward(
                x, 
                self.weight + (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling,
                self.bias
            )
        return self._conv_forward(x, self.weight, self.bias)


# class ConvTranspose2d(ConvLoRA):
#     def __init__(self, *args, **kwargs):
#         super(Conv2d, self).__init__(nn.ConvTranspose2d, *args, **kwargs)

# class Conv2d(ConvLoRA):
#     def __init__(self, *args, **kwargs):
#         super(Conv2d, self).__init__(nn.Conv2d, *args, **kwargs)

# class Conv1d(ConvLoRA):
#     def __init__(self, *args, **kwargs):
#         super(Conv1d, self).__init__(nn.Conv1d, *args, **kwargs)

# # Can Extend to other ones like this

# class Conv3d(ConvLoRA):
#     def __init__(self, *args, **kwargs):
#         super(Conv3d, self).__init__(nn.Conv3d, *args, **kwargs)


class Swish(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(x)


class GroupNorm(nn.Module):
    def __init__(self, in_channels):
        super(GroupNorm, self).__init__()
        self.gn = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)

    def forward(self, x):
        return self.gn(x)



class attn_norm(nn.Module):
    def __init__(self, dim=-1, method='softmax') -> None:
        super().__init__()
        if method == 'softmax':
            self.attn_norm = nn.Softmax(dim=dim)
        elif method == 'squared_relu':
            self.attn_norm = nn.ReLU()
        elif method == 'softmax_plus':
            self.attn_norm = nn.Softmax(dim=dim)

        self.method = method

    def forward(self, x):
        if self.method == 'softmax':
            return self.attn_norm(x)
        else:
            mask = x > -torch.inf / 10
            l = x.shape[-1]
            if self.method == 'squared_relu':
                return self.attn_norm(x)**2
            elif self.method == 'softmax_plus':
                scale = np.log(l)/np.log(512) * mask + 1 - mask * 1
                return self.attn_norm(x * scale)



def drop_path_f(x, drop_prob: float = 0., training: bool = False):
    """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
    This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
    the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
    changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
    'survival rate' as the argument.
    """
    if drop_prob == 0. or not training:
        return x
    keep_prob = 1 - drop_prob
    shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
    random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
    random_tensor.floor_()  # binarize
    output = x.div(keep_prob) * random_tensor
    return output


class DropPath(nn.Module):
    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).
    """
    def __init__(self, drop_prob=None):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        return drop_path_f(x, self.drop_prob, self.training)


def window_partition(x, window_size: tuple):
    """
    将feature map按照window_size划分成一个个没有重叠的window
    Args:
        x: (B, H, W, C)
        window_size (tuple): window size(Wt, Wh, Ww)
    Returns:
        windows: (num_windows*B, window_size, C)
    """
    if len(window_size) == 3:
        B, T, H, W, C = x.shape
        x = x.view(B, T // window_size[0], window_size[0], H // window_size[1], window_size[1], W // window_size[2], window_size[2], C)
        # permute: [B, H//Mh, Mh, W//Mw, Mw, C] -> [B, H//Mh, W//Mh, Mw, Mw, C]
        # view: [B, H//Mh, W//Mw, Mh, Mw, C] -> [B*num_windows, Mh, Mw, C]
        windows = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(-1, window_size[0], window_size[1], window_size[2], C)
    elif len(window_size) == 2:
        B, H, W, C = x.shape
        x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C)
        # permute: [B, H//Mh, Mh, W//Mw, Mw, C] -> [B, H//Mh, W//Mh, Mw, Mw, C]
        # view: [B, H//Mh, W//Mw, Mh, Mw, C] -> [B*num_windows, Mh, Mw, C]
        windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0], window_size[1], C)
    return windows



def window_reverse(windows, window_size, T=1, H=1, W=1):
    """
    将一个个window还原成一个feature map
    Args:
        windows: (num_windows*B, window_size, window_size, C)
        window_size (int): Window size(M)
        H (int): Height of image
        W (int): Width of image
    Returns:
        x: (B, H, W, C)
    """
    if len(window_size) == 3:
        B = int(windows.shape[0] / (T * H * W / window_size[0] / window_size[1] / window_size[2]))
        # view: [B*num_windows, Mh, Mw, C] -> [B, H//Mh, W//Mw, Mh, Mw, C]
        x = windows.view(B, T // window_size[0], H // window_size[1], W // window_size[2], window_size[0], window_size[1], window_size[2], -1)
        # permute: [B, H//Mh, W//Mw, Mh, Mw, C] -> [B, H//Mh, Mh, W//Mw, Mw, C]
        # view: [B, H//Mh, Mh, W//Mw, Mw, C] -> [B, H, W, C]
        x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous().view(B, T, H, W, -1)
    elif len(window_size) == 2:
        B = int(windows.shape[0] / (H * W / window_size[0] / window_size[1]))
        # view: [B*num_windows, Mh, Mw, C] -> [B, H//Mh, W//Mw, Mh, Mw, C]
        x = windows.view(B, H // window_size[0], W // window_size[1], window_size[0], window_size[1], -1)
        # permute: [B, H//Mh, W//Mw, Mh, Mw, C] -> [B, H//Mh, Mh, W//Mw, Mw, C]
        # view: [B, H//Mh, Mh, W//Mw, Mw, C] -> [B, H, W, C]
        x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
    return x


class ScaleOffset(nn.Module):
    def __init__(self, dim, scale=True, offset=True) -> None:
        super().__init__()
        if scale:
            self.gamma = nn.Parameter(torch.zeros(dim))
            nn.init.normal_(self.gamma, std=.02)
        else:
            self.gamma = None
        if offset:
            self.beta = nn.Parameter(torch.zeros(dim))
        else:
            self.beta = None
    
    def forward(self, input):
        if self.gamma is not None:
            output = input * self.gamma
        else:
            output = input
        if self.beta is not None:
            output = output + self.beta
        else:
            output = output
        
        return output


class PatchEmbed(nn.Module):
    """
    2D Image to Patch Embedding
    """
    def __init__(self, patch_size=[1, 1, 1], in_c=3, embed_dim=96, norm_layer=None):
        super().__init__()
        self.patch_size = patch_size
        self.in_chans = in_c
        self.embed_dim = embed_dim
        if len(patch_size) == 2:
            self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size)
        elif len(patch_size) == 3:
            self.proj = nn.Conv3d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()

    def forward(self, x):
        if len(self.patch_size) == 3:
            _, _, T, H, W = x.shape
        elif len(self.patch_size) == 2:
            _, _, H, W = x.shape

        # 下采样patch_size倍
        x = self.proj(x)
        # _, _, T, H, W = x.shape
        # flatten: [B, C, H, W] -> [B, C, HW]
        # transpose: [B, C, HW] -> [B, HW, C]
        x = x.flatten(2).transpose(1, 2)
        x = self.norm(x)
        if len(self.patch_size) == 3:
            return x, T//self.patch_size[-3], H//self.patch_size[-2], W//self.patch_size[-1]
        elif len(self.patch_size) == 2:
            return x, 1, H//self.patch_size[-2], W//self.patch_size[-1]


class Mlp(nn.Module):
    """ MLP as used in Vision Transformer, MLP-Mixer and related networks
    """
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features

        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.drop1 = nn.Dropout(drop)
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop2 = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop1(x)
        x = self.fc2(x)
        x = self.drop2(x)
        return x


class SElayer(nn.Module):
    def __init__(self, dim, reduction=4) -> None:
        super().__init__()

        hidden_dim = dim // reduction
        self.channel_conv1 = nn.Conv2d(dim, hidden_dim, 1, 1, 0)
        self.act1 = nn.ReLU()
        self.channel_conv2 = nn.Conv2d(hidden_dim, dim, 1, 1, 0)
        self.act2 = nn.Sigmoid()

    def forward(self, x):
        y = torch.mean(x, dim=[2, 3], keepdim=True)
        y = self.channel_conv1(y)
        y = self.act1(y)
        y = self.channel_conv2(y)
        x = x*self.act2(y)
        return x



class PeriodicPad2d(nn.Module):
    """ 
        pad longitudinal (left-right) circular 
        and pad latitude (top-bottom) with zeros
    """
    def __init__(self, pad_width):
       super(PeriodicPad2d, self).__init__()
       self.pad_width = to_2tuple(pad_width)

    def forward(self, x):
        # pad left and right circular
        out = F.pad(x, (self.pad_width[1], self.pad_width[1], 0, 0), mode="circular") 
        # pad top and bottom zeros
        out = F.pad(out, (0, 0, self.pad_width[0], self.pad_width[0]), mode="constant", value=0) 
        return out




class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ResidualBlock, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.block = nn.Sequential(
            GroupNorm(in_channels),
            Swish(),
            PeriodicPad2d(1),
            nn.Conv2d(in_channels, out_channels, 3, 1, 0),
            GroupNorm(out_channels),
            Swish(),
            PeriodicPad2d(1),
            nn.Conv2d(out_channels, out_channels, 3, 1, 0)
        )
        if in_channels != out_channels:
            self.channel_up = nn.Conv2d(in_channels, out_channels, 1, 1, 0)

    def forward(self, x):
        if self.in_channels != self.out_channels:
            return self.block(x) + self.channel_up(x)
        else:
            return x + self.block(x)


class UpSampleBlock(nn.Module):
    def __init__(self, channels):
        super(UpSampleBlock, self).__init__()
        self.up = nn.ConvTranspose2d(channels, channels, kernel_size=(2, 2), stride=(2, 2))
        self.ppad = PeriodicPad2d(1)
      
        self.conv = nn.Conv2d(channels, channels, 3, 1, 0)

    def forward(self, x):
        # x = F.interpolate(x, scale_factor=2.)
        x = self.up(x)
        x = self.ppad(x)
        return self.conv(x)


class DownSampleBlock(nn.Module):
    def __init__(self, channels):
        super(DownSampleBlock, self).__init__()
        self.ppad = PeriodicPad2d(1)
        self.conv = nn.Conv2d(channels, channels, 3, 2, 0)

    def forward(self, x):
        x = self.ppad(x)
        return self.conv(x)


class NonLocalBlock(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.in_channels = in_channels

        self.norm = GroupNorm(in_channels)
        self.q = torch.nn.Conv2d(in_channels, in_channels, 1, 1, 0)
        self.k = torch.nn.Conv2d(in_channels, in_channels, 1, 1, 0)
        self.v = torch.nn.Conv2d(in_channels, in_channels, 1, 1, 0)
        self.proj_out = torch.nn.Conv2d(in_channels, in_channels, 1, 1, 0)

    def forward(self, x):
        h_ = self.norm(x)
        q = self.q(h_)
        k = self.k(h_)
        v = self.v(h_)

        b, c, h, w = q.shape

        q = q.reshape(b, c, h * w)
        q = q.permute(0, 2, 1)
        k = k.reshape(b, c, h * w)
        v = v.reshape(b, c, h * w)

        attn = torch.bmm(q, k)
        attn = attn * (int(c) ** (-0.5))
        attn = F.softmax(attn, dim=2)

        attn = attn.permute(0, 2, 1)
        A = torch.bmm(v, attn)
        A = A.reshape(b, c, h, w)

        A = self.proj_out(A)

        return x + A


class PatchExpand(nn.Module):
    def __init__(self, dim, dim_scale=2, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.expand = nn.Linear(dim, 2*dim, bias=False) if dim_scale==2 else nn.Identity()
        self.norm = norm_layer(dim // dim_scale)

    def forward(self, x):
        """
        x: B, H, W, C
        """
        x = self.expand(x)
        B, H, W, C = x.shape

        x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=2, p2=2, c=C//4)
        # x = x.view(B,-1,C//4)
        x= self.norm(x)

        return x



class PatchMerging(nn.Module):
    r""" Patch Merging Layer.
    Args:
        dim (int): Number of input channels.
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """

    def __init__(self, dim, window_length, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.window_length = window_length
        if window_length == 3:
            self.reduction = nn.Linear(8 * dim, 2 * dim, bias=False)
            self.norm = norm_layer(8 * dim)
        elif window_length == 2:
            self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
            self.norm = norm_layer(4 * dim)

    def forward(self, x, T, H, W):
        """
        x: B, H*W, C
        """
        B, L, C = x.shape
        if self.window_length == 3:
            assert L == T * H * W, "input feature has wrong size"
            if T > 1:
                x = x.view(B, T, H, W, C)
            else:
                x = x.view(B, H, W, C)
        elif self.window_length == 2:
            assert L == H * W, "input feature has wrong size"
            x = x.view(B, H, W, C)

        if len(x.shape) == 5:
            x0 = x[:, 0::2, 0::2, 0::2, :]  # [B, H/2, W/2, C]
            x1 = x[:, 1::2, 0::2, 0::2, :]  # [B, H/2, W/2, C]
            x2 = x[:, 0::2, 1::2, 0::2, :]  # [B, H/2, W/2, C]
            x3 = x[:, 1::2, 1::2, 0::2, :]  # [B, H/2, W/2, C]
            x4 = x[:, 0::2, 0::2, 1::2, :]  # [B, H/2, W/2, C]
            x5 = x[:, 1::2, 0::2, 1::2, :]  # [B, H/2, W/2, C]
            x6 = x[:, 0::2, 1::2, 1::2, :]  # [B, H/2, W/2, C]
            x7 = x[:, 1::2, 1::2, 1::2, :]  # [B, H/2, W/2, C]
            x = torch.cat([x0, x1, x2, x3, x4, x5, x6, x7], -1)  # [B, H/2, W/2, 4*C]
            x = x.view(B, -1, 8 * C)  # [B, H/2*W/2, 4*C]
        elif len(x.shape) == 4:
            x0 = x[:, 0::2, 0::2, :]  # [B, H/2, W/2, C]
            x1 = x[:, 1::2, 0::2, :]  # [B, H/2, W/2, C]
            x2 = x[:, 0::2, 1::2, :]  # [B, H/2, W/2, C]
            x3 = x[:, 1::2, 1::2, :]  # [B, H/2, W/2, C]
            x = torch.cat([x0, x1, x2, x3], -1)  # [B, H/2, W/2, 4*C]
            x = x.view(B, -1, 4 * C)  # [B, H/2*W/2, 4*C]
        else:
            print(x.shape)

        x = self.norm(x)
        x = self.reduction(x)  # [B, H/2*W/2, 2*C]

        return x


class scale_before_norm(torch.autograd.Function):
    """See linear_with_grad_accumulation_and_async_allreduce"""

    @staticmethod
    @custom_fwd
    def forward(ctx, input):
        ctx.scale = input.std(dim=-1, keepdim=True)
        # ctx.scale = (input.max(dim=-1, keepdim=True)[0] - input.min(dim=-1, keepdim=True)[0]) / 2
        ctx.scale.clamp(min=1, max=10)
        
        input /= ctx.scale

        return input

    @staticmethod
    @custom_bwd
    def backward(ctx, grad_output):
        # grad_output /= ctx.scale
        return grad_output, 



def scale_before_norm_func(
    input: torch.Tensor
) -> torch.Tensor:
    args = [
        input
    ]

    return scale_before_norm.apply(*args)


class scale_norm(nn.Module):
    def __init__(self, features, norm_type="layernorm", use_scale=False, **kwargs) -> None:
        super().__init__()
        self.use_scale = use_scale
        self.norm_type = norm_type

        if norm_type == "layernorm":
            self.norm = nn.LayerNorm(features, **kwargs)
        elif norm_type == "instancenorm":
            self.norm = nn.InstanceNorm1d(features, **kwargs)
        else:
            raise NotImplementedError

    def forward(self, input_data):
        H, W = 0, 0
        if len(input_data.shape) == 4:
            B, H, W, C = input_data.shape
        else:
            B, L, C = input_data.shape

        if self.norm_type == "instancenorm":
            input_data = input_data.reshape(B, -1, C).permute(0, 2, 1)
        if self.use_scale:
            input_data = scale_before_norm_func(input_data)

        output_data = self.norm(input_data)
        if self.norm_type == "instancenorm":
            output_data = output_data.permute(0, 2, 1)
        if H == 0:
            output_data = output_data.reshape(B, L, C)
        else:
            output_data = output_data.reshape(B, H, W, C)
        
        return output_data


class TimestepEmbedder(nn.Module):
    """
    Embeds scalar timesteps into vector representations.
    """
    def __init__(self, hidden_size, frequency_embedding_size=256, 
                 r=0, lora_alpha=1, lora_dropout=0,
                 fan_in_fan_out=False, merge_weights=True):
        super().__init__()
        self.mlp = nn.Sequential(
            Linear(frequency_embedding_size, hidden_size, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, fan_in_fan_out=fan_in_fan_out, merge_weights=merge_weights, bias=True),
            nn.SiLU(),
            Linear(hidden_size, hidden_size, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, fan_in_fan_out=fan_in_fan_out, merge_weights=merge_weights, bias=True)
        )
        self.frequency_embedding_size = frequency_embedding_size
    def lora(self, mode=True):
        for child in self.children():
            if hasattr(child, "lora"):
                child.lora(mode)
    @staticmethod
    def timestep_embedding(t, dim, max_period=10000):
        """
        Create sinusoidal timestep embeddings.
        :param t: a 1-D Tensor of N indices, one per batch element.
                          These may be fractional.
        :param dim: the dimension of the output.
        :param max_period: controls the minimum frequency of the embeddings.
        :return: an (N, D) Tensor of positional embeddings.
        """
        # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
        half = dim // 2
        freqs = torch.exp(
            -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
        ).to(device=t.device)
        args = t[:, None].float() * freqs[None]
        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
        if dim % 2:
            embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
        return embedding

    def forward(self, t):
        if isinstance(t, list):
            t_freq = []
            for i in range(len(t)):
                t_freq.append(self.timestep_embedding(t[i], self.frequency_embedding_size//len(t)))
            t_freq = torch.cat(t_freq, dim=-1)
        else:
            t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
        t_emb = self.mlp(t_freq)
        return t_emb

