# -----------------------------------------------------------------------------------
# Hi-IR: Hierarchical Information Flow for Generalized Efficient Image Restoration
# Hi-IR with Unet architecture
# -----------------------------------------------------------------------------------


from functools import partial

import numbers
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairscale.nn import checkpoint_wrapper
from omegaconf import OmegaConf
from model.ops import (
    bchw_to_blc,
    blc_to_bchw,
    calculate_mask,
)
from model.tree_block import TreeTransformerBlock
from model.hiir import _parse_list
from timm.models.layers import to_2tuple


##########################################################################
## Resizing modules


class Downsample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Downsample, self).__init__()

        self.body = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
            nn.PixelUnshuffle(2),
        )

    def forward(self, x):
        return self.body(x)


class Upsample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Upsample, self).__init__()

        self.body = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
            nn.PixelShuffle(2),
        )

    def forward(self, x):
        return self.body(x)


def downsample(
    in_channels, out_channels, kernel_size=2, stride=2, padding=0, bias=True
):
    return nn.Conv2d(
        in_channels=in_channels,
        out_channels=out_channels,
        kernel_size=kernel_size,
        stride=stride,
        padding=padding,
        bias=bias,
    )


def upsample(in_channels, out_channels, kernel_size=2, stride=2, padding=0, bias=True):
    return nn.ConvTranspose2d(
        in_channels=in_channels,
        out_channels=out_channels,
        kernel_size=kernel_size,
        stride=stride,
        padding=padding,
        bias=bias,
    )


def _parse_block(
    block_version,
    dim,
    input_resolution,
    num_heads,
    window_size,
    grid_size,
    global_size,
    shift_size,
    drop,
    attn_drop,
    drop_path,
    compression_ratio,
    args,
):

    if block_version == "hiir":
        tree_block = TreeTransformerBlock
        block = tree_block(
            dim=dim,
            input_resolution=input_resolution,
            num_heads=num_heads,
            window_size=window_size,
            grid_size=grid_size,
            global_size=global_size,
            shift_size=shift_size,
            qkv_bias=args.qkv_bias,
            qkv_conv=args.qkv_conv,
            qk_reduce=args.qk_reduce,
            drop=drop,
            attn_drop=attn_drop,
            drop_path=drop_path,
            mlp_type=args.mlp_type,
            mlp_ratio=args.mlp_ratio,
            mlp_kernel_size=args.mlp_kernel_size,
            version=args.version,
            conv_scale=args.conv_scale,
            compression_ratio=compression_ratio,
        )
    else:
        raise NotImplementedError(f"Block version {block_version} not implemented!")
    return block


class HiIRStage(nn.Module):
    """Hi-IR stage.
    Args:
        dim (int): Number of input channels.
        input_resolution (tuple[int]): Input resolution.
        depth (int): Number of blocks.
        num_heads (int): Number of attention heads.
        window_size (int): Local window size.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
    """

    def __init__(
        self,
        dim,
        input_resolution,
        depth,
        num_heads,
        window_size,
        grid_size,
        global_size,
        drop=0.0,
        attn_drop=0.0,
        drop_path=0.0,
        compression_ratio=4,
        args=None,
    ):

        super().__init__()

        self.dim = dim
        self.input_resolution = input_resolution
        self.block_version = args.block_version

        self.blocks = nn.ModuleList()
        for i in range(depth):
            block = _parse_block(
                block_version=self.block_version,
                dim=dim,
                input_resolution=input_resolution,
                num_heads=num_heads,
                window_size=window_size,
                grid_size=grid_size,
                global_size=global_size,
                shift_size=window_size // 2 if args.shift and i % 2 > 0 else 0,
                drop=drop,
                attn_drop=attn_drop,
                drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
                compression_ratio=compression_ratio,
                args=args,
            )

            if args.fairscale_checkpoint:
                block = checkpoint_wrapper(block, offload_to_cpu=args.offload_to_cpu)
            self.blocks.append(block)

    def forward(self, x, window_size, grid_size, mask):
        """
        Args:
            x: input tensor with shape of B, C, H, W
        Returns:
            output: tensor shape B, C, H, W
        """
        if self.block_version == "hiir":
            for blk in self.blocks:
                x = blk(x, window_size, grid_size, mask)
            return x
        elif self.block_version in ["restormer", "scunet"]:
            for blk in self.blocks:
                x = blk(x)
            return x
        elif self.block_version in ["swin_v1", "swin_v2"]:
            H, W = x.shape[2:]
            res = bchw_to_blc(x)
            for blk in self.blocks:
                res = blk(res, (H, W))
            return blc_to_bchw(res, (H, W))
        # elif self.block_version in ["restormer", "scunet"]:
        #     for blk in self.blocks:
        #         x = blk(x)
        #     return x

    def flops(self):
        flops = 0
        flops += self.blocks.flops()
        H, W = self.input_resolution
        flops += H * W * self.dim * self.dim * 9

        return flops


##########################################################################
##---------- Hi-IR with Unet architecture -----------------------
class HiIRUnet(nn.Module):
    def __init__(
        self,
        img_size=64,
        in_channels=3,
        out_channels=None,
        embed_dim=96,
        depths="8+8+8+8",
        num_heads="3+3+6+12",
        expansion_ratio="1+1+2+4",
        window_size=8,
        grid_size=8,
        global_size=None,
        shift=False,
        qkv_bias=True,
        qkv_conv=False,
        qk_reduce=True,
        drop_rate=0.0,  # Used by attention projection layer and MLP
        attn_drop_rate=0.0,  # Used by attention MHA
        drop_path_rate=0.1,  # stochastic depth decay rule used for attention block and MLP. Not used.
        mlp_type="locality",
        mlp_ratio=2.0,
        mlp_kernel_size=3,
        version="v2",
        reduce_stage1_encoder_blocks=False,
        block_version="hiir",
        subsample_type="pixelshuffle",
        multiple_degradation=False,
        bias=True,
        fairscale_checkpoint=False,  # fairscale activation checkpointing
        offload_to_cpu=False,
        conv_scale=0.00,
        trans_conv_parallel=False,
        dual_conv=True,
        dual_pixel_task=False,  ## True for dual-pixel defocus deblurring only. Also set inp_channels=6
        # init_method="n",  # initialization method of the weight parameters used to train large scale models.
        **kwargs,
    ):

        super(HiIRUnet, self).__init__()
        out_channels = out_channels or in_channels
        self.in_channels = in_channels
        self.out_channels = out_channels

        # self.patch_embed = OverlapPatchEmbed(inp_channels, dim)
        self.conv_first = nn.Conv2d(in_channels, embed_dim, 3, 1, 1, bias=bias)
        depths = _parse_list(depths)
        num_heads = _parse_list(num_heads)
        expansion_ratio = _parse_list(expansion_ratio)
        self.multiple_degradation = multiple_degradation
        self.subsample_type = subsample_type

        if isinstance(window_size, numbers.Integral):
            self.window_size = [window_size] * len(depths)
        else:
            self.window_size = window_size
        if isinstance(grid_size, numbers.Integral):
            self.grid_size = [grid_size] * len(depths)
        else:
            self.grid_size = grid_size
        if isinstance(global_size, numbers.Integral) or global_size is None:
            self.global_size = [global_size] * len(depths)
        else:
            self.global_size = global_size
        print(self.window_size)
        self.pad_size = 8
        args = OmegaConf.create(
            {
                "fairscale_checkpoint": fairscale_checkpoint,
                "offload_to_cpu": offload_to_cpu,
                "version": version,
                "shift": shift,
                "qkv_bias": qkv_bias,
                "qkv_conv": qkv_conv,
                "qk_reduce": qk_reduce,
                "mlp_ratio": mlp_ratio,
                "mlp_type": mlp_type,
                "mlp_kernel_size": mlp_kernel_size,
                "conv_scale": conv_scale,
                "trans_conv_parallel": trans_conv_parallel,
                "block_version": block_version,
            }
        )
        self.pos_drop = nn.Dropout(p=drop_rate)

        stage = partial(
            HiIRStage,
            drop=drop_rate,
            attn_drop=attn_drop_rate,
            args=args,
        )
        compression_ratio_factor = 4
        ## Encoders of the Unet stages.
        self.encoders = nn.ModuleDict()
        self.downsamplers = nn.ModuleDict()
        for i in range(len(depths) - 1):
            self.encoders[f"stage{i+1}"] = stage(
                dim=embed_dim * expansion_ratio[i],
                window_size=self.window_size[i],
                grid_size=self.grid_size[i],
                global_size=self.global_size[i],
                input_resolution=to_2tuple(img_size // 2**i),
                depth=depths[i] // 2
                if i == 0 and reduce_stage1_encoder_blocks
                else depths[i],
                num_heads=num_heads[i],
                compression_ratio=expansion_ratio[i] * compression_ratio_factor,
            )
            if self.subsample_type == "simple":
                self.downsamplers[f"stage{i+1}"] = downsample(
                    embed_dim * expansion_ratio[i],
                    embed_dim * expansion_ratio[i + 1],
                )
            elif self.subsample_type == "pixelshuffle":
                self.downsamplers[f"stage{i+1}"] = Downsample(
                    embed_dim * expansion_ratio[i],
                    embed_dim * expansion_ratio[i + 1] // 4,
                )
        ## Module for the latent space.

        self.latent = stage(
            dim=embed_dim * expansion_ratio[3],
            window_size=self.window_size[3],
            grid_size=self.grid_size[3],
            global_size=self.global_size[3],
            input_resolution=to_2tuple(img_size // 2**3),
            depth=depths[3],
            num_heads=num_heads[3],
            compression_ratio=expansion_ratio[3] * compression_ratio_factor,
        )

        ## Decoders of the Unet stages
        self.upsamplers = nn.ModuleDict()
        self.reducers = nn.ModuleDict()
        self.decoders = nn.ModuleDict()
        for i in range(len(depths) - 2, -1, -1):
            if self.subsample_type == "simple":
                self.upsamplers[f"stage{i+1}"] = upsample(
                    embed_dim * expansion_ratio[i + 1],
                    embed_dim * expansion_ratio[i],
                )
            elif self.subsample_type == "pixelshuffle":
                self.upsamplers[f"stage{i+1}"] = Upsample(
                    embed_dim * expansion_ratio[i + 1],
                    embed_dim * expansion_ratio[i] * 4,
                )
            self.reducers[f"stage{i+1}"] = nn.Conv2d(
                embed_dim * expansion_ratio[i] * 2,
                embed_dim * expansion_ratio[i],
                kernel_size=1,
                bias=bias,
            )
            self.decoders[f"stage{i+1}"] = stage(
                dim=embed_dim * expansion_ratio[i],
                window_size=self.window_size[i],
                grid_size=self.grid_size[i],
                global_size=self.global_size[i],
                input_resolution=to_2tuple(img_size // 2**i),
                depth=depths[i],
                num_heads=num_heads[i],
                compression_ratio=expansion_ratio[i] * compression_ratio_factor,
            )

        #### For Dual-Pixel Defocus Deblurring Task ####
        self.dual_pixel_task = dual_pixel_task
        self.dual_conv = dual_conv
        if self.dual_pixel_task and self.dual_conv:
            # self.skip_conv = nn.Conv2d(
            #     in_channels, out_channels, kernel_size=1, bias=bias
            # )
            self.skip_conv = nn.Conv2d(embed_dim, embed_dim, kernel_size=1, bias=bias)
        ###########################

        self.output = nn.Conv2d(embed_dim, out_channels, 3, 1, 1, bias=bias)

    def encoding(self, x, window_size, grid_size, masks):
        features = []
        i = 0
        for encoder, downsampler, w, m in zip(
            self.encoders.values(),
            self.downsamplers.values(),
            window_size,
            masks.values(),
        ):
            i += 1
            # print("check_shape", i, x.shape, m.shape)
            x = encoder(x, w, grid_size, m)
            features.append(x)
            x = downsampler(x)
        return x, features

    def decoding(self, x, features, window_size, grid_size, masks):
        for feature, upsampler, reducer, decoder, w, m in zip(
            features[::-1],
            self.upsamplers.values(),
            self.reducers.values(),
            self.decoders.values(),
            window_size,
            list(masks.values())[::-1],
        ):
            x = upsampler(x)
            x = reducer(torch.cat([x, feature], 1))
            x = decoder(x, w, grid_size, m)
        return x

    def forward_features(self, x, window_size=None, grid_size=None):

        # attention mask
        masks = {}
        H, W = x.shape[2:]
        for i in range(len(window_size)):
            mask = calculate_mask(
                (int(H // 2**i), int(W // 2**i)),
                to_2tuple(window_size[i]),
                to_2tuple(window_size[i] // 2),
            ).to(dtype=x.dtype, device=x.device)
            masks[f"stage{i+1}"] = mask
        mask = masks.pop("stage4")

        # encoding
        x, features = self.encoding(x, window_size[:-1], grid_size, masks)

        # latent
        x = self.latent(x, window_size[3], grid_size, mask) + x

        # decoding

        x = self.decoding(x, features, window_size[-2::-1], grid_size, masks)

        return x

    def check_image_size(self, x, window_size):
        if window_size is None:
            window_size = self.window_size
        else:
            window_size = [
                int(window_size / self.window_size[0] * w) for w in self.window_size
            ]
        # print(window_size)
        pad_size = window_size[-1] * 2 ** (len(window_size) - 1)

        _, _, h, w = x.size()
        mod_pad_h = (pad_size - h % pad_size) % pad_size
        mod_pad_w = (pad_size - w % pad_size) % pad_size
        x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), "reflect")
        # print("check size", x.shape, pad_size)
        return x, window_size

    def forward(self, x, window_size=None, grid_size=None):

        H, W = x.shape[2:]
        # first conv serve as patch embedding
        x, window_size = self.check_image_size(x, window_size)
        feature = self.conv_first(x)

        # The Unet architecture
        out = self.forward_features(feature, window_size, grid_size)

        #### For Dual-Pixel Defocus Deblurring Task ####
        # if self.dual_pixel_task:
        #     out = out + self.skip_conv(feature)
        #     out = self.output(out)
        # ###########################
        # else:
        #     if self.multiple_degradation:
        #         out = self.output(out) + x[:, :-1, ...]
        #     else:
        #         out = self.output(out) + x

        out = self.output(out + feature)
        if self.in_channels == self.out_channels:
            out += x
        else:
            if self.dual_pixel_task:
                if self.dual_conv:
                    out += self.skip_conv(x)
                else:
                    out += x[:, :3, ...]
            if self.multiple_degradation:
                out += x[:, :-1, ...]

        # out = self.output(out + feature)

        # if self.in_channels == self.out_channels:
        #     out = self.output(out + feature) + x
        # else:
        #     if self.dual_pixel_task:
        #         if self.dual_conv:
        #             out = self.output(out + feature + self.skip_conv(feature))
        #         else:
        #             out = self.output(out + feature) + x[:, :3, ...]
        #     elif self.multiple_degradation:
        #         out = self.output(out + feature) + x[:, :-1, ...]

        return out[:, :, :H, :W]

    def convert_checkpoint(self, state_dict):
        new_state = {}
        for k, v in state_dict.items():
            if k.find("qkv.weight") >= 0 or k.find("proj.weight") >= 0:
                if v.dim() == 2:
                    new_state[k] = v[..., None, None]
                else:
                    new_state[k] = v
            else:
                new_state[k] = v
        return new_state


class Model(nn.Module):
    def __init__(self, model):
        super(Model, self).__init__()
        self.model = model

    def forward(self, x):
        return self.model(x)


if __name__ == "__main__":

    model = HiIRUnet(
        img_size=128,
        in_channels=3,
        out_channels=3,
        embed_dim=128,
        depths="4+4+4+4",
        num_heads="4+4+8+16",
        expansion_ratio="1+1+2+4",
        window_size=8,
        mlp_ratio=2.0,
        dual_pixel_task=False,  ## True for dual-pixel defocus deblurring only. Also set inp_channels=6
        version="v2",
        mlp_type="locality",
        mlp_kernel_size=3,
        multiple_degradation=False,
        subsample_type="simple",
        dual_conv=False,
        trans_conv_parallel=False,
    )

    model_warp = Model(model)
    print(model)
    # print(height, width, model.flops() / 1e9)

    x = torch.randn((1, 3, 64, 64))
    x = model(x)
    print(x.shape)
    num_params = 0
    for p in model.parameters():
        if p.requires_grad:
            num_params += p.numel()
    print(f"Number of parameters {num_params / 10 ** 6: 0.2f}")
