# -----------------------------------------------------------------------------------
# SemanIR: Sharing Key Semantics in Transformer Makes Efficient Image Restoration
# -----------------------------------------------------------------------------------


import numbers
import numpy as np
import torch
import torch.nn as nn
from model.ops import (
    bchw_to_blc,
    blc_to_bchw,
    _get_meshgrid_coords,
    coords_diff,
)
from timm.models.layers import DropPath, to_2tuple
import math
import torch.nn.functional as F


class Linear(nn.Linear):
    def __init__(self, in_features, out_features, bias=True):
        super(Linear, self).__init__(in_features, out_features, bias)

    def forward(self, x):
        B, C, H, W = x.shape
        x = bchw_to_blc(x)
        x = super(Linear, self).forward(x)
        x = blc_to_bchw(x, (H, W))
        return x


def build_last_conv(conv_type, dim):
    if conv_type == "1conv":
        block = nn.Conv2d(dim, dim, 3, 1, 1)
    elif conv_type == "3conv":
        # to save parameters and memory
        block = nn.Sequential(
            nn.Conv2d(dim, dim // 4, 3, 1, 1),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            nn.Conv2d(dim // 4, dim // 4, 1, 1, 0),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            nn.Conv2d(dim // 4, dim, 3, 1, 1),
        )
    elif conv_type == "1conv1x1":
        block = nn.Conv2d(dim, dim, 1, 1, 0)
    elif conv_type == "linear":
        block = Linear(dim, dim)
    return block


def model_analysis(model):

    print(model)

    # number of parameters
    num_params = 0
    for p in model.parameters():
        if p.requires_grad:
            num_params += p.numel()
    print(f"Params: {num_params / 10 ** 6: 0.2f} M")


class LayerNorm(nn.Module):
    def __init__(self, normalized_shape, dim=1):
        super(LayerNorm, self).__init__()
        if isinstance(normalized_shape, numbers.Integral):
            normalized_shape = (normalized_shape,)
        normalized_shape = torch.Size(normalized_shape)

        assert len(normalized_shape) == 1

        self.weight = nn.Parameter(torch.ones(normalized_shape))
        self.bias = nn.Parameter(torch.zeros(normalized_shape))
        self.normalized_shape = normalized_shape
        self.dim = dim

    def forward(self, x):
        # x: B, C, H, W
        mu = x.mean(self.dim, keepdim=True)
        sigma = x.var(self.dim, keepdim=True, unbiased=False)
        return (x - mu) / torch.sqrt(sigma + 1e-5) * self.weight[
            ..., None, None
        ] + self.bias[..., None, None]


def _parse_list(model_param):
    if isinstance(model_param, str):
        if model_param.find("+") >= 0:
            model_param = list(map(int, model_param.split("+")))
        else:
            model_param = list(map(int, model_param.split("x")))
            model_param = [model_param[0]] * model_param[1]
    return model_param


class Mlp(nn.Module):
    """MLP as used in Vision Transformer, MLP-Mixer and related networks"""

    def __init__(
        self,
        in_features,
        hidden_features=None,
        out_features=None,
        act_layer=nn.GELU,
        drop=0.0,
    ):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        drop_probs = to_2tuple(drop)

        self.fc1 = nn.Conv2d(in_features, hidden_features, 1, 1, 0)
        self.act = act_layer()
        self.drop1 = nn.Dropout(drop_probs[0])
        self.fc2 = nn.Conv2d(hidden_features, out_features, 1, 1, 0)
        self.drop2 = nn.Dropout(drop_probs[1])

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop1(x)
        x = self.fc2(x)
        x = self.drop2(x)
        return x


def grid_partition(x, grid_size, global_size=None):
    """
    Args:
        x: (B, C, H, W)
    Returns:
        windows: (B_, grid_size ** 2, C)
    """
    B, C, H, W = x.shape
    if global_size is None:
        x = x.view(B, C, grid_size, H // grid_size, grid_size, W // grid_size)
        windows = x.permute(0, 3, 5, 2, 4, 1).contiguous().view(-1, grid_size**2, C)
    else:
        pad = [0, 0, 0, 0]
        if W > global_size:
            pad[1] = math.ceil(W / global_size) * global_size - W
        if H > global_size:
            pad[3] = math.ceil(H / global_size) * global_size - H
        x = F.pad(x, pad, "reflect")
        H_new, W_new = x.shape[2:]
        # print(pad, H, W, H_new, W_new, global_size)
        h, w = min(H_new, global_size), min(W_new, global_size)
        x = x.view(
            B,
            C,
            H_new // h,
            grid_size,
            h // grid_size,
            W_new // w,
            grid_size,
            w // grid_size,
        )
        windows = (
            x.permute(0, 2, 5, 4, 7, 3, 6, 1).contiguous().view(-1, grid_size**2, C)
        )
    return windows


def grid_reverse(windows, grid_size, input_size, global_size=None):
    """
    Args:
        windows: (B_, L, C)
    Returns:
        x: (B, C, H, W)
    """
    B, C, H, W = input_size
    if global_size is None:
        x = windows.view(
            B,
            H // grid_size,
            W // grid_size,
            grid_size,
            grid_size,
            C,
        )
        x = x.permute(0, 5, 3, 1, 4, 2).contiguous().view(B, C, H, W)
    else:
        H_new, W_new = H, W
        if H > global_size:
            H_new = math.ceil(H / global_size) * global_size
        if W > global_size:
            W_new = math.ceil(W / global_size) * global_size
        h, w = min(H_new, global_size), min(W_new, global_size)
        x = windows.view(
            B,
            H_new // h,
            W_new // w,
            h // grid_size,
            w // grid_size,
            grid_size,
            grid_size,
            C,
        )
        x = x.permute(0, 7, 1, 5, 3, 2, 6, 4).contiguous().view(B, C, H_new, W_new)
        x = x[:, :, :H, :W]
    return x


def get_relative_coords_table(window_size):

    # get relative_coords_table
    ts = window_size

    coord_h = torch.arange(-(ts[0] - 1), ts[0], dtype=torch.float32)
    coord_w = torch.arange(-(ts[1] - 1), ts[1], dtype=torch.float32)
    table = torch.stack(torch.meshgrid([coord_h, coord_w], indexing="ij")).permute(
        1, 2, 0
    )

    table = table.contiguous().unsqueeze(0)  # 1, 2*Wh-1, 2*Ww-1, 2

    table[:, :, :, 0] /= ts[0] - 1
    table[:, :, :, 1] /= ts[1] - 1

    table *= 8
    table = torch.sign(table) * torch.log2(torch.abs(table) + 1.0) / np.log2(8)

    return table


def get_relative_position_index(window_size):

    coords = _get_meshgrid_coords((0, 0), window_size)  # 2, Wh*Ww
    idx = coords_diff(coords, coords, max_diff=window_size)

    return idx  # Wh*Ww, Wh*Ww

