# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
# DeiT: https://github.com/facebookresearch/deit
# --------------------------------------------------------

from functools import partial

import torch
import torch.nn as nn
import numpy as np

from timm.models.vision_transformer import PatchEmbed, Block
from timm.models.layers.helpers import to_2tuple
from timm.models.layers.trace_utils import _assert

from util.pos_embed import get_2d_sincos_pos_embed
# from diff_aug import DiffAugment
import scipy.signal
# from torch_utils.ops import upfirdn2d

wavelets = {
    'haar': [0.7071067811865476, 0.7071067811865476],
    'db1': [0.7071067811865476, 0.7071067811865476],
    'db2': [-0.12940952255092145, 0.22414386804185735, 0.836516303737469, 0.48296291314469025],
    'db3': [0.035226291882100656, -0.08544127388224149, -0.13501102001039084, 0.4598775021193313, 0.8068915093133388,
            0.3326705529509569],
    'db4': [-0.010597401784997278, 0.032883011666982945, 0.030841381835986965, -0.18703481171888114,
            -0.02798376941698385, 0.6308807679295904, 0.7148465705525415, 0.23037781330885523],
    'db5': [0.003335725285001549, -0.012580751999015526, -0.006241490213011705, 0.07757149384006515,
            -0.03224486958502952, -0.24229488706619015, 0.13842814590110342, 0.7243085284385744, 0.6038292697974729,
            0.160102397974125],
    'db6': [-0.00107730108499558, 0.004777257511010651, 0.0005538422009938016, -0.031582039318031156,
            0.02752286553001629, 0.09750160558707936, -0.12976686756709563, -0.22626469396516913, 0.3152503517092432,
            0.7511339080215775, 0.4946238903983854, 0.11154074335008017],
    'db7': [0.0003537138000010399, -0.0018016407039998328, 0.00042957797300470274, 0.012550998556013784,
            -0.01657454163101562, -0.03802993693503463, 0.0806126091510659, 0.07130921926705004, -0.22403618499416572,
            -0.14390600392910627, 0.4697822874053586, 0.7291320908465551, 0.39653931948230575, 0.07785205408506236],
    'db8': [-0.00011747678400228192, 0.0006754494059985568, -0.0003917403729959771, -0.00487035299301066,
            0.008746094047015655, 0.013981027917015516, -0.04408825393106472, -0.01736930100202211, 0.128747426620186,
            0.00047248457399797254, -0.2840155429624281, -0.015829105256023893, 0.5853546836548691, 0.6756307362980128,
            0.3128715909144659, 0.05441584224308161],
    'sym2': [-0.12940952255092145, 0.22414386804185735, 0.836516303737469, 0.48296291314469025],
    'sym3': [0.035226291882100656, -0.08544127388224149, -0.13501102001039084, 0.4598775021193313, 0.8068915093133388,
             0.3326705529509569],
    'sym4': [-0.07576571478927333, -0.02963552764599851, 0.49761866763201545, 0.8037387518059161, 0.29785779560527736,
             -0.09921954357684722, -0.012603967262037833, 0.0322231006040427],
    'sym5': [0.027333068345077982, 0.029519490925774643, -0.039134249302383094, 0.1993975339773936, 0.7234076904024206,
             0.6339789634582119, 0.01660210576452232, -0.17532808990845047, -0.021101834024758855,
             0.019538882735286728],
    'sym6': [0.015404109327027373, 0.0034907120842174702, -0.11799011114819057, -0.048311742585633, 0.4910559419267466,
             0.787641141030194, 0.3379294217276218, -0.07263752278646252, -0.021060292512300564, 0.04472490177066578,
             0.0017677118642428036, -0.007800708325034148],
    'sym7': [0.002681814568257878, -0.0010473848886829163, -0.01263630340325193, 0.03051551316596357,
             0.0678926935013727, -0.049552834937127255, 0.017441255086855827, 0.5361019170917628, 0.767764317003164,
             0.2886296317515146, -0.14004724044296152, -0.10780823770381774, 0.004010244871533663,
             0.010268176708511255],
    'sym8': [-0.0033824159510061256, -0.0005421323317911481, 0.03169508781149298, 0.007607487324917605,
             -0.1432942383508097, -0.061273359067658524, 0.4813596512583722, 0.7771857517005235, 0.3644418948353314,
             -0.05194583810770904, -0.027219029917056003, 0.049137179673607506, 0.003808752013890615,
             -0.01495225833704823, -0.0003029205147213668, 0.0018899503327594609],
}


def toogle_grad(model, requires_grad):
    for name, param in model.named_parameters():
        param.requires_grad_(requires_grad)


class PatchEmbedWithPadding(nn.Module):
    """ 2D Image to Patch Embedding
    """
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True,
                 padding=2, padding_mode='zeros'):
        super().__init__()
        img_size = to_2tuple(img_size)
        kernel_size = to_2tuple(patch_size + padding * 2)
        patch_size = to_2tuple(patch_size)
        self.img_size = img_size
        self.patch_size = patch_size
        self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
        self.num_patches = self.grid_size[0] * self.grid_size[1]
        self.flatten = flatten
        self.padding = padding

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=patch_size,
                              padding=padding, padding_mode=padding_mode)
        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()

    def forward(self, x):
        B, C, H, W = x.shape
        _assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).")
        _assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).")
        x = self.proj(x)
        if self.flatten:
            x = x.flatten(2).transpose(1, 2)  # BCHW -> BNC
        x = self.norm(x)
        return x


class MaskedAutoencoderViT(nn.Module):
    """ Masked Autoencoder with VisionTransformer backbone
    """

    def __init__(self, img_size=224, patch_size=16, in_chans=3,
                 embed_dim=1024, depth=24, num_heads=16,
                 mlp_ratio=4., norm_layer=nn.LayerNorm, diff_aug=None, num_class=None, 
                 DLastNorm='LN', padding=0, padding_mode='zeros', drop=0.0):
        super().__init__()

        self.num_class = num_class
        self.depth = depth
        self.L = (img_size // patch_size) ** 2 + 1

        # --------------------------------------------------------------------------
        # MAE encoder specifics
        # self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
        self.patch_embed = PatchEmbedWithPadding(img_size, patch_size, in_chans, embed_dim,
                                                 padding=padding, padding_mode=padding_mode)
        num_patches = self.patch_embed.num_patches

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.noST_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim),
                                      requires_grad=False)  # fixed sin-cos embedding
        self.seg_embed = nn.Embedding(3, embed_dim)  # 0: x/(SUT), 1:S, 2:T

        self.blocks = nn.ModuleList([
            Block(embed_dim * 2, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer, drop=drop)
            for i in range(depth)])

        if DLastNorm == 'LN':
            self.norm = norm_layer(embed_dim * 2)
        elif DLastNorm == 'NO':
            self.norm = nn.Identity(embed_dim * 2)
        # elif DLastNorm == 'IN':
        #     self.norm = nn.InstanceNorm1d(self.L, affine=True)
        elif DLastNorm == 'LR':
            self.norm = nn.LeakyReLU(0.2)
        elif DLastNorm == 'BN':
            self.norm = nn.BatchNorm1d(embed_dim * 2, affine=True)

        # --------------------------------------------------------------------------
        self.drop = nn.Dropout(p=drop)
        self.head = nn.Linear(embed_dim * 2, self.num_class)  # real or fake
        self.initialize_weights()

        self.diff_aug = diff_aug
        if diff_aug is not None:
            if 'filter' in self.diff_aug:
                Hz_lo = np.asarray(wavelets['sym2'])  # H(z)
                Hz_hi = Hz_lo * ((-1) ** np.arange(Hz_lo.size))  # H(-z)
                Hz_lo2 = np.convolve(Hz_lo, Hz_lo[::-1]) / 2  # H(z) * H(z^-1) / 2
                Hz_hi2 = np.convolve(Hz_hi, Hz_hi[::-1]) / 2  # H(-z) * H(-z^-1) / 2
                Hz_fbank = np.eye(4, 1)  # Bandpass(H(z), b_i)
                for i in range(1, Hz_fbank.shape[0]):
                    Hz_fbank = np.dstack([Hz_fbank, np.zeros_like(Hz_fbank)]).reshape(Hz_fbank.shape[0], -1)[:, :-1]
                    Hz_fbank = scipy.signal.convolve(Hz_fbank, [Hz_lo2])
                    Hz_fbank[i,
                    (Hz_fbank.shape[1] - Hz_hi2.size) // 2: (Hz_fbank.shape[1] + Hz_hi2.size) // 2] += Hz_hi2
                Hz_fbank = torch.as_tensor(Hz_fbank, dtype=torch.float32)
                self.register_buffer('Hz_fbank', torch.as_tensor(Hz_fbank, dtype=torch.float32))
            else:
                self.Hz_fbank = None
            if 'geo' in self.diff_aug:
                self.register_buffer('Hz_geom', upfirdn2d.setup_filter(wavelets['sym6']))
            else:
                self.Hz_geom = None

    def initialize_weights(self):
        # initialization
        # initialize (and freeze) pos_embed by sin-cos embedding
        pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches ** .5),
                                            cls_token=True)
        self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))

        # initialize patch_embed like nn.Linear (instead of nn.Conv2d)
        w = self.patch_embed.proj.weight.data
        torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))

        # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
        torch.nn.init.normal_(self.cls_token, std=.02)

        # initialize nn.Linear and nn.LayerNorm
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            # we use xavier_uniform following official JAX ViT:
            # torch.nn.init.xavier_uniform_(m.weight, gain=0.6)
            torch.nn.init.xavier_uniform_(m.weight)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def patchify(self, imgs):
        """
        imgs: (N, 3, H, W)
        x: (N, L, patch_size**2 *3)
        """
        p = self.patch_embed.patch_size[0]
        assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0

        h = w = imgs.shape[2] // p
        x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
        x = torch.einsum('nchpwq->nhwpqc', x)
        x = x.reshape(shape=(imgs.shape[0], h * w, p ** 2 * 3))
        return x

    def unpatchify(self, x):
        """
        x: (N, L, patch_size**2 *3)
        imgs: (N, 3, H, W)
        """
        p = self.patch_embed.patch_size[0]
        h = w = int(x.shape[1] ** .5)
        assert h * w == x.shape[1]

        x = x.reshape(shape=(x.shape[0], h, w, p, p, 3))
        x = torch.einsum('nhwpqc->nchpwq', x)
        imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p))
        return imgs

    def forward_feature(self, x, mask_s, mask_t, D_no_padST=False):  # s he t yinggai you leisi sentence embedding ??
        x = self.patch_embed(x)  # B, L, E
        s_t_index = (mask_s + 2 * mask_t).long()  # [B, 197, 768]
        s_t_index = torch.cat([s_t_index[:, :1] * 0 + 1, s_t_index], dim=1)

        # add cls token for classification
        cls_token = self.cls_token.expand(x.shape[0], -1, -1)  # stole cls_tokens impl from Phil Wang, thanks
        x = torch.cat((cls_token, x), dim=1)

        # add mask embedding
        x[s_t_index == 0] = self.noST_token

        # add position embedding
        x = x + self.pos_embed
        # self.pos_drop = nn.Dropout(p=drop_rate)
        # x = self.pos_drop(x + self.pos_embed)

        # add segment/sentence embedding  # +  embedding learn para ??
        seg_embed = self.seg_embed(s_t_index)  # print('shapes ========= ', x.shape, seg_embed.shape)
        x = torch.cat([x, seg_embed], dim=2)
        BB, LL, EE = x.shape

        if not D_no_padST:  # always pad to the same L length
            # h_es = EE // 2
            # flag = 0
            for blk in self.blocks:
                x = blk(x)
                # if flag < self.depth - 1:
                #     x[:, :, h_es:] = seg_embed
                #     flag += 1
            x = self.norm(x)
            x = x[:, 0]  # cls token

        else:  # upper/lower parts may have different lengths
            # mask = torch.cat((torch.ones(BB, 1).to(x.device), s_t_index), dim=-1)
            st_num = (s_t_index > 0).sum(-1)
            h_bs, h_es = BB // 2, EE // 2
            if st_num[0] == st_num[-1]:  # two parts have same st_num
                x = x[s_t_index > 0].view(BB, -1, EE)
                # seg_embed_input = seg_embed[s_t_index > 0].view(BB, -1, h_es)

                # apply Transformer blocks
                # flag = 0
                for blk in self.blocks:
                    x = blk(x)
                    # if flag < self.depth - 1:
                    #     x[:, :, h_es:] = seg_embed_input  # x = torch.cat([x, seg_embed], dim=2)
                    #     flag += 1

                x = self.norm(x)

                # self.pool = "cls"
                # if self.pool == "mean":
                #     x = x.mean(dim=1)
                # elif self.pool == "cls":
                #     x = x[:, 0]  # cls token
                x = x[:, 0]  # cls token

            else:  # two parts have different st_num
                mask_down = torch.ones_like(s_t_index)
                mask_down[:h_bs, ...] = 0

                x_up = x[s_t_index * (1 - mask_down) > 0].view(h_bs, -1, EE)
                x_down = x[s_t_index * mask_down > 0].view(h_bs, -1, EE)
                # seg_embed_inputup = seg_embed[s_t_index * (1 - mask_down) > 0].view(h_bs, -1, h_es)
                # seg_embed_inputdown = seg_embed[s_t_index * mask_down > 0].view(h_bs, -1, h_es)

                # apply Transformer blocks
                # flag = 0
                for blk in self.blocks:
                    x_up = blk(x_up)
                    x_down = blk(x_down)
                    # if flag < self.depth - 1:
                    #     x_up[:, :, h_es:] = seg_embed_inputup
                    #     x_down[:, :, h_es:] = seg_embed_inputdown
                    #     flag += 1

                x_up = self.norm(x_up)
                x_down = self.norm(x_down)

                x = torch.cat((x_up[:, 0], x_down[:, 0]), dim=0)  # cls token

        return x

    def forward(self, imgs, mask_s=None, mask_t=None, label=None, method=None, aug=True, D_no_padST=False):

        # if "None" not in self.diff_aug and aug:
        if self.diff_aug is not None and aug:
            imgs = DiffAugment(imgs, self.diff_aug, True, [self.Hz_geom, self.Hz_fbank])
        elif self.diff_aug is not None:
            imgs = DiffAugment(imgs, "translation", True, [self.Hz_geom, self.Hz_fbank])

        latent = self.forward_feature(imgs, mask_s, mask_t, D_no_padST)
        out = self.head(self.drop(latent))

        # index = torch.LongTensor(range(out.size(0))).to(label.device)
        # out = out[index, label].unsqueeze(-1)


        index = torch.LongTensor(range(out.size(0)))
        if label.is_cuda:
            index = index.cuda(device=imgs.device)
        out = out[index, label]
        return out


def mae_vit_base_patch16_dec512d8b(**kwargs):
    model = MaskedAutoencoderViT(
        patch_size=16, embed_dim=768, depth=12, num_heads=12,
        mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model


def mae_vit_base_patch16_dec512d12h16(**kwargs):
    model = MaskedAutoencoderViT(
        patch_size=16, embed_dim=768, depth=12, num_heads=12,
        mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model


def mae_vit_large_patch16_dec512d8b(**kwargs):
    model = MaskedAutoencoderViT(
        patch_size=16, embed_dim=1024, depth=24, num_heads=16,
        mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model


def mae_vit_huge_patch14_dec512d8b(**kwargs):
    model = MaskedAutoencoderViT(
        patch_size=14, embed_dim=1280, depth=32, num_heads=16,
        mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model


def mae_vit_small_patch16_dec512d8b(**kwargs):
    model = MaskedAutoencoderViT(
        patch_size=16, embed_dim=768, depth=6, num_heads=6,
        mlp_ratio=2, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model


# set recommended archs
mae_vit_base_patch16 = mae_vit_base_patch16_dec512d8b  # decoder: 512 dim, 8 blocks
mae_vit_large_patch16 = mae_vit_large_patch16_dec512d8b  # decoder: 512 dim, 8 blocks
mae_vit_huge_patch14 = mae_vit_huge_patch14_dec512d8b  # decoder: 512 dim, 8 blocks

mae_vit_small = mae_vit_small_patch16_dec512d8b  # decoder: 512 dim, 8 blocks
mae_vit_largeDec = mae_vit_base_patch16_dec512d12h16
