# -----------------------------------------------------------------------------------
# Hi-IR: Hierarchical Information Flow for Generalized Efficient Image Restoration
# -----------------------------------------------------------------------------------
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairscale.nn import checkpoint_wrapper
from omegaconf import OmegaConf
from model.upsample import Upsample, UpsampleOneStep
from model.ops import calculate_mask
from model.common import build_last_conv
from model.tree_block import LayerNorm, TreeTransformerBlock
from timm.models.layers import to_2tuple, trunc_normal_


# window_pool = {
#     128: (16, 32),
#     160: (16, 20, 32),
#     192: (16, 24, 32),
#     224: (16, 28, 32),
#     256: (16, 32),
# }

# window_pool_more = {
#     128: (8, 16, 32),
#     160: (8, 10, 16, 20, 32),
#     192: (8, 12, 16, 24, 32),
#     224: (8, 14, 16, 28, 32),
#     256: (8, 16, 32),
# }


patch_window_pool = {
    224: (7, 8, 14, 16, 28, 32),
    240: (48, 40, 30, 24, 20, 16, 15, 12, 10, 8, 6, 5),
    256: (8, 16, 32),
}


def _paired_patch_window(pool_dict):
    """Select patch window."""
    pool_list = []
    for patch_size, window_sizes in pool_dict.items():
        for window_size in window_sizes:
            pool_list.append((patch_size, window_size))
    return pool_list


paired_patch_window = _paired_patch_window(patch_window_pool)


# def patch_window_selection(training, random_training, index):
#     """Window selection."""
#     if training:
#         if random_training is None:
#             return patch_size, window_size
#         elif random_training == "less":
#             return paired_patch_window[index]
#         # elif random_training == "more":
#         #     current_window_size = window_pool_more[index]
#     else:
#         return patch_size, window_size


def _parse_list(model_param):
    if isinstance(model_param, str):
        if model_param.find("+") >= 0:
            model_param = list(map(int, model_param.split("+")))
        else:
            model_param = list(map(int, model_param.split("x")))
            model_param = [model_param[0]] * model_param[1]
    return model_param


class HiIRStage(nn.Module):
    """Hi-IR stage.
    Args:
        dim (int): Number of input channels.
        input_resolution (tuple[int]): Input resolution.
        depth (int): Number of blocks.
        num_heads (int): Number of attention heads.
        window_size (int): Local window size.
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
    """

    def __init__(
        self,
        dim,
        input_resolution,
        depth,
        num_heads,
        window_size,
        grid_size,
        global_size,
        drop=0.0,
        attn_drop=0.0,
        drop_path=0.0,
        skip=True,
        args=None,
    ):
        super().__init__()

        self.dim = dim
        self.input_resolution = input_resolution
        # self.init_method = args.init_method
        self.skip = skip

        self.blocks = nn.ModuleList()
        for i in range(depth):
            block = TreeTransformerBlock(
                dim=dim,
                input_resolution=input_resolution,
                num_heads=num_heads,
                window_size=window_size,
                grid_size=grid_size,
                global_size=global_size,
                shift_size=window_size // 2 if args.shift and i % 2 > 0 else 0,
                qkv_bias=args.qkv_bias,
                qkv_conv=args.qkv_conv,
                qk_reduce=args.qk_reduce,
                drop=drop,
                attn_drop=attn_drop,
                drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
                mlp_type=args.mlp_type,
                mlp_ratio=args.mlp_ratio,
                mlp_kernel_size=args.mlp_kernel_size,
                version=args.version,
                conv_scale=args.conv_scale,
            )

            if args.fairscale_checkpoint:
                block = checkpoint_wrapper(block, offload_to_cpu=args.offload_to_cpu)
            self.blocks.append(block)

        self.conv = build_last_conv(args.conv_type, dim)

    # def _init_weights(self):
    #     for n, m in self.named_modules():
    #         if self.init_method == "w":
    #             if isinstance(m, (nn.Linear, nn.Conv2d)) and n.find("cpb_mlp") < 0:
    #                 print("nn.Linear and nn.Conv2d weight initilization")
    #                 m.weight.data *= self.res_scale
    #         elif self.init_method == "l":
    #             if isinstance(m, nn.LayerNorm):
    #                 print("nn.LayerNorm initialization")
    #                 nn.init.constant_(m.bias, 0)
    #                 nn.init.constant_(m.weight, 0)
    #         elif self.init_method.find("t") >= 0:
    #             scale = 0.1 ** (len(self.init_method) - 1) * int(self.init_method[-1])
    #             if isinstance(m, nn.Linear) and n.find("cpb_mlp") < 0:
    #                 trunc_normal_(m.weight, std=scale)
    #             elif isinstance(m, nn.Conv2d):
    #                 m.weight.data *= self.res_scale
    #             print(
    #                 "Initialization nn.Linear - trunc_normal; nn.Conv2d - weight rescale."
    #             )
    #         else:
    #             raise NotImplementedError(
    #                 f"Parameter initialization method {self.init_method} not implemented in TransformerStage."
    #             )

    def forward(self, x, window_size, grid_size, mask):
        """
        Args:
            x: input tensor with shape of B, C, H, W
        Returns:
            output: tensor shape B, C, H, W
        """
        res = x
        for blk in self.blocks:
            res = blk(res, window_size, grid_size, mask)
        res = self.conv(res)
        if self.skip:
            return res + x
        else:
            return res

    def flops(self):
        flops = 0
        flops += self.blocks.flops()
        H, W = self.input_resolution
        flops += H * W * self.dim * self.dim * 9

        return flops


class HiIR(nn.Module):
    r"""SwinIR
        A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer.
    Args:
        img_size (int | tuple(int)): Input image size. Default 64
        in_chans (int): Number of input image channels. Default: 3
        embed_dim (int): Patch embedding dimension. Default: 96
        depths (tuple(int)): Depth of each Swin Transformer layer.
        num_heads (tuple(int)): Number of attention heads in different layers.
        window_size (int): Window size. Default: 7
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
        qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
        drop_rate (float): Dropout rate. Default: 0. Used by attention projection layer and MLP.
        attn_drop_rate (float): Attention dropout rate. Default: 0. Used by attention MHA.
        drop_path_rate (float): Stochastic depth rate. Default: 0.1
        upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
        img_range: Image range. 1. or 255.
        upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
        conv_type: The convolutional block before residual connection. '1conv'/'3conv'
        version: Transformer block version. Choices: v1, v2.
                v1: cosine similarity attention. post-norm
                v2: dot product attention. pre-norm
        init_method: initialization method of the weight parameters used to train large scale models.
            Choices: n, normal -- Swin V1 init method.
                    l, layernorm -- Swin V2 init method. Zero the weight and bias in the post layer normalization layer.
                    r, res_rescale -- EDSR rescale method. Rescale the residual blocks with a scaling factor 0.1
                    w, weight_rescale -- MSRResNet rescale method. Rescale the weight parameter in residual blocks with a scaling factor 0.1
                    t, trunc_normal_ -- nn.Linear, trunc_normal; nn.Conv2d, weight_rescale
    """

    def __init__(
        self,
        img_size=64,
        in_channels=3,
        out_channels=None,
        embed_dim=96,
        depths="6+6+6+6+6+6",
        num_heads="6+6+6+6+6+6",
        window_size=8,
        grid_size=8,
        global_size=None,
        shift=False,
        qkv_bias=True,
        qkv_conv=False,
        qk_reduce=True,
        drop_rate=0.0,  # Used by attention projection layer and MLP
        attn_drop_rate=0.0,  # Used by attention MHA
        drop_path_rate=0.1,  # stochastic depth decay rule used for attention block and MLP.
        mlp_type="locality",
        mlp_ratio=4.0,
        mlp_kernel_size=3,
        version="v2",  #  tree transformer version
        upscale=2,
        img_range=1.0,
        upsampler="",
        conv_type="1conv",
        fairscale_checkpoint=False,  # fairscale activation checkpointing
        offload_to_cpu=False,
        conv_scale=0.00,
        # init_method="n",  # initialization method of the weight parameters used to train large scale models.
        **kwargs,
    ):
        super(HiIR, self).__init__()
        # Process the input arguments
        out_channels = out_channels or in_channels
        num_out_feats = 64
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.embed_dim = embed_dim
        depths = _parse_list(depths)
        num_heads = _parse_list(num_heads)
        self.upscale = upscale
        self.upsampler = upsampler
        self.window_size = window_size

        self.img_range = img_range
        if in_channels == 3:
            rgb_mean = (0.4488, 0.4371, 0.4040)
            self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
        else:
            self.mean = torch.zeros(1, 1, 1, 1)

        self.pad_size = window_size
        args = OmegaConf.create(
            {
                "fairscale_checkpoint": fairscale_checkpoint,
                "offload_to_cpu": offload_to_cpu,
                "version": version,
                "shift": shift,
                "qkv_bias": qkv_bias,
                "qkv_conv": qkv_conv,
                "qk_reduce": qk_reduce,
                "mlp_ratio": mlp_ratio,
                "mlp_type": mlp_type,
                "mlp_kernel_size": mlp_kernel_size,
                "conv_type": conv_type,
                "conv_scale": conv_scale,
            }
        )

        # Head of the network. First convolution.
        self.conv_first = nn.Conv2d(in_channels, embed_dim, 3, 1, 1)

        # Body of the network
        # self.norm_start = nn.LayerNorm(embed_dim)
        self.norm_start = LayerNorm(embed_dim, 1)
        self.pos_drop = nn.Dropout(p=drop_rate)

        # stochastic depth
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
        # stochastic depth decay rule

        self.layers = nn.ModuleList()
        for i_layer in range(len(depths)):
            layer = HiIRStage(
                dim=embed_dim,
                input_resolution=to_2tuple(img_size),
                depth=depths[i_layer],
                num_heads=num_heads[i_layer],
                window_size=window_size,
                grid_size=grid_size,
                global_size=global_size,
                drop=drop_rate,
                attn_drop=attn_drop_rate,
                drop_path=dpr[
                    sum(depths[:i_layer]) : sum(depths[: i_layer + 1])
                ],  # no impact on SR results
                args=args,
            )
            self.layers.append(layer)
        # self.norm_end = nn.LayerNorm(embed_dim)
        self.norm_end = LayerNorm(embed_dim, 1)

        # Tail of the network
        self.conv_after_body = build_last_conv(conv_type, embed_dim)

        #####################################################################################################
        ################################ 3, high quality image reconstruction ################################
        if self.upsampler == "pixelshuffle":
            # for classical SR
            self.conv_before_upsample = nn.Sequential(
                nn.Conv2d(embed_dim, num_out_feats, 3, 1, 1), nn.LeakyReLU(inplace=True)
            )
            self.upsample = Upsample(upscale, num_out_feats)
            self.conv_last = nn.Conv2d(num_out_feats, out_channels, 3, 1, 1)
        elif self.upsampler == "pixelshuffledirect":
            # for lightweight SR (to save parameters)
            self.upsample = UpsampleOneStep(
                upscale,
                embed_dim,
                out_channels,
            )
        elif self.upsampler == "nearest+conv":
            # for real-world SR (less artifacts)
            assert self.upscale == 4, "only support x4 now."
            self.conv_before_upsample = nn.Sequential(
                nn.Conv2d(embed_dim, num_out_feats, 3, 1, 1), nn.LeakyReLU(inplace=True)
            )
            self.conv_up1 = nn.Conv2d(num_out_feats, num_out_feats, 3, 1, 1)
            self.conv_up2 = nn.Conv2d(num_out_feats, num_out_feats, 3, 1, 1)
            self.conv_hr = nn.Conv2d(num_out_feats, num_out_feats, 3, 1, 1)
            self.conv_last = nn.Conv2d(num_out_feats, out_channels, 3, 1, 1)
            self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
        else:
            # for image denoising and JPEG compression artifact reduction
            self.conv_last = nn.Conv2d(embed_dim, out_channels, 3, 1, 1)

        self.apply(self._init_weights)
        # if init_method in ["l", "w"] or init_method.find("t") >= 0:
        #     for layer in self.layers:
        #         layer._init_weights()

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            # Only used to initialize linear layers
            # weight_shape = m.weight.shape
            # if weight_shape[0] > 256 and weight_shape[1] > 256:
            #     std = 0.004
            # else:
            #     std = 0.02
            # print(f"Standard deviation during initialization {std}.")
            trunc_normal_(m.weight, std=0.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    @torch.jit.ignore
    def no_weight_decay(self):
        return {"absolute_pos_embed"}

    @torch.jit.ignore
    def no_weight_decay_keywords(self):
        return {"relative_position_bias_table"}

    def check_image_size(self, x, pad_size):
        pad_size = pad_size or self.pad_size
        _, _, h, w = x.size()
        mod_pad_h = (pad_size - h % pad_size) % pad_size
        mod_pad_w = (pad_size - w % pad_size) % pad_size
        x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), "reflect")
        # print("check size", x.shape, pad_size)
        return x

    def forward_features(self, x, window_size=None, grid_size=None):
        x = self.norm_start(x)
        x = self.pos_drop(x)
        window_size = window_size or self.window_size
        mask = calculate_mask(
            x.shape[2:], to_2tuple(window_size), to_2tuple(window_size // 2)
        ).to(dtype=x.dtype, device=x.device)

        for layer in self.layers:
            x = layer(x, window_size, grid_size, mask)

        x = self.norm_end(x)

        return x

    def forward(self, x, window_size=None, grid_size=None):
        H, W = x.shape[2:]
        x = self.check_image_size(x, window_size)

        self.mean = self.mean.type_as(x)
        x = (x - self.mean) * self.img_range

        if self.upsampler == "pixelshuffle":
            # for classical SR
            x = self.conv_first(x)
            x = (
                self.conv_after_body(self.forward_features(x, window_size, grid_size))
                + x
            )
            x = self.conv_before_upsample(x)
            x = self.conv_last(self.upsample(x))
        elif self.upsampler == "pixelshuffledirect":
            # for lightweight SR
            x = self.conv_first(x)
            x = (
                self.conv_after_body(self.forward_features(x, window_size, grid_size))
                + x
            )
            x = self.upsample(x)
        elif self.upsampler == "nearest+conv":
            # for real-world SR
            x = self.conv_first(x)
            x = (
                self.conv_after_body(self.forward_features(x, window_size, grid_size))
                + x
            )
            x = self.conv_before_upsample(x)
            x = self.lrelu(
                self.conv_up1(
                    torch.nn.functional.interpolate(x, scale_factor=2, mode="nearest")
                )
            )
            x = self.lrelu(
                self.conv_up2(
                    torch.nn.functional.interpolate(x, scale_factor=2, mode="nearest")
                )
            )
            x = self.conv_last(self.lrelu(self.conv_hr(x)))
        else:
            # for image denoising and JPEG compression artifact reduction
            x_first = self.conv_first(x)
            res = (
                self.conv_after_body(
                    self.forward_features(x_first, window_size, grid_size)
                )
                + x_first
            )
            if self.in_channels == self.out_channels:
                x = x + self.conv_last(res)
            else:
                x = self.conv_last(res)

        x = x / self.img_range + self.mean

        return x[:, :, : H * self.upscale, : W * self.upscale]

    def flops(self):
        flops = 0
        H, W = self.patches_resolution
        flops += H * W * 3 * self.embed_dim * 9
        flops += self.patch_embed.flops()
        for i, layer in enumerate(self.layers):
            flops += layer.flops()
        flops += H * W * 3 * self.embed_dim * self.embed_dim
        flops += self.upsample.flops()
        return flops

    def convert_checkpoint(self, state_dict):
        new_state = {}
        for k, v in state_dict.items():
            if k.find("qkv.weight") >= 0 or k.find("proj.weight") >= 0:
                if v.dim() == 2:
                    new_state[k] = v[..., None, None]
                else:
                    new_state[k] = v
            else:
                new_state[k] = v
        return new_state


if __name__ == "__main__":
    upscale = 4
    window_size = 8
    height = (512 // upscale // window_size + 1) * window_size
    width = (512 // upscale // window_size + 1) * window_size

    # HiIR
    model = HiIR(
        upscale=upscale,
        window_size=16,
        grid_size=16,
        global_size=128,
        img_size=64,
        img_range=1.0,
        depths=[6, 6, 6, 6, 6, 6],
        embed_dim=180,
        num_heads=[6, 6, 6, 6, 6, 6],
        mlp_ratio=2,
        conv_type="3conv",
        shift=True,
        upsampler="pixelshuffle",
    )

    # Hi-IR large
    # model = Hi-IR(
    #     upscale=upscale,
    #     window_size=8,
    #     grid_size=8,
    #     global_size=128,
    #     img_size=64,
    #     img_range=1.0,
    #     depths=[6, 6, 6, 6, 6, 6],
    #     embed_dim=180,
    #     num_heads=[6, 6, 6, 6, 6, 6],
    #     mlp_ratio=2,
    #     conv_type="3conv",
    #     shift=True,
    #     upsampler="pixelshuffle",
    # )

    print(model)

    from model.common import model_analysis
    model_analysis(model)

