# -----------------------------------------------------------------------------------
# SemanIR: Sharing Key Semantics in Transformer Makes Efficient Image Restoration
# -----------------------------------------------------------------------------------


import numbers
from functools import partial

import torch
import torch.nn as nn
import torch.nn.functional as F
from fairscale.nn import checkpoint_wrapper
from model.semanir_block import SemanticTransformerBlock
from model.semanir_block_v2 import SemanticTransformerBlockV2
from timm.models.layers import to_2tuple


class TransformerStage(nn.Module):
    """
    Args:
        dim (int): Number of input channels.
        input_resolution (tuple[int]): Input resolution.
        depth (int): Number of blocks.
        num_heads (int): Number of attention heads.
        window_size (int): Local window size.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
    """

    def __init__(
        self,
        dim,
        input_resolution,
        depth,
        num_heads,
        window_size,
        global_size,
        mlp_ratio=2.0,
        top_k=32,
        qkv_bias=True,
        qk_scale=None,
        graph_type="mask",
        fairscale_checkpoint=False,
        offload_to_cpu=False,
        version="v1",
    ):
        super().__init__()

        self.dim = dim
        self.input_resolution = input_resolution

        self.blocks = nn.ModuleList()

        for i in range(depth):
            if version == "v1":
                block = SemanticTransformerBlock(
                    dim=dim,
                    input_resolution=to_2tuple(input_resolution),
                    num_heads=num_heads,
                    window_size=window_size,
                    shift_size=0 if (i % 2 == 0) else window_size // 2,
                    mlp_ratio=mlp_ratio,
                    top_k=top_k,
                    qkv_bias=qkv_bias,
                    qk_scale=qk_scale,
                    graph_type=graph_type,
                )
            else:
                block = SemanticTransformerBlockV2(
                    dim=dim,
                    input_resolution=to_2tuple(input_resolution),
                    num_heads=num_heads,
                    window_size=window_size,
                    shift_size=0 if (i % 2 == 0) else window_size // 2,
                    global_size=global_size,
                    mlp_ratio=mlp_ratio,
                    top_k=top_k,
                    qkv_bias=qkv_bias,
                    qk_scale=qk_scale,
                    graph_type=graph_type,
                )
            if fairscale_checkpoint:
                block = checkpoint_wrapper(block, offload_to_cpu=offload_to_cpu)
            self.blocks.append(block)

    def forward(self, x):
        sim_measure = {}
        for i, blk in enumerate(self.blocks):
            if i <= 1:
                x, measure = blk(x)
                sim_measure[i % 2 == 0] = measure
            else:
                x, _ = blk(x, measure=sim_measure[i % 2 == 0])
            # print(i)

        return x

    def flops(self):
        flops = 0
        flops += self.blocks.flops()
        H, W = self.input_resolution
        flops += H * W * self.dim * self.dim * 9

        return flops


##########################################################################
## Overlapped image patch embedding with 3x3 Conv
class OverlapPatchEmbed(nn.Module):
    def __init__(self, in_c=3, embed_dim=48, bias=False):
        super(OverlapPatchEmbed, self).__init__()

        self.proj = nn.Conv2d(
            in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias
        )

    def forward(self, x):
        x = self.proj(x)

        return x


##########################################################################
## Resizing modules
class Downsample(nn.Module):
    def __init__(self, n_feat):
        super(Downsample, self).__init__()

        self.body = nn.Sequential(
            nn.Conv2d(
                n_feat, n_feat // 2, kernel_size=3, stride=1, padding=1, bias=False
            ),
            nn.PixelUnshuffle(2),
        )

    def forward(self, x):
        return self.body(x)


class Upsample(nn.Module):
    def __init__(self, n_feat):
        super(Upsample, self).__init__()

        self.body = nn.Sequential(
            nn.Conv2d(
                n_feat, n_feat * 2, kernel_size=3, stride=1, padding=1, bias=False
            ),
            nn.PixelShuffle(2),
        )

    def forward(self, x):
        return self.body(x)


##########################################################################
##---------- Restormer -----------------------
class SemanIRUnet(nn.Module):
    def __init__(
        self,
        img_size=128,
        in_channels=3,
        out_channels=3,
        dim=48,
        window_size=8,
        top_k=16,
        num_blocks=[4, 6, 6, 8],
        num_refinement_blocks=4,
        heads=[1, 2, 4, 8],
        global_size=[128, 64, 32, 32],
        mlp_ratio=4,
        bias=True,
        dual_pixel_task=False,  ## True for dual-pixel defocus deblurring only. Also set in_channels=6
        multiple_degradation=False,
        graph_type="mask",
        fairscale_checkpoint=False,
        offload_to_cpu=False,
        version="v1",
        **kwargs,
    ):

        super(SemanIRUnet, self).__init__()
        if isinstance(window_size, numbers.Integral):
            window_size = [window_size] * len(num_blocks)
        else:
            window_size = window_size
        self.window_size = window_size
        self.multiple_degradation = multiple_degradation

        self.patch_embed = OverlapPatchEmbed(in_channels, dim)

        stage = partial(
            TransformerStage,
            mlp_ratio=mlp_ratio,
            top_k=top_k,
            graph_type=graph_type,
            fairscale_checkpoint=fairscale_checkpoint,
            offload_to_cpu=offload_to_cpu,
            version=version,
        )

        self.encoder_level1 = stage(
            dim=dim,
            input_resolution=img_size,
            depth=num_blocks[0],
            num_heads=heads[0],
            window_size=window_size[0],
            global_size=global_size[0],
        )

        self.down1_2 = Downsample(dim)  ## From Level 1 to Level 2
        self.encoder_level2 = stage(
            dim=int(dim * 2**1),
            input_resolution=img_size // 2**1,
            depth=num_blocks[1],
            num_heads=heads[1],
            window_size=window_size[1],
            global_size=global_size[1],
        )

        self.down2_3 = Downsample(int(dim * 2**1))  ## From Level 2 to Level 3
        self.encoder_level3 = stage(
            dim=int(dim * 2**2),
            input_resolution=img_size // 2**2,
            depth=num_blocks[2],
            num_heads=heads[2],
            window_size=window_size[2],
            global_size=global_size[2],
        )

        self.down3_4 = Downsample(int(dim * 2**2))  ## From Level 3 to Level 4
        self.latent = stage(
            dim=int(dim * 2**3),
            input_resolution=img_size // 2**3,
            depth=num_blocks[3],
            num_heads=heads[3],
            window_size=window_size[3],
            global_size=global_size[3],
        )

        self.up4_3 = Upsample(int(dim * 2**3))  ## From Level 4 to Level 3
        self.reduce_chan_level3 = nn.Conv2d(
            int(dim * 2**3), int(dim * 2**2), kernel_size=1, bias=bias
        )
        self.decoder_level3 = stage(
            dim=int(dim * 2**2),
            input_resolution=img_size // 2**2,
            depth=num_blocks[2],
            num_heads=heads[2],
            window_size=window_size[2],
            global_size=global_size[2],
        )

        self.up3_2 = Upsample(int(dim * 2**2))  ## From Level 3 to Level 2
        self.reduce_chan_level2 = nn.Conv2d(
            int(dim * 2**2), int(dim * 2**1), kernel_size=1, bias=bias
        )
        self.decoder_level2 = stage(
            dim=int(dim * 2**1),
            input_resolution=img_size // 2**1,
            depth=num_blocks[1],
            num_heads=heads[1],
            window_size=window_size[1],
            global_size=global_size[1],
        )

        self.up2_1 = Upsample(
            int(dim * 2**1)
        )  ## From Level 2 to Level 1  (NO 1x1 conv to reduce channels)

        self.decoder_level1 = stage(
            dim=int(dim * 2**1),
            input_resolution=img_size,
            depth=num_blocks[0],
            num_heads=heads[0],
            window_size=window_size[0],
            global_size=global_size[0],
        )

        self.refinement = stage(
            dim=int(dim * 2**1),
            input_resolution=img_size,
            depth=num_refinement_blocks,
            num_heads=heads[0],
            window_size=window_size[0],
            global_size=global_size[0],
        )

        #### For Dual-Pixel Defocus Deblurring Task ####
        self.dual_pixel_task = dual_pixel_task
        if self.dual_pixel_task:
            self.skip_conv = nn.Conv2d(dim, int(dim * 2**1), kernel_size=1, bias=bias)
        ###########################

        self.output = nn.Conv2d(
            int(dim * 2**1),
            out_channels,
            kernel_size=3,
            stride=1,
            padding=1,
            bias=bias,
        )

    def check_image_size(self, x):
        # print(window_size)
        pad_size = self.window_size[-1] * 2 ** (len(self.window_size) - 1)

        _, _, h, w = x.size()
        mod_pad_h = (pad_size - h % pad_size) % pad_size
        mod_pad_w = (pad_size - w % pad_size) % pad_size
        x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), "reflect")
        return x

    def forward(self, inp_img):
        H, W = inp_img.shape[2:]
        inp_img = self.check_image_size(inp_img)

        inp_enc_level1 = self.patch_embed(inp_img)
        out_enc_level1 = self.encoder_level1(inp_enc_level1)

        inp_enc_level2 = self.down1_2(out_enc_level1)
        out_enc_level2 = self.encoder_level2(inp_enc_level2)

        inp_enc_level3 = self.down2_3(out_enc_level2)
        out_enc_level3 = self.encoder_level3(inp_enc_level3)

        inp_enc_level4 = self.down3_4(out_enc_level3)
        latent = self.latent(inp_enc_level4)

        inp_dec_level3 = self.up4_3(latent)
        inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 1)
        inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3)
        out_dec_level3 = self.decoder_level3(inp_dec_level3)

        inp_dec_level2 = self.up3_2(out_dec_level3)
        inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 1)
        inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2)
        out_dec_level2 = self.decoder_level2(inp_dec_level2)

        inp_dec_level1 = self.up2_1(out_dec_level2)
        inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 1)
        out_dec_level1 = self.decoder_level1(inp_dec_level1)

        out_dec_level1 = self.refinement(out_dec_level1)

        #### For Dual-Pixel Defocus Deblurring Task ####
        if self.dual_pixel_task:
            out_dec_level1 = out_dec_level1 + self.skip_conv(inp_enc_level1)
            out_dec_level1 = self.output(out_dec_level1)
        ###########################
        else:
            if self.multiple_degradation:
                out_dec_level1 = self.output(out_dec_level1) + inp_img[:, :-1, ...]
            else:
                out_dec_level1 = self.output(out_dec_level1) + inp_img

        return out_dec_level1[:, :, :H, :W]

    def convert_checkpoint(self, state_dict):
        for k in list(state_dict.keys()):
            if (
                k.find("relative_coords_table") >= 0
                or k.find("relative_position_index") >= 0
            ):
                state_dict.pop(k)
                print(k)
        return state_dict

if __name__ == "__main__":
    model = SemanIRUnet(
        img_size=128,
        in_channels=3,
        out_channels=3,
        dim=48,
        window_size=8,
        top_k=16,
        num_blocks=[4, 6, 6, 8],
        num_refinement_blocks=4,
        heads=[1, 2, 4, 8],
        mlp_ratio=4,
        bias=True,
        dual_pixel_task=False,  ## True for dual-pixel defocus deblurring only. Also set in_channels=6
        version="v2"
    )

    print(model)
    # print(height, width, model.flops() / 1e9)

    x = torch.randn((1, 3, 256, 256))
    x = model(x)
    print(x.shape)
    num_params = 0
    for p in model.parameters():
        if p.requires_grad:
            num_params += p.numel()
    print(f"Number of parameters {num_params / 10 ** 6: 0.2f}")
