import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange


class DropPath(nn.Module):
    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).
    """
    def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob
        self.scale_by_keep = scale_by_keep

    def forward(self, x):
        if self.drop_prob == 0. or not self.training:
            return x
        keep_prob = 1 - self.drop_prob
        shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
        random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
        if keep_prob > 0.0 and self.scale_by_keep:
            random_tensor.div_(keep_prob)
        return x * random_tensor

    def extra_repr(self):
        return f'drop_prob={round(self.drop_prob,3):0.3f}'


class MlpHead(nn.Module):
    def __init__(self, dim, num_classes=1000, mlp_ratio=4, act_type="gelu", drop_rate=0.2):
        super().__init__()
        hidden_features = min(2048, int(mlp_ratio * dim))
        self.fc1 = nn.Linear(dim, hidden_features, bias=False)
        self.norm = nn.BatchNorm1d(hidden_features)
        self.act = Act(hidden_features, act_type)
        self.drop = nn.Dropout(drop_rate)
        self.fc2 = nn.Linear(hidden_features, num_classes, bias=False)

    def forward(self, x):
        x = self.fc1(x)
        x = self.norm(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        return x


class Act(nn.Module):
    def __init__(self, out_planes=None, act_type="gelu", inplace=True):
        super(Act, self).__init__()

        self.act = None
        if act_type == "relu":
            self.act = nn.ReLU(inplace=inplace)
        elif act_type == "prelu":
            self.act = nn.PReLU(out_planes)
        elif act_type == "hardswish":
            self.act = nn.Hardswish(inplace=True)
        elif act_type == "silu":
            self.act = nn.SiLU(inplace=True)
        elif act_type == "gelu":
            self.act = nn.GELU()

    def forward(self, x):
        if self.act is not None:
            x = self.act(x)
        return x


class ConvX(nn.Module):
    def __init__(self, in_planes, out_planes, groups=1, kernel_size=3, stride=1, act_type="gelu"):
        super(ConvX, self).__init__()
        self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, groups=groups, padding=kernel_size//2, bias=False)
        self.norm = nn.BatchNorm2d(out_planes)
        # self.norm = nn.SyncBatchNorm(out_planes)
        self.act = None
        if act_type is not None:
            self.act = Act(out_planes, act_type)

    def forward(self, x):
        out = self.conv(x)
        out = self.norm(out)
        if self.act is not None:
            out = self.act(out)
        return out


class LayerNorm(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, x):
        # with torch.autocast(device_type="cuda", enabled=False):
        #     x = x.float()
        #     var, mean = torch.var_mean(x, dim=self.dim, keepdim=True)
        #     out = (x - mean) / torch.sqrt(var + 1e-6)
        var, mean = torch.var_mean(x, dim=self.dim, keepdim=True)
        out = (x - mean) / torch.sqrt(var + 1e-6)
        return out


class RE(nn.Module):
    def __init__(self, dim, ratio=8, split_size=8):
        super().__init__()
        self.split_size = split_size

        hidden_dim = max(8, dim // ratio)

        self.region = nn.Sequential(
            nn.Conv2d(dim, hidden_dim, kernel_size=1, bias=False),
            nn.BatchNorm2d(hidden_dim),
            # nn.SyncBatchNorm(hidden_dim),
            nn.ReLU(True),
            nn.Conv2d(hidden_dim, dim, kernel_size=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        input = x
        Ho, Wo = x.shape[2:]
        if Ho % self.split_size > 0 or Wo % self.split_size > 0:
            H_pad = (self.split_size - Ho % self.split_size)
            W_pad = (self.split_size - Wo % self.split_size)
            x = F.interpolate(x, size=(Ho+H_pad, Wo+W_pad))

        B, C, H, W = x.shape
        hsize = int(H / self.split_size)
        wsize = int(W / self.split_size)
        hnum = self.split_size 
        wnum = self.split_size

        out = rearrange(x, 'b c (hsize hnum) (wsize wnum) -> b c (hsize wsize) hnum wnum', hsize=hsize, wsize=wsize, hnum=hnum, wnum=wnum).mean(dim=2)
        out = self.region(out)
        out = F.interpolate(out, size=(Ho, Wo))
        return input * out


class GateMLP(nn.Module):
    def __init__(self, in_planes, out_planes, mlp_ratio=1.0, split_size=7):
        super(GateMLP, self).__init__()
        mid_planes = int(out_planes*mlp_ratio)

        self.conv_in = ConvX(in_planes, mid_planes*2, groups=1, kernel_size=1, stride=1, act_type=None)
        self.dw = ConvX(mid_planes, mid_planes, groups=mid_planes, kernel_size=3, stride=1, act_type=None)
        self.re = RE(mid_planes, split_size=split_size)
        self.proj = ConvX(mid_planes, out_planes, groups=1, kernel_size=1, stride=1, act_type=None)
        self.act = Act(act_type="gelu", inplace=False)

    def forward(self, x):
        x = self.conv_in(x)
        x_1, x_2 = torch.chunk(x, dim=1, chunks=2)
        x = self.dw(self.act(x_1)) * x_2
        x = self.re(x)
        x = self.proj(x)
        return x


class Attention(nn.Module):
    def __init__(self, num_head):
        super().__init__()
        self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_head, 1, 1))), requires_grad=True)

    def forward(self, q, k, v):
        logit_scale = torch.clamp(self.logit_scale, max=4.6052).exp()
        dots = q @ k.transpose(-2, -1) * logit_scale        
        attn = dots.softmax(dim=-1)
        out = attn @ v
        return out


class DilatedCNN(nn.Module):
    def __init__(self, planes, split_size=7, num_head=1):
        super().__init__()
        self.planes = planes
        self.num_head = num_head
        self.split_size = split_size

        self.conv_in = ConvX(planes, planes*2, groups=1, kernel_size=1, stride=1, act_type=None)
        self.spe = ConvX(planes, planes, groups=planes, kernel_size=3, stride=1, act_type=None) 
        self.att = Attention(num_head)

        self.act = Act(act_type="gelu")
        self.proj = ConvX(planes, planes, groups=1, kernel_size=1, stride=1, act_type=None)

    def forward(self, x):
        qkv = self.conv_in(x)
        qkv, c = torch.split(qkv, split_size_or_sections=(self.planes, self.planes), dim=1)

        Ho, Wo = x.shape[2:]
        if Ho % self.split_size > 0 or Wo % self.split_size > 0:
            H_pad = (self.split_size - Ho % self.split_size)
            W_pad = (self.split_size - Wo % self.split_size)
            qkv = F.interpolate(qkv, size=(Ho+H_pad, Wo+W_pad))

        B, C, H, W = qkv.shape
        hsize = int(H / self.split_size)
        wsize = int(W / self.split_size)
        hnum = self.split_size
        wnum = self.split_size

        spe = self.spe(c)
        qkv = rearrange(qkv, 'b (h d) (hsize hnum) (wsize wnum) -> b (hnum wnum) h d (hsize wsize)', h=self.num_head, hsize=hsize, wsize=wsize, hnum=hnum, wnum=wnum)

        l2_qkv = F.normalize(qkv, dim=-1)
        out = self.att(l2_qkv, l2_qkv, qkv)

        out = rearrange(out, 'b (hnum wnum) h d (hsize wsize) -> b (h d) (hsize hnum) (wsize wnum)', h=self.num_head, hsize=hsize, wsize=wsize, hnum=hnum, wnum=wnum)
        if Ho % self.split_size > 0 or Wo % self.split_size > 0:
            out = F.interpolate(out, size=(Ho, Wo))
        out = out + spe

        out = self.act(out)
        out = self.proj(out)

        return out


class DownBlock(nn.Module):
    def __init__(self, in_planes, out_planes, mlp_ratio=1.0, drop_path=0.0):
        super(DownBlock, self).__init__()
        mid_planes = int(out_planes*mlp_ratio)

        self.mlp = nn.Sequential(
            ConvX(in_planes, mid_planes, groups=1, kernel_size=1, stride=1),
            ConvX(mid_planes, mid_planes, groups=mid_planes, kernel_size=3, stride=2),
            ConvX(mid_planes, out_planes, groups=1, kernel_size=1, stride=1, act_type=None)
        )

        self.skip = nn.Sequential(
            ConvX(in_planes, in_planes, groups=in_planes, kernel_size=3, stride=2, act_type=None),
            ConvX(in_planes, out_planes, groups=1, kernel_size=1, stride=1, act_type=None)
        )
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

    def forward(self, x):
        x = self.drop_path(self.mlp(x)) + self.skip(x)
        return x


class Block(nn.Module):
    def __init__(self, in_planes, out_planes, mlp_ratio=1.0, split_size=7, num_head=1, drop_path=0.0):
        super(Block, self).__init__()
        self.split_size = split_size
        mid_planes = int(out_planes*mlp_ratio)

        self.mlp = GateMLP(in_planes, out_planes, mlp_ratio, split_size)
        self.dcnn = DilatedCNN(out_planes, split_size=split_size, num_head=num_head)

        self.ln = LayerNorm(dim=1)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

    def forward(self, x):
        x = self.drop_path(self.mlp(self.ln(x))) + x
        x = self.drop_path(self.dcnn(self.ln(x))) + x

        return x


class RaCNN(nn.Module):
    # pylint: disable=unused-variable
    def __init__(self, dims, layers, mlp_ratio=3, split_sizes=[8,4,2,1], num_heads=[2,4,8,16], drop_path_rate=0., num_classes=1000):
        super(RaCNN, self).__init__()
        self.mlp_ratio = mlp_ratio
        self.drop_path_rate = drop_path_rate

        if isinstance(dims, int):
            dims = [dims//2, dims, dims*2, dims*4, dims*8]
        else:
            dims = [dims[0]//2] + dims

        self.first_conv = ConvX(3, dims[0], 1, 3, 2)

        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(layers))]

        self.layer1 = self._make_layers(dims[0], dims[1], layers[0], split_size=split_sizes[0], num_head=num_heads[0], drop_path=dpr[:layers[0]])
        self.layer2 = self._make_layers(dims[1], dims[2], layers[1], split_size=split_sizes[1], num_head=num_heads[1], drop_path=dpr[layers[0]:sum(layers[:2])])
        self.layer3 = self._make_layers(dims[2], dims[3], layers[2], split_size=split_sizes[2], num_head=num_heads[2], drop_path=dpr[sum(layers[:2]):sum(layers[:3])])
        self.layer4 = self._make_layers(dims[3], dims[4], layers[3], split_size=split_sizes[3], num_head=num_heads[3], drop_path=dpr[sum(layers[:3]):sum(layers[:4])])

        head_dim = max(1024, dims[4])
        self.head = ConvX(dims[4], head_dim, 1, 1, 1)
        self.gap = nn.AdaptiveAvgPool2d(1)
        self.classifier = MlpHead(head_dim, num_classes)

        self.init_params(self)

    def _make_layers(self, inputs, outputs, num_block, split_size, num_head, drop_path):
        layers = [DownBlock(inputs, outputs, self.mlp_ratio, drop_path[0])]

        for i in range(1, num_block):
            layers.append(Block(outputs, outputs, self.mlp_ratio, split_size, num_head, drop_path[i]))
            
        return nn.Sequential(*layers)

    def init_params(self, model):
        for name, m in model.named_modules():
            if isinstance(m, nn.Conv2d):
                if 'first' in name:
                    nn.init.normal_(m.weight, 0, 0.01)
                else:
                    nn.init.normal_(m.weight, 0, 1.0 / m.weight.shape[1])
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0.0001)
                nn.init.constant_(m.running_mean, 0)
            elif isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(m.weight, 1)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0.0001)
                nn.init.constant_(m.running_mean, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.first_conv(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        out = self.head(x)
        out = self.gap(out).flatten(1)
        out = self.classifier(out)
        return out


