import numbers
import math
from functools import partial

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from model.ops import (
    _get_meshgrid_coords,
    coords_diff,
)
from timm.models.layers import DropPath, to_2tuple


class LayerNorm(nn.Module):
    def __init__(self, normalized_shape, dim=1):
        super(LayerNorm, self).__init__()
        if isinstance(normalized_shape, numbers.Integral):
            normalized_shape = (normalized_shape,)
        normalized_shape = torch.Size(normalized_shape)

        assert len(normalized_shape) == 1

        self.weight = nn.Parameter(torch.ones(normalized_shape))
        self.bias = nn.Parameter(torch.zeros(normalized_shape))
        self.normalized_shape = normalized_shape
        self.dim = dim

    def forward(self, x):
        # x: B, C, H, W
        mu = x.mean(self.dim, keepdim=True)
        sigma = x.var(self.dim, keepdim=True, unbiased=False)
        return (x - mu) / torch.sqrt(sigma + 1e-5) * self.weight[
            ..., None, None
        ] + self.bias[..., None, None]


class Mlp(nn.Module):
    """MLP as used in Vision Transformer, MLP-Mixer and related networks"""

    def __init__(
        self,
        in_features,
        hidden_features=None,
        out_features=None,
        act_layer=nn.GELU,
        drop=0.0,
    ):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        drop_probs = to_2tuple(drop)

        self.fc1 = nn.Conv2d(in_features, hidden_features, 1, 1, 0)
        self.act = act_layer()
        self.drop1 = nn.Dropout(drop_probs[0])
        self.fc2 = nn.Conv2d(hidden_features, out_features, 1, 1, 0)
        self.drop2 = nn.Dropout(drop_probs[1])

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop1(x)
        x = self.fc2(x)
        x = self.drop2(x)
        return x


class SELayer(nn.Module):
    def __init__(self, channel, reduction=4):
        super(SELayer, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel),
            nn.Hardsigmoid(),
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y


class ConvBlock(nn.Module):
    def __init__(self, num_feat, compress_ratio=4, reduction=16):
        super(ConvBlock, self).__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(num_feat, num_feat // compress_ratio, 3, 1, 1),
            nn.GELU(),
            nn.Conv2d(num_feat // compress_ratio, num_feat, 3, 1, 1),
            SELayer(num_feat, reduction),
        )

    def forward(self, x):
        x = self.conv(x)
        return x


class LocalityFFN(nn.Module):
    def __init__(
        self,
        in_dim,
        out_dim,
        expand_ratio=4.0,
        reduction=4,
        mlp_kernel_size=3,
    ):
        """
        :param in_dim: the input dimension
        :param out_dim: the output dimension. The input and output dimension should be the same.
        :param expand_ratio: expansion ratio of the hidden dimension.
        :param act: the activation function.
                    hs+se: h_swish and SE module
        :param reduction: reduction rate in SE module.
        """
        super(LocalityFFN, self).__init__()
        dim = int(in_dim * expand_ratio)
        ks = mlp_kernel_size

        # the first linear layer is replaced by 1x1 convolution.
        self.conv = nn.Sequential(
            nn.Conv2d(in_dim, dim, 1, 1, 0, bias=True),
            nn.Hardswish(),
            nn.Conv2d(dim, dim, ks, 1, ks // 2, groups=dim, bias=True),
            nn.Hardswish(),
            SELayer(dim, reduction=reduction),
            nn.Conv2d(dim, out_dim, 1, 1, 0, bias=True),
        )

    def forward(self, x):
        """
        Args:
            x: input tensor with shape of B, C, H, W
        Returns:
            output: tensor shape B, C, H, W
        """
        x = self.conv(x)
        return x


def get_relative_coords_table(window_size):

    # get relative_coords_table
    ts = window_size

    coord_h = torch.arange(-(ts[0] - 1), ts[0], dtype=torch.float32)
    coord_w = torch.arange(-(ts[1] - 1), ts[1], dtype=torch.float32)
    table = torch.stack(torch.meshgrid([coord_h, coord_w], indexing="ij")).permute(
        1, 2, 0
    )

    table = table.contiguous().unsqueeze(0)  # 1, 2*Wh-1, 2*Ww-1, 2

    table[:, :, :, 0] /= ts[0] - 1
    table[:, :, :, 1] /= ts[1] - 1

    table *= 8
    table = torch.sign(table) * torch.log2(torch.abs(table) + 1.0) / np.log2(8)

    return table


def get_relative_position_index(window_size):

    coords = _get_meshgrid_coords((0, 0), window_size)  # 2, Wh*Ww
    idx = coords_diff(coords, coords, max_diff=window_size)

    return idx  # Wh*Ww, Wh*Ww


class Attention(nn.Module):
    r"""Window based multi-head self attention (W-MSA) module with relative position bias.
    It supports both first-level and second-level self-attention in TreeIR.
    QK dimension is reduced to dim // 2 to save computation. V dimension is kept to dim.
    Args:
        dim (int): Number of input channels.
        window_size (tuple[int]): The height and width of the window.
        num_heads (int): Number of attention heads.
        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
        proj_drop (float, optional): Dropout ratio of output. Default: 0.0
        level (int): Attention level. Valid values are 1 (default) and 2.
    """

    def __init__(
        self,
        dim,
        num_heads,
        window_size,
        grid_size,
        global_size=None,
        shift_size=0,
        qkv_bias=True,
        qkv_conv=True,
        qk_reduce=True,
        attn_drop=0.0,
        level=1,
        t_version="v2",
    ):

        super().__init__()
        self.dim = dim
        self.window_size = window_size  # Wh, Ww
        self.grid_size = grid_size
        self.global_size = global_size
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.level = level
        self.qkv_conv = qkv_conv
        self.qk_reduce = qk_reduce
        self.shift_size = shift_size
        self.t_version = t_version

        if t_version == "v1":
            self.logit_scale = nn.Parameter(
                torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True
            )
        else:
            self.scale = self.head_dim**-0.5

        # mlp to generate continuous relative position bias
        self.cpb_mlp = nn.Sequential(
            nn.Linear(2, 512, bias=True),
            nn.ReLU(inplace=True),
            nn.Linear(512, num_heads, bias=False),
        )

        qkv_dim = dim * (2 if self.qk_reduce else 3)
        if qkv_conv:
            self.qkv = nn.Sequential(
                nn.Conv2d(dim, qkv_dim, 1, 1, 0, bias=qkv_bias),
                nn.Conv2d(qkv_dim, qkv_dim, 3, 1, 1, groups=qkv_dim, bias=qkv_bias),
            )
        else:
            self.qkv = nn.Conv2d(dim, qkv_dim, 1, 1, 0, bias=qkv_bias)

        self.attn_drop = nn.Dropout(attn_drop)
        self.softmax = nn.Softmax(dim=-1)

    def partition(self, x, window_size):
        if self.level == 1:
            x = window_partition(x, window_size)
        else:
            x = grid_partition(x, self.grid_size, self.global_size)
        return x

    def reverse(self, x, input_size, window_size):
        if self.level == 1:
            x = window_reverse(x, window_size, input_size)
        else:
            x = grid_reverse(x, self.grid_size, input_size, self.global_size)
        return x

    def attention(self, q, k, v, mask):
        B_, N = v.shape[0], v.shape[2]
        n = int(math.sqrt(N))
        if self.t_version == "v1":
            # cosine attention map
            attn = F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)
            logit_scale = torch.clamp(self.logit_scale, max=math.log(1.0 / 0.01)).exp()
            attn = attn * logit_scale
        else:
            # dot product attention map
            q = q * self.scale
            attn = q @ k.transpose(-2, -1)

        # positional encoding
        bias_table = self.cpb_mlp(
            get_relative_coords_table((n, n)).to(dtype=q.dtype, device=q.device)
        )
        bias_table = bias_table.view(-1, self.num_heads)

        bias = bias_table[get_relative_position_index((n, n)).view(-1)]
        bias = bias.view(N, N, -1).permute(2, 0, 1).contiguous()
        # nH, Wh*Ww, Wh*Ww
        bias = 16 * torch.sigmoid(bias)
        attn = attn + bias.unsqueeze(0)
        # print(self.shift_size)
        if self.shift_size > 0:
            # print(mask.shape)
            nW = mask.shape[0]
            # print(mask.shape, attn.shape)
            mask = mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask
            attn = attn.view(-1, self.num_heads, N, N)

        attn = self.softmax(attn)
        attn = self.attn_drop(attn)
        x = (attn @ v).transpose(1, 2).reshape(B_, N, self.dim)

        return x

    def qkv_split(self, qkv):
        B_, N = qkv.shape[:2]
        if self.qk_reduce:
            qkv = qkv.reshape(B_, N, self.num_heads, -1).permute(0, 2, 1, 3)
            split_dim = [self.head_dim // 2, self.head_dim // 2, self.head_dim]
            q, k, v = torch.split(qkv, split_dim, dim=-1)
        else:
            qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
            q, k, v = qkv[0], qkv[1], qkv[2]
        return q, k, v

    def forward(self, x, window_size=None, mask=None):
        """
        Args:
            x: input features with shape of (B, C, H, W)
        """

        B, C, H, W = x.shape

        qkv = self.qkv(x)

        # cyclic shift
        if self.shift_size > 0:
            qkv = torch.roll(
                qkv, shifts=(-window_size // 2, -window_size // 2), dims=(2, 3)
            )

        # partition
        qkv = self.partition(qkv, window_size)

        # qkv split
        q, k, v = self.qkv_split(qkv)

        # attention
        x = self.attention(q, k, v, mask)

        # reverse
        x = self.reverse(x, (B, C, H, W), window_size)

        # reverse cyclic shift
        if self.shift_size > 0:
            x = torch.roll(x, shifts=(window_size // 2, window_size // 2), dims=(2, 3))

        return x

    def extra_repr(self) -> str:
        s = f"dim={self.dim}, version={self.t_version}, num_heads={self.num_heads}, level={self.level}, "
        if self.level == 1:
            s += f"window_size={self.window_size}, global_size={self.global_size}, shift_size={self.shift_size}, "
        else:
            s += f"grid_size={self.grid_size}, global_size={self.global_size}, "
        s += f"qkv_conv={self.qkv_conv}, qk_reduce={self.qk_reduce}"
        return s

    def flops(self, N):
        # calculate flops for 1 window with token length of N
        flops = 0
        # qkv = self.qkv(x)
        flops += N * self.dim * 3 * self.dim
        # attn = (q @ k.transpose(-2, -1))
        flops += self.num_heads * N * (self.dim // self.num_heads) * N
        #  x = (attn @ v)
        flops += self.num_heads * N * N * (self.dim // self.num_heads)
        # x = self.proj(x)
        flops += N * self.dim * self.dim
        return flops


class TreeAttention(nn.Module):
    def __init__(
        self,
        dim,
        input_resolution,
        num_heads,
        window_size,
        grid_size,
        global_size=128,
        shift_size=0,
        qkv_bias=True,
        qkv_conv=False,
        qk_reduce=True,
        attn_drop=0.0,
        proj_drop=0.0,
        t_version="v2",
    ):
        super().__init__()
        self.window_size = window_size
        self.shift_size = shift_size

        attn = partial(
            Attention,
            dim=dim,
            num_heads=num_heads,
            window_size=window_size,
            grid_size=grid_size,
            global_size=global_size,
            qkv_bias=qkv_bias,
            qkv_conv=qkv_conv,
            qk_reduce=qk_reduce,
            attn_drop=attn_drop,
            t_version=t_version,
        )

        self.attn1 = attn(level=1, shift_size=shift_size)
        self.attn2 = attn(level=2, shift_size=0)
        self.proj = nn.Conv2d(dim, dim, 1, 1, 0)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x, window_size, grid_size, mask=None):
        """
        Args:
            x: input tensor with shape of B, C, H, W
        Returns:
            output: tensor shape B, C, H, W
        """
        # first-level self-attention
        x = self.attn1(x, window_size, mask)

        # second-level self-attention
        x = self.attn2(x, grid_size)

        # output projection
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


def window_partition(x, window_size):
    """
    Args:
        x: (B, C, H, W)
    Returns:
        windows: (B_, window_size ** 2, C)
    """
    B, C, H, W = x.shape
    x = x.view(B, C, H // window_size, window_size, W // window_size, window_size)
    windows = x.permute(0, 2, 4, 3, 5, 1).contiguous().view(-1, window_size**2, C)

    return windows


def window_reverse(windows, window_size, input_size):
    """
    Args:
        windows: (B_, L, C)
    Returns:
        x: (B, C, H, W)
    """
    B, C, H, W = input_size
    x = windows.view(
        B,
        H // window_size,
        W // window_size,
        window_size,
        window_size,
        C,
    )
    x = x.permute(0, 5, 1, 3, 2, 4).contiguous().view(B, C, H, W)
    return x


def grid_partition(x, grid_size, global_size=None):
    """
    Args:
        x: (B, C, H, W)
    Returns:
        windows: (B_, grid_size ** 2, C)
    """
    B, C, H, W = x.shape
    if global_size is None:
        x = x.view(B, C, grid_size, H // grid_size, grid_size, W // grid_size)
        windows = x.permute(0, 3, 5, 2, 4, 1).contiguous().view(-1, grid_size**2, C)
    else:
        pad = [0, 0, 0, 0]
        if W > global_size:
            pad[1] = math.ceil(W / global_size) * global_size - W
        if H > global_size:
            pad[3] = math.ceil(H / global_size) * global_size - H
        x = F.pad(x, pad, "reflect")
        H_new, W_new = x.shape[2:]
        # print(pad, H, W, H_new, W_new, global_size)
        h, w = min(H_new, global_size), min(W_new, global_size)
        x = x.view(
            B,
            C,
            H_new // h,
            grid_size,
            h // grid_size,
            W_new // w,
            grid_size,
            w // grid_size,
        )
        windows = (
            x.permute(0, 2, 5, 4, 7, 3, 6, 1).contiguous().view(-1, grid_size**2, C)
        )
    return windows


def grid_reverse(windows, grid_size, input_size, global_size=None):
    """
    Args:
        windows: (B_, L, C)
    Returns:
        x: (B, C, H, W)
    """
    B, C, H, W = input_size
    if global_size is None:
        x = windows.view(
            B,
            H // grid_size,
            W // grid_size,
            grid_size,
            grid_size,
            C,
        )
        x = x.permute(0, 5, 3, 1, 4, 2).contiguous().view(B, C, H, W)
    else:
        H_new, W_new = H, W
        if H > global_size:
            H_new = math.ceil(H / global_size) * global_size
        if W > global_size:
            W_new = math.ceil(W / global_size) * global_size
        h, w = min(H_new, global_size), min(W_new, global_size)
        x = windows.view(
            B,
            H_new // h,
            W_new // w,
            h // grid_size,
            w // grid_size,
            grid_size,
            grid_size,
            C,
        )
        x = x.permute(0, 7, 1, 5, 3, 2, 6, 4).contiguous().view(B, C, H_new, W_new)
        x = x[:, :, :H, :W]
    return x


class TreeTransformerBlock(nn.Module):
    r"""TreeIR Block.
    Args:
        dim (int): Number of input channels.
        input_resolution (tuple[int]): Input resulotion.
        num_heads (int): Number of attention heads.
        window_size (int): Window size.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float, optional): Stochastic depth rate. Default: 0.0
        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """

    def __init__(
        self,
        dim,
        input_resolution,
        num_heads,
        window_size=8,
        grid_size=8,
        global_size=None,
        shift_size=0,
        qkv_bias=True,
        qkv_conv=False,
        qk_reduce=True,
        drop=0.0,
        attn_drop=0.0,
        drop_path=0.0,
        mlp_type="locality",
        mlp_ratio=4.0,
        mlp_kernel_size=3,
        version="v2",
        conv_scale=0,
        compression_ratio=4,
    ):
        super().__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.num_heads = num_heads
        self.window_size = window_size
        self.grid_size = grid_size
        self.mlp_ratio = mlp_ratio
        self.version = version
        self.conv_scale = conv_scale

        self.attn = TreeAttention(
            dim=dim,
            input_resolution=input_resolution,
            num_heads=num_heads,
            window_size=window_size,
            grid_size=grid_size,
            global_size=global_size,
            shift_size=shift_size,
            qkv_bias=qkv_bias,
            qkv_conv=qkv_conv,
            qk_reduce=qk_reduce,
            attn_drop=attn_drop,
            proj_drop=drop,
            t_version=version,
        )
        # self.norm1 = nn.LayerNorm(dim)
        self.norm1 = LayerNorm(dim, 1)

        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()

        if self.conv_scale > 0:
            self.conv = ConvBlock(dim, compression_ratio)

        if mlp_type == "locality":
            self.mlp = LocalityFFN(dim, dim, mlp_ratio, mlp_kernel_size=mlp_kernel_size)
        else:
            self.mlp = Mlp(dim, hidden_features=int(dim * mlp_ratio))
        # self.norm2 = nn.LayerNorm(dim)
        self.norm2 = LayerNorm(dim, 1)

    def forward(self, x, window_size=None, grid_size=None, mask=None):
        """
        Args:
            x: input tensor with shape of B, C, H, W
        Returns:
            output: tensor shape B, C, H, W
        """
        # print(mask)
        if window_size is None:
            window_size = self.window_size
        if grid_size is None:
            grid_size = self.grid_size
        if self.version == "v1":
            # Window attention
            y = x + self.drop_path(
                self.norm1(self.attn(x, window_size, grid_size, mask))
            )
            if self.conv_scale > 0:
                y += self.conv(x) * self.conv_scale
            # FFN
            y = y + self.drop_path(self.norm2(self.mlp(y)))
        elif self.version == "v2":
            # Window attention
            y = x + self.drop_path(
                self.attn(self.norm1(x), window_size, grid_size, mask)
            )
            if self.conv_scale > 0:
                y += self.conv(x) * self.conv_scale
            # FFN
            y = y + self.drop_path(self.mlp(self.norm2(y)))

        return y

    def extra_repr(self) -> str:
        return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, mlp_ratio={self.mlp_ratio}, version={self.version}"

    def flops(self):
        flops = 0
        H, W = self.input_resolution
        # norm1
        flops += self.dim * H * W
        # W-MSA/SW-MSA
        nW = H * W / self.window_size / self.window_size
        flops += nW * self.attn.flops(self.window_size * self.window_size)
        # mlp
        flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
        # norm2
        flops += self.dim * H * W
        return flops


class TreeTransConvParallelBlock(nn.Module):
    r"""TreeIR Block.
    Args:
        dim (int): Number of input channels.
        input_resolution (tuple[int]): Input resulotion.
        num_heads (int): Number of attention heads.
        window_size (int): Window size.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float, optional): Stochastic depth rate. Default: 0.0
        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """

    def __init__(
        self,
        dim,
        input_resolution,
        num_heads,
        window_size=7,
        shift_size=0,
        mlp_ratio=4.0,
        qkv_bias=True,
        drop=0.0,
        attn_drop=0.0,
        drop_path=0.0,
        fixed_attn_size=True,
        mlp_kernel_size=3,
        version="v2",
        conv_scale=0.01,
        compression_ratio=4,
    ):
        super().__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.num_heads = num_heads
        self.window_size = window_size
        self.mlp_ratio = mlp_ratio

        self.trans_branch = TreeTransformerBlock(
            dim // 2,
            input_resolution,
            round(num_heads / 2),
            window_size,
            shift_size,
            mlp_ratio,
            qkv_bias,
            drop,
            attn_drop,
            drop_path,
            fixed_attn_size,
            mlp_kernel_size,
            version=version,
            conv_scale=0,
            qk_reduce=False,
        )
        self.conv_branch = ConvBlock(dim // 2, compression_ratio // 4)

        self.split = nn.Linear(dim, dim, bias=True)
        self.merge = nn.Linear(dim, dim, bias=True)

    def forward(self, x, window_size=None):
        """
        Args:
            x: input tensor with shape of B, H, W, C
        Returns:
            output: tensor shape B, H, W, C
        """
        conv_x, trans_x = torch.split(
            self.split(x), [self.dim // 2, self.dim // 2], dim=-1
        )
        conv_x = self.conv_branch(conv_x) + conv_x
        trans_x = self.trans_branch(trans_x, window_size)
        res = self.merge(torch.cat((conv_x, trans_x), dim=-1))
        x = x + res

        return x

    def extra_repr(self) -> str:
        return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, mlp_ratio={self.mlp_ratio}"

    def flops(self):
        flops = 0
        H, W = self.input_resolution
        # norm1
        flops += self.dim * H * W
        # W-MSA/SW-MSA
        nW = H * W / self.window_size / self.window_size
        flops += nW * self.attn.flops(self.window_size * self.window_size)
        # mlp
        flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
        # norm2
        flops += self.dim * H * W
        return flops
