# This piece of code is developed based on
# https://github.com/pytorch/examples/tree/master/imagenet
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from math import sqrt
from math import log
import math
from .eca_module import eca_layer
from timm.models.vision_transformer import default_cfgs, _cfg
from .se_module import SELayer
# from functools import partial
# from layernorm import RMSNorm, layer_norm_fn, rms_norm_fn
# from selective_scan_interface import selective_scan
from einops import rearrange, repeat
# from collections import namedtuple
# from utils.generation import GenerationMixin
# from utils.hf import load_config_hf, load_state_dict_hf
# from dataclasses import dataclass, field


__all__ = ['MAMBA_ResNet', 'mamba_resnet50', 'mamba_resnet50_eca',
           'mamba_resnet101', 'mamba_resnet101_eca',
           'mamba_resnet152', 'mamba_resnet152_eca',
           'mamba_resnext50_32x4d', 'mamba_resnext50_32x4d_se', 'mamba_resnext50_32x4d_eca',
           'mamba_resnext101_32x4d', 'mamba_resnext101_32x4d_se', 'mamba_resnext101_32x4d_eca'
           ]

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# s6la channel k: s6la_channel = 32 (default)

def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=dilation, groups=groups, bias=False, dilation=dilation)

def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


#=========================== define bottleneck ============================
class s6la_Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None, 
                 s6la_channel=32, SE=False, ECA_size=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None, reduction=16):
        super(s6la_Bottleneck, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        # `planes * base_width / 64 * cardinality`
        width = int(planes * (base_width / 64.)) * groups
        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv1x1(inplanes + s6la_channel, width)
        self.bn1 = norm_layer(width)
        self.conv2 = conv3x3(width, width, stride, groups, dilation)
        self.bn2 = norm_layer(width)
        self.conv3 = conv1x1(width, planes * self.expansion)
        self.bn3 = norm_layer(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride
        
        self.averagePooling = None
        if downsample is not None and stride != 1:
            self.averagePooling = nn.AvgPool2d((2, 2), stride=(2, 2))
        
        self.se = None
        if SE:
            self.se = SELayer(planes * self.expansion, reduction)
        
        self.eca = None
        if ECA_size != None:
            self.eca = eca_layer(planes * self.expansion, int(ECA_size))

    def forward(self, x, h):
        identity = x
        
        x = torch.cat((x, h), dim=1)

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)
        
        if self.se != None:
            out = self.se(out)
            
        if self.eca != None:
            out = self.eca(out)
        
        y = out
        
        if self.downsample is not None:
            identity = self.downsample(identity)
        if self.averagePooling is not None:
            h = self.averagePooling(h)
        
        out += identity
        out = self.relu(out)

        return out, y, h

class MAMBA_ResNet(nn.Module):
    '''
    s6la_channel: the number of filters of the shared(recurrent) conv in s6la
    SE: whether use SE or not 
    ECA: None: not use ECA, or specify a list of kernel sizes
    '''
    def __init__(self, block, layers, avg_pool_par=4, num_classes=1000, 
                 s6la_channel=32, SE=False, ECA=None,
                 zero_init_last_bn=True, #zero_init_residual=False,
                 groups=1, drop_rate=0.0, width_per_group=64, replace_stride_with_dilation=None,
                 norm_layer=None):
        super(MAMBA_ResNet, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        self._norm_layer = norm_layer
        self.s6la_channel = s6la_channel
        self.dt_rank = math.ceil(self.s6la_channel / 16)
        self.inplanes = 64
        self.drop_rate = drop_rate
        self.dilation = 1
        self.avg_pool_par = avg_pool_par
        if replace_stride_with_dilation is None:
            # each element in the tuple indicates if we should replace
            # the 2x2 stride with a dilated convolution instead
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError("replace_stride_with_dilation should be None "
                             "or a 3-element tuple, got {}".format(replace_stride_with_dilation))

        # Mamba settings    
        # S4D real initialization
        self.A_log = nn.Parameter(torch.empty(self.s6la_channel*self.avg_pool_par**2, self.avg_pool_par**2, device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')))
        
        self.x_proj = nn.Linear(
            self.s6la_channel* self.avg_pool_par * self.avg_pool_par, self.dt_rank + self.avg_pool_par * self.avg_pool_par * 2, bias=False
        )
        self.dt_proj = nn.Linear(self.dt_rank, self.s6la_channel*self.avg_pool_par**2, bias=True)
        self.D = nn.Parameter(torch.ones(self.s6la_channel * self.avg_pool_par * self.avg_pool_par, device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')))  # Keep in fp32
        self.norm = nn.LayerNorm(self.avg_pool_par**2)
        
        if ECA is None:
            ECA = [None] * 4
        elif len(ECA) != 4:
            raise ValueError("argument ECA should be a 4-element tuple, got {}".format(ECA))
        
        self.flops = False
        # flops: whether compute the flops and params or not
        # when use paras_flops, set as True
        self.groups = groups
        self.base_width = width_per_group
        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = norm_layer(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
        conv_outs = [None] * 4
        recurrent_convs = [None] * 4
        stages = [None] * 4
        stage_bns = [None] * 4
        
        stages[0], stage_bns[0], conv_outs[0] = self._make_layer(block, 64, layers[0], s6la_channel=s6la_channel, SE=SE, ECA_size=ECA[0])
        stages[1], stage_bns[1], conv_outs[1] = self._make_layer(block, 128, layers[1], s6la_channel=s6la_channel, SE=SE, ECA_size=ECA[1], stride=2, dilate=replace_stride_with_dilation[0])
        stages[2], stage_bns[2], conv_outs[2] = self._make_layer(block, 256, layers[2], s6la_channel=s6la_channel, SE=SE, ECA_size=ECA[2], stride=2, dilate=replace_stride_with_dilation[1])
        stages[3], stage_bns[3], conv_outs[3] = self._make_layer(block, 512, layers[3], s6la_channel=s6la_channel, SE=SE, ECA_size=ECA[3], stride=2, dilate=replace_stride_with_dilation[2])
        
        self.conv_outs = nn.ModuleList(conv_outs)
        self.recurrent_convs = nn.ModuleList(recurrent_convs)
        self.stages = nn.ModuleList(stages)
        self.stage_bns = nn.ModuleList(stage_bns)
        self.linear_c = nn.Linear(s6la_channel, s6la_channel*self.avg_pool_par**2)
        self.linear_c_back = nn.Linear(s6la_channel*self.avg_pool_par**2, s6la_channel)
        
        self.tanh = nn.Tanh()
        
        self.bn2 = norm_layer(s6la_channel)
        
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.avg_pool = nn.AdaptiveAvgPool2d((self.avg_pool_par, self.avg_pool_par))
        self.fc = nn.Linear(512 * block.expansion + s6la_channel, num_classes)
        self.conv_hid = nn.Conv2d(self.s6la_channel, self.s6la_channel, kernel_size=1)
        
        self.h_0 = torch.nn.Parameter(torch.zeros(self.s6la_channel,1,1))
        self.h_0 = torch.nn.init.kaiming_normal_(self.h_0)
        
        # self.mambablock = MambaLMHeadModel()

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_last_bn:
        # if zero_init_residual:
            for m in self.modules():
                if isinstance(m, s6la_Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)
                # elif isinstance(m, s6la_BasicBlock):  # not implemented yet
                #     nn.init.constant_(m.bn2.weight, 0)

    def _make_layer(self, block, planes, blocks, 
                    s6la_channel, SE, ECA_size, stride=1, dilate=False):
        
        conv_out = conv1x1(planes * block.expansion, s6la_channel)
        
        norm_layer = self._norm_layer
        downsample = None
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = 1
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                norm_layer(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample, 
                            s6la_channel=s6la_channel, SE=SE, ECA_size=ECA_size, groups=self.groups,
                            base_width=self.base_width, dilation=previous_dilation, norm_layer=norm_layer))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes, 
                                s6la_channel=s6la_channel, SE=SE, ECA_size=ECA_size, groups=self.groups,
                                base_width=self.base_width, dilation=self.dilation,
                                norm_layer=norm_layer))
        bns = [norm_layer(s6la_channel) for _ in range(blocks)]

        return nn.Sequential(*layers), nn.ModuleList(bns), conv_out

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x) # 1 64 56 56

        batch, _, height, width = x.size()
        # self.s6la_channel = s6la_channel
        
        h = self.h_0.unsqueeze(0).expand(batch, -1, height, width).to(device)

        for layers, bns, conv_out in zip(self.stages, self.stage_bns, self.conv_outs):
            for layer, bn in zip(layers, bns):
                x, y, h = layer(x, h) # torch.Size([1, 256, 56, 56]) torch.Size([1, 256, 56, 56]) torch.Size([1, 32, 56, 56])
                y_out = conv_out(y) # torch.Size([1, 32, 56, 56])
                b, c, ori_h, ori_w = y_out.size()
                # feature descriptor on the global spatial information
                y_out = self.avg_pool(y_out) # [b, c, 4, 4]
                y_out = y_out.flatten(1).unsqueeze(1) # [b, 1, 16c]
                y_out = y_out.transpose(1, 2) # [b, 16c, 1]
                h_out = self.avg_pool(h) # [b, c, 4, 4]
                # h_out = self.conv_hid(h_out) # [b, c, 4, 4]
                h_out = h_out.flatten(2) # [b, c, n=16]    
                h_out = self.linear_c(h_out.transpose(1, 2)).transpose(1, 2) # [b, 16c, n]
                
                h_out = self.norm(h_out.to(dtype=self.norm.weight.dtype)) # [b, 16c, n]   
                # torch.isnan(tensor).any()
                # Mamba module updates
                # Mamba parameters         
                x_dbl = self.x_proj(rearrange(y_out, "b c l -> (b l) c"))  # [b*1, c]
                dt, B, C = torch.split(x_dbl, [self.dt_rank, self.avg_pool_par**2, self.avg_pool_par**2], dim=-1)
                dt = self.dt_proj.weight @ dt.t()
                dt = rearrange(dt, "c (b l) -> b c l", l=1) # b c 1
                dt = F.softplus(dt + self.dt_proj.bias[..., None].float())
                # dt = torch.randn_like(dt)
                B = rearrange(B, "(b l) n -> b n l", l=1).contiguous() # b n 1
                C = rearrange(C, "(b l) n -> b n l", l=1).contiguous() # b n 1
                A = -torch.exp(self.A_log.float()) # c n
                deltaA = torch.exp(torch.einsum('bcl,cn->bcln', dt, A)) # b 16c 1 n
                # print("deltaA",deltaA,80*"*")
                # deltaB = torch.einsum('bcl,bnl->bcnl', dt, B).squeeze(-1) # b 16c n
                deltaB_y = torch.einsum('bcl,bnl,bcl->bcln', dt, B, y_out) # b 16c 1 n  
                # seqlen = y_out.shape[2]
                
                # h_output = []
                # h_out = torch.einsum('bcln,bcn->bcn', deltaA, h_out)
                h = deltaA[:, :, 0] * h_out + deltaB_y[:, :, 0] # (batch 16c n)  
                out = torch.einsum('bdn,bn->bd', h, C[:, :, 0]) # b 16c 1
                #if y.is_complex():
                #    y = y.real * 2
                # h_output.append(h_out)
                # h = out + y_out * rearrange(self.D, "d -> d 1") # b c 1
                # print(h,80*"+")
                h = self.linear_c_back(h.transpose(1,2)).transpose(1,2) # b c 16
                h = rearrange(h, "b c (avg n) -> b c avg n", n=self.avg_pool_par).squeeze(-1) # b c 4*4
                h = h.reshape(b, c, self.avg_pool_par, self.avg_pool_par) # b c 4 4
                h = F.interpolate(h, size=(ori_h, ori_w), mode='nearest') # b c 56 56
                h = bn(h)
                h = self.tanh(h) # b c 56 56
        h = self.bn2(h)
        h = self.relu(h)
        x = torch.cat((x, h), dim=1)
        
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)               
        # h = self.bn2(h)  # 64 32 7 7
        # # print(h.shape,80*"=")
        # h = self.relu(h)
        # print(x.shape) # 64 2048 7 7
        # output = torch.cat((x, h), dim=1) # 64 2080 7 7
        # print('output_ori: ',output.shape)
        # output = self.avgpool(output)
        # output = torch.flatten(output, 1)
        # out = self.bn2(out.unsqueeze(-1).unsqueeze(-1)).squeeze(-1).squeeze(-1)
        # out = self.relu(out)
        # if self.drop_rate:
        #     out = F.dropout(out, p=float(self.drop_rate), training=self.training)
        # # print('output: ',output.shape) 
        # output = self.fc(out)

        return x

    # def forward(self, x):
    #     return self._forward_impl(x)
    
def mamba_resnet50(s6la_channel=32):
    """ Constructs a s6la_ResNet-50 model.
    default: 
        num_classes=1000, s6la_channel=32, SE=False, ECA=None
    ECA: a list of kernel sizes in ECA
    """
    print("Constructing s6la_resnet50......")
    model = MAMBA_ResNet(s6la_Bottleneck, [3, 4, 6, 3])
    return model

def mamba_resnet50_eca(s6la_channel=16, k_size=[5, 5, 5, 7]):
    """Constructs a s6la_ResNet-50_ECA model.
    Args:
        k_size: Adaptive selection of kernel size
        s6la_channel: the number of filters of the shared(recurrent) conv in s6la
    """
    print("Constructing s6la_resnet50_eca......")
    model = MAMBA_ResNet(s6la_Bottleneck, [3, 4, 6, 3], s6la_channel=s6la_channel, ECA=k_size)
    return model


def mamba_resnet101(s6la_channel=32):
    """ Constructs a s6la_ResNet-101 model.
    default: 
        num_classes=1000, s6la_channel=32, SE=False, ECA=None
    """
    print("Constructing s6la_resnet101......")
    model = MAMBA_ResNet(s6la_Bottleneck, [3, 4, 23, 3])
    # checkpoint = torch.load('/home/r12user3/Documents/s6laNet/work_dirs/mamba_resnet101_/checkpoint.pth.tar')
    # model.load_state_dict(checkpoint['state_dict'])
    return model

def mamba_resnet101_eca(s6la_channel=32, k_size=[5, 5, 5, 7]):
    """Constructs a s6la_ResNet-101_ECA model.
    Args:
        k_size: Adaptive selection of kernel size
        s6la_channel: the number of filters of the shared(recurrent) conv in s6la
    """
    print("Constructing s6la_resnet101_eca......")
    model = MAMBA_ResNet(s6la_Bottleneck, [3, 4, 23, 3], ECA=k_size)
    return model


def mamba_resnet152(s6la_channel=32):
    """ Constructs a s6la_ResNet-152 model.
    default: 
        num_classes=1000, s6la_channel=32, SE=False, ECA=None
    """
    print("Constructing s6la_resnet152......")
    model = MAMBA_ResNet(s6la_Bottleneck, [3, 8, 36, 3])
    return model

def mamba_resnet152_eca(s6la_channel=32, k_size=[5, 5, 5, 7]):
    """Constructs a s6la_ResNet-101_ECA model.
    Args:
        k_size: Adaptive selection of kernel size
        s6la_channel: the number of filters of the shared(recurrent) conv in s6la
    """
    print("Constructing s6la_resnet101_eca......")
    model = MAMBA_ResNet(s6la_Bottleneck, [3, 8, 36, 3], ECA=k_size)
    return model


def mamba_resnext50_32x4d(s6la_channel=32):
    """ Constructs a s6la_ResNeXt50_32x4d model.
    default: 
        num_classes=1000, s6la_channel=32, SE=False, ECA=None
    """
    print("Constructing s6la_resnext50_32x4d......")
    model = MAMBA_ResNet(s6la_Bottleneck, [3, 4, 6, 3], groups=32, width_per_group=4)
    return model

def mamba_resnext50_32x4d_se(s6la_channel=32):
    """ Constructs a s6la_ResNeXt50_32x4d_SE model.
    default: 
        num_classes=1000, s6la_channel=32, SE=False, ECA=None
    """
    print("Constructing s6la_resnext50_32x4d_se......")
    model = MAMBA_ResNet(s6la_Bottleneck, [3, 4, 6, 3], SE=True, groups=32, width_per_group=4)
    return model

def mamba_resnext50_32x4d_eca(s6la_channel=32, k_size=[5, 5, 5, 7]):
    """Constructs a s6la_ResNeXt50_32x4d_ECA model.
    Args:
        k_size: Adaptive selection of kernel size
        s6la_channel: the number of filters of the shared(recurrent) conv in s6la
    """
    print("Constructing s6la_resnext50_32x4d_eca......")
    model = MAMBA_ResNet(s6la_Bottleneck, [3, 4, 6, 3], ECA=k_size, groups=32, width_per_group=4)
    return model


def mamba_resnext101_32x4d(s6la_channel=32):
    """ Constructs a s6la_ResNeXt101_32x4d model.
    default: 
        num_classes=1000, s6la_channel=32, SE=False, ECA=None
    """
    print("Constructing s6la_resnext101_32x4d......")
    model = MAMBA_ResNet(s6la_Bottleneck, [3, 4, 23, 3], groups=32, width_per_group=4)
    return model

def mamba_resnext101_32x4d_se(s6la_channel=32):
    """ Constructs a s6la_ResNeXt101_32x4d_SE model.
    default: 
        num_classes=1000, s6la_channel=32, SE=False, ECA=None
    """
    print("Constructing s6la_resnext101_32x4d_se......")
    model = MAMBA_ResNet(s6la_Bottleneck, [3, 4, 23, 3], SE=True, groups=32, width_per_group=4)
    return model

def mamba_resnext101_32x4d_eca(s6la_channel=32, k_size=[5, 5, 5, 7]):
    """Constructs a s6la_ResNeXt101_32x4d_ECA model.
    Args:
        k_size: Adaptive selection of kernel size
        s6la_channel: the number of filters of the shared(recurrent) conv in s6la
    """
    print("Constructing s6la_resnext101_32x4d_eca......")
    model = MAMBA_ResNet(s6la_Bottleneck, [3, 4, 23, 3], ECA=k_size, groups=32, width_per_group=4)
    return model
