import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from timm.models.layers import to_2tuple

from basicsr.ops.layernorm import LayerNorm2d
from basicsr.utils.registry import ARCH_REGISTRY

from .omnisr_arch import MBConvResidual, SqueezeExcitation



def MBConv(
    dim_in, dim_out, *, downsample, expansion_rate=4, shrinkage_rate=0.25, dropout=0.0
):
    hidden_dim = int(expansion_rate * dim_out)
    stride = 2 if downsample else 1

    net = nn.Sequential(
        nn.Conv2d(dim_in, hidden_dim, 1),
        nn.SiLU(inplace=True),
        nn.Conv2d(
            hidden_dim, hidden_dim, 3, stride=stride, padding=1, groups=hidden_dim
        ),
        nn.SiLU(inplace=True),
        SqueezeExcitation(hidden_dim, shrinkage_rate=shrinkage_rate),
        nn.Conv2d(hidden_dim, dim_out, 1),
    )

    if dim_in == dim_out and not downsample:
        net = MBConvResidual(net, dropout=dropout)

    return net


class ESA(nn.Module):
    def __init__(self, n_feats, conv=nn.Conv2d):
        super(ESA, self).__init__()
        f = max(n_feats // 4, 16)
        self.conv1 = conv(n_feats, f, kernel_size=1)
        self.conv_f = conv(f, f, kernel_size=1)
        # self.conv_max = conv(f, f, kernel_size=3, stride=1, padding=1)
        self.conv2 = conv(f, f, kernel_size=3, stride=2, padding=0)
        self.conv3 = conv(f, f, kernel_size=3, stride=1, padding=1)
        # self.conv3_ = conv(f, f, kernel_size=3, stride=1, padding=1)
        self.conv4 = conv(f, n_feats, kernel_size=1, bias=False)
        self.sigmoid = nn.Sigmoid()
        self.act = nn.SiLU(inplace=True)

    def forward(self, x):
        c1_ = self.conv1(x)
        c1 = self.conv2(c1_)
        v_max = F.max_pool2d(c1, kernel_size=7, stride=3)
        # v_range = self.act(self.conv_max(v_max))
        c3 = self.act(self.conv3(v_max))
        # c3 = self.conv3_(c3)
        c3 = F.interpolate(
            c3, (x.size(2), x.size(3)), mode="bilinear", align_corners=False
        )
        cf = self.conv_f(c1_)
        c4 = self.conv4(c3 + cf)
        m = self.sigmoid(c4)
        return x * m


class SharedAttention(nn.Module):
    def __init__(self, channels, window_size, calc_attn=True, shift=True) -> None:
        super().__init__()
        self.ws = window_size
        self.shift = shift
        if calc_attn:
            self.quad_scale = nn.Parameter(torch.FloatTensor([0.125]))
        if not calc_attn:
            self.aff = nn.Identity()

    def forward(self, qkv, pre_attn_v=None):
        b, c, h, w = qkv.shape

        if pre_attn_v is None:
            if self.shift:
                qkv = torch.roll(
                    qkv, shifts=(-self.ws[0] // 2, -self.ws[1] // 2), dims=(2, 3)
                )

            quad_q, quad_k, v = rearrange(
                qkv,
                "b (n c) (h dh) (w dw) -> n b (h w) (dh dw) c",
                n=3,
                dh=self.ws[0],
                dw=self.ws[1],
            )

            attn = (
                (quad_q * self.quad_scale) @ quad_k.transpose(-1, -2).contiguous()
            ).softmax(dim=-1)
            out = attn @ v

            out = rearrange(
                out,
                "b (h w) (dh dw) c -> b c (h dh) (w dw)",
                h=h // self.ws[0],
                w=w // self.ws[1],
                dh=self.ws[0],
                dw=self.ws[1],
            )

            if self.shift:
                out = torch.roll(
                    out, shifts=(self.ws[0] // 2, self.ws[1] // 2), dims=(2, 3)
                )

            return out, {"attn": attn, "v": qkv.chunk(3, dim=1)[-1]}
        else:
            attn = pre_attn_v["attn"]

            v = pre_attn_v["v"] + self.aff(qkv)

            if self.shift:
                v = torch.roll(
                    v,
                    shifts=(-self.ws[0] // 2, -self.ws[1] // 2),
                    dims=(2, 3),
                )
            v = rearrange(
                v,
                "b c (h dh) (w dw) -> b (h w) (dh dw) c",
                dh=self.ws[0],
                dw=self.ws[1],
            )

            out = rearrange(
                attn @ v,
                "b (h w) (dh dw) c -> b c (h dh) (w dw)",
                h=h // self.ws[0],
                w=w // self.ws[1],
                dh=self.ws[0],
                dw=self.ws[1],
            )

            if self.shift:
                out = torch.roll(
                    out, shifts=(self.ws[0] // 2, self.ws[1] // 2), dims=(2, 3)
                )

            return out, None


# class MLP(nn.Module):
#     def __init__(self, channels) -> None:
#         super().__init__()
#         # self.norm = nn.BatchNorm2d(channels)
#         self.norm = LayerNorm2d(channels)
#         # self.fc1 = ShiftConv2d1x1(channels, channels * 2, 1, shift_mode="+")
#         self.fc1 = nn.Conv2d(channels, channels * 2, 1)
#         self.act = nn.SiLU(inplace=True)
#         # self.dw1 = nn.Conv2d(channels * 2, channels * 2, 3, 1, 1, groups=channels * 2)
#         self.fc2 = nn.Conv2d(channels * 2, channels, 1)
#         # self.fc2 = ShiftConv2d1x1(channels * 2, channels, 1, shift_mode="+")

#     def forward(self, x):
#         shortcut, x = x, self.fc1(self.norm(x))
#         x = self.act(x)
#         # x = self.dw1(x)
#         x = self.fc2(x)
#         return x + shortcut


class GatedMLP(nn.Module):
    def __init__(self, channels) -> None:
        super().__init__()
        self.norm = LayerNorm2d(channels)
        # self.fc1 = ShiftConv2d1x1(channels, channels * 2, 1, shift_mode="+")
        self.fc1 = nn.Conv2d(channels, channels * 2, 1)
        self.act = nn.GELU()
        # self.act = nn.Identity()
        self.dw1 = nn.Conv2d(channels * 2, channels * 2, 3, 1, 1, groups=channels * 2)
        self.fc2 = nn.Conv2d(channels, channels, 1)
        # self.fc2 = ShiftConv2d1x1(channels, channels, 1, shift_mode="+")

    def forward(self, x):
        shortcut, x = x, self.fc1(self.norm(x))
        x1, x2 = self.dw1(x).chunk(2, dim=1)
        x = self.act(x1) * x2
        x = self.fc2(x)
        return x + shortcut


class GroupedSharedAttention(nn.Module):
    def __init__(self, channels, window_size, calc_attn=True, shift=True) -> None:
        super().__init__()
        assert channels % len(window_size) == 0
        self.channels = channels
        self.window_size = window_size
        self.calc_attn = calc_attn

        self.proj_in = GatedMLP(channels)
        self.qkv_conv = nn.Sequential(
            LayerNorm2d(channels),
            nn.Conv2d(channels, channels * 3, 1),
        )
        if not calc_attn:
            self.proj_in = nn.Identity()
            self.qkv_conv = nn.Sequential(
                LayerNorm2d(channels),
                nn.Conv2d(channels, channels * 2, 1),
            )
        self.attns = nn.ModuleList()
        for ws in window_size:
            if not isinstance(ws, list):
                ws = to_2tuple(ws)
            self.attns.append(
                SharedAttention(
                    channels // len(window_size), ws, calc_attn=calc_attn, shift=shift
                ),
            )
        self.proj_out = nn.Conv2d(channels, channels, 1)

    def forward(self, x, buffer=None):
        shortcut = x
        qkvgs = self.qkv_conv(x)

        attns = []
        if self.calc_attn and buffer is None:
            buffer = []
            qkvgs = qkvgs.chunk(len(self.window_size), dim=1)
            for i, (qkv, block) in enumerate(zip(qkvgs, self.attns)):
                attn, buf = block(qkv)
                attns.append(attn)
                buffer.append(buf)
                del attn
            attn = torch.cat(attns, dim=1)
        else:
            qkvs, g = qkvgs.chunk(2, dim=1)
            qkvs = qkvs.chunk(len(self.window_size), dim=1)
            for i, (qkv, block) in enumerate(zip(qkvs, self.attns)):
                attn, _ = block(qkv, pre_attn_v=buffer[i])
                attns.append(attn)
                del attn
            buffer = None
            attn = g * torch.cat(attns, dim=1)

        out = self.proj_out(attn) + shortcut
        out = self.proj_in(out)
        return out, buffer


class RCDG(nn.Module):
    def __init__(self, channels, window_size, block_nums, last_norm=False) -> None:
        super().__init__()
        self.channels = channels
        self.window_size = window_size
        self.proj_in = MBConv(
            channels, channels, expansion_rate=1, shrinkage_rate=0.25, downsample=False
        )
        # self.proj_in = nn.Identity()

        self.blocks = nn.ModuleList(
            [
                GroupedSharedAttention(
                    channels, window_size, calc_attn=i % 2 == 0, shift=i % 4 > 1
                )
                for i in range(block_nums)
            ]
        )
        self.tail = nn.Conv2d(channels, channels, 1)
        self.proj_out = ESA(channels)
        # if last_norm:
        #     self.proj_out = LayerNorm2d(channels)

    
    def forward(self, x):
        shortcut, x, buffer = x, self.proj_in(x), None
        for i, block in enumerate(self.blocks):
            x, buffer = block(x, buffer)
            # assert torch.isnan(x).sum() == 0, i
        x = self.tail(x)
        return self.proj_out(x + shortcut)


class MeanShift(nn.Conv2d):
    r"""

    Args:
        rgb_range (int):
        sign (int):
        data_type (str):

    """

    def __init__(self, rgb_range: int, sign: int = -1, data_type: str = "DF2K") -> None:
        super(MeanShift, self).__init__(3, 3, kernel_size=(1, 1))

        rgb_std = (1.0, 1.0, 1.0)
        if data_type == "DIV2K":
            # RGB mean for DIV2K 1-800
            rgb_mean = (0.4488, 0.4371, 0.4040)
        elif data_type == "DF2K":
            # RGB mean for DF2K 1-3450
            rgb_mean = (0.4690, 0.4490, 0.4036)
        else:
            raise NotImplementedError(f"Unknown data type for MeanShift: {data_type}.")

        std = torch.Tensor(rgb_std)
        self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1)
        self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std
        for p in self.parameters():
            p.requires_grad = False


class Upsample(nn.Sequential):
    """Upsample module.

    Args:
        scale (int): Scale factor. Supported scales: 2^n and 3.
        num_feat (int): Channel number of intermediate features.
    """

    def __init__(self, scale, num_feat):
        m = []
        if (scale & (scale - 1)) == 0:  # scale = 2^n
            for _ in range(int(math.log(scale, 2))):
                m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
                m.append(nn.PixelShuffle(2))
        elif scale == 3:
            m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
            m.append(nn.PixelShuffle(3))
        else:
            raise ValueError(
                f"scale {scale} is not supported. Supported scales: 2^n and 3."
            )
        super(Upsample, self).__init__(*m)


class UpsampleOneStep(nn.Sequential):
    """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
       Used in lightweight SR to save parameters.

    Args:
        scale (int): Scale factor. Supported scales: 2^n and 3.
        num_feat (int): Channel number of intermediate features.

    """

    def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
        self.num_feat = num_feat
        self.input_resolution = input_resolution
        m = []
        m.append(nn.Conv2d(num_feat, (scale**2) * num_out_ch, 3, 1, 1))
        m.append(nn.PixelShuffle(scale))
        super(UpsampleOneStep, self).__init__(*m)

    def flops(self):
        h, w = self.input_resolution
        flops = h * w * self.num_feat * 3 * 9
        return flops


@ARCH_REGISTRY.register()
class ShareFormer(nn.Module):
    def __init__(
        self,
        img_channel=3,
        channels=64,
        up_scale=2,
        window_size=[4, 8, 16],
        gau_blocks=16,
        rc_blocks=1,
        out_channel=3,
        rgb_range=1,
        upsampler="pixelshuffledirect",
    ) -> None:
        super().__init__()
        self.sub_mean = MeanShift(rgb_range)
        self.add_mean = MeanShift(rgb_range, sign=1)

        self.up_scale = up_scale
        self.window_size = window_size
        self.rgb_range = rgb_range

        self.intro = nn.Conv2d(
            in_channels=img_channel,
            out_channels=channels,
            kernel_size=3,
            padding=1,
            stride=1,
            groups=1,
            bias=True,
        )
        self.blocks = nn.Sequential(
            *[
                RCDG(channels, window_size, gau_blocks, last_norm=i == rc_blocks - 1)
                for i in range(rc_blocks)
            ]
        )
        self.neck = nn.Conv2d(
            in_channels=channels,
            out_channels=channels,
            kernel_size=3,
            padding=1,
            stride=1,
            groups=1,
            bias=True,
        )

        self.upsampler = upsampler
        if self.upsampler == "pixelshuffle":
            # for classical SR
            self.conv_before_upsample = nn.Sequential(
                nn.Conv2d(channels, 64, 3, 1, 1), nn.SiLU(inplace=True)
            )
            self.up = Upsample(up_scale, 64)
            self.conv_last = nn.Conv2d(64, out_channel, 3, 1, 1)
        elif self.upsampler == "pixelshuffledirect":
            # for lightweight SR (to save parameters)
            self.up = UpsampleOneStep(
                up_scale,
                channels,
                out_channel,
            )

    def forward(self, x):

        x = self.sub_mean(x)

        x = self.intro(x)

        x = self.neck(self.blocks(x)) + x

        if self.upsampler == "pixelshuffle":
            x = self.conv_before_upsample(x)
            x = self.conv_last(self.up(x))
        elif self.upsampler == "pixelshuffledirect":
            x = self.up(x)

        x = self.add_mean(x)

        return x


