import copy
from functools import partial

import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.models.layers import trunc_normal_

from .vision_transformer import VisionTransformer
from diffusers import DDPMScheduler

class ScaledDDPMScheduler(DDPMScheduler):
    def __init__(self, factor=1.2, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._adjust_betas(factor)

    def _adjust_betas(self, factor):
        self.betas = self.betas ** factor

        self.alphas = 1.0 - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
        
    def noise_sampling(self, x, timesteps=None):
        bs = x.shape[0]
        noise = torch.randn(x.shape, device=x.device)
        if timesteps == None:
            timesteps = torch.randint(0, self.config.num_train_timesteps, (bs,), device=x.device).long()
        samples = self.add_noise(x, noise, timesteps)
        return samples

class VisionTransformerForSimMIM(VisionTransformer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

        assert self.num_classes == 0

        self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
        self._trunc_normal_(self.mask_token, std=.02)
        
        self.scheduler = ScaledDDPMScheduler(
            factor=1.2, num_train_timesteps=1000, beta_start=0.0001, beta_end=0.02)
        
    def _trunc_normal_(self, tensor, mean=0., std=1.):
        trunc_normal_(tensor, mean=mean, std=std, a=-std, b=std)

    def forward(self, x, is_masked=False, noise_block=None):
        if is_masked:
            return self.forward_masked(x, noise_block)
        else:
            return self.forward_clean(x)

    def forward_clean(self, x):
        x = self.patch_embed(x)

        B, L, _ = x.shape
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)

        if self.pos_embed is not None:
            x = x + self.pos_embed
        x = self.pos_drop(x)

        rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
        for blk in self.blocks:
            x = blk(x, rel_pos_bias=rel_pos_bias)
        x = self.norm(x)

        cls_token = x[:, 0]
        x = x[:, 1:]
        B, L, C = x.shape
        H = W = int(L ** 0.5)
        x = x.permute(0, 2, 1).reshape(B, C, H, W)
        return x, cls_token

    def saliency_guided_masking(self, x, mask_ratio=0.6, mask_ratio_var=0.15, delta=0.1):
        N, L, D = x.shape

        aff = torch.matmul(x, x.permute(0, 2, 1))
        aff = nn.functional.softmax(aff, dim=2)
        aff_sum = torch.sum(aff, dim=1)

        aff_sum_normalized = (aff_sum - aff_sum.min(dim=1, keepdim=True)[0]) / \
                            (aff_sum.max(dim=1, keepdim=True)[0] - aff_sum.min(dim=1, keepdim=True)[0])

        y = (aff_sum_normalized > delta).sum(dim=1)
        d_noise = (torch.rand(N, device=x.device) * 100).long()
        timesteps = (y.float() * 10).long()
        timesteps = torch.clamp(timesteps, min=0, max=700)
        timesteps = timesteps + d_noise

        len_keep = int(L * (1 - mask_ratio))

        noise = torch.rand(N, L, device=x.device) / 2
        saliency_guided_noise = aff_sum_normalized + noise

        ids_shuffle = torch.argsort(saliency_guided_noise, dim=1)
        ids_restore = torch.argsort(ids_shuffle, dim=1)

        mask = torch.ones([N, L], device=x.device)
        mask[:, :len_keep] = 0
        mask = torch.gather(mask, dim=1, index=ids_restore)

        return mask, timesteps

    def forward_masked(self, x, noise_block=2):
        noise_x = x.clone()
        x = self.patch_embed(x)

        B, L, D = x.shape

        mask, timesteps = self.saliency_guided_masking(x)
        t = self.time_embed(timesteps, L + 1)
        mask_token = self.mask_token.expand(B, L, -1)
        w = mask.flatten(1).unsqueeze(-1).type_as(mask_token)
        x = x * (1 - w) + mask_token * w

        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        
        if self.pos_embed is not None:
            x = x + self.pos_embed
        x = self.pos_drop(x)

        rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
        
        for idx, blk in enumerate(self.blocks, noise_block):
            if idx == noise_block:
                x = x + t
                x = self.scheduler.noise_sampling(x, timesteps)
            x = blk(x, rel_pos_bias=rel_pos_bias)
        x = self.norm(x)

        cls_token = x[:, 0]
        x = x[:, 1:]
        B, L, C = x.shape
        H = W = int(L ** 0.5)
        x = x.permute(0, 2, 1).reshape(B, C, H, W)
        
        noise_x = self.scheduler.noise_sampling(noise_x, timesteps)
        mask = mask.unsqueeze(-1).repeat(1, 1, D)
        return x, noise_x, cls_token, mask

class ProjectionHead(nn.Module):
    def __init__(self, in_dim, out_dim=256):
        super().__init__()
        self.proj = nn.Sequential(
            nn.Linear(in_dim, out_dim),
            nn.LayerNorm(out_dim),
            nn.ReLU(),
            nn.Linear(out_dim, out_dim)
        )

    def forward(self, x):
        return self.proj(x)

class PredictionHead(nn.Module):
    def __init__(self, in_dim, hidden_dim=256, out_dim=256):
        super().__init__()
        self.pred = nn.Sequential(
            nn.Linear(in_dim, out_dim),
            nn.LayerNorm(out_dim),
            nn.ReLU(),
            nn.Linear(out_dim, out_dim)
        )

    def forward(self, x):
        return self.pred(x)

class SimMIM(nn.Module):
    def __init__(self, encoder, encoder_stride, momentum=0.99, device=None):
        super().__init__()
        self.device = device if device is not None else torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.encoder = encoder
        self.target_encoder = copy.deepcopy(encoder)

        self.encoder_stride = encoder_stride

        self.momentum = momentum

        for param in self.target_encoder.parameters():
            param.requires_grad = False  # Stop gradient for target encoder

        self.target_encoder = self.target_encoder.to(self.device)

        self.projection_head = ProjectionHead(encoder.embed_dim)
        self.prediction_head = PredictionHead(256)
        self.target_projection = ProjectionHead(encoder.embed_dim)

        self.decoder = nn.Sequential(
            nn.Conv2d(
                in_channels=self.encoder.num_features,
                out_channels=self.encoder_stride ** 2 * 3, kernel_size=1),
            nn.PixelShuffle(self.encoder_stride),
        )

        self.in_chans = self.encoder.in_chans
        self.patch_size = self.encoder.patch_size

    def _update_target_encoder(self):
        for param_s, param_t in zip(self.encoder.parameters(), self.target_encoder.parameters()):
            param_t.data = self.momentum * param_t.data + (1 - self.momentum) * param_s.data

    def unpatchify(self, x):
        p = self.patch_size
        h = w = int(x.shape[1] ** .5)
        assert h * w == x.shape[1]

        x = x.reshape(shape=(x.shape[0], h, w, p, p, 3))
        x = torch.einsum('nhwpqc->nchpwq', x)
        imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p))
        return imgs

    def forward(self, x, mae_loss_coef=1.0):
        if not hasattr(self, "initialized_target_encoder"):
            self._update_target_encoder()
            self.initialized_target_encoder = True

        z_source, noise_x, cls_source, mask = self.encoder(x, is_masked=True, noise_block=2) # source encoder (noisy input)
        x_rec = self.decoder(z_source)

        with torch.no_grad():
            self._update_target_encoder()
            z_target, cls_target = self.target_encoder(x) # target encoder (clean input)

        # Projection
        z_source_proj = self.projection_head(cls_source)
        z_target_proj = self.target_projection(cls_target).detach()

        # Prediction
        z_source_pred = self.prediction_head(z_source_proj)

        # contrastive loss
        loss_contrastive = F.mse_loss(z_source_pred, z_target_proj)

        mask = self.unpatchify(mask)
        loss_recon = F.l1_loss(noise_x, x_rec, reduction='none')
        loss_recon = (loss_recon * mask).sum() / (mask.sum() + 1e-5) / self.in_chans

        visible = 1 - mask
        loss_denoise = F.l1_loss(x, x_rec, reduction='none')
        loss_denoise = (loss_recon * visible).sum() / (visible.sum() + 1e-5) / self.in_chans

        avg_loss = (loss_recon + loss_denoise) / 2.0
        base_weight = 0.01
        scaling_factor = 1.0
        contrastive_weight = base_weight * (1.0 - torch.tanh(scaling_factor * avg_loss))

        loss = loss_recon + loss_denoise * 0.05 + loss_contrastive * contrastive_weight # NOTE : modify denoise loss weight 1.0 -> 0.1 

        return loss

    @torch.jit.ignore
    def no_weight_decay(self):
        if hasattr(self.encoder, 'no_weight_decay'):
            return {'encoder.' + i for i in self.encoder.no_weight_decay()}
        return {}

    @torch.jit.ignore
    def no_weight_decay_keywords(self):
        if hasattr(self.encoder, 'no_weight_decay_keywords'):
            return {'encoder.' + i for i in self.encoder.no_weight_decay_keywords()}
        return {}


def build_simmim(config):
    model_type = config.MODEL.TYPE
    if model_type == 'vit':
        encoder = VisionTransformerForSimMIM(
            img_size=config.DATA.IMG_SIZE,
            patch_size=config.MODEL.VIT.PATCH_SIZE,
            in_chans=config.MODEL.VIT.IN_CHANS,
            num_classes=0,
            embed_dim=config.MODEL.VIT.EMBED_DIM,
            depth=config.MODEL.VIT.DEPTH,
            num_heads=config.MODEL.VIT.NUM_HEADS,
            mlp_ratio=config.MODEL.VIT.MLP_RATIO,
            qkv_bias=config.MODEL.VIT.QKV_BIAS,
            drop_rate=config.MODEL.DROP_RATE,
            drop_path_rate=config.MODEL.DROP_PATH_RATE,
            norm_layer=partial(nn.LayerNorm, eps=1e-6),
            init_values=config.MODEL.VIT.INIT_VALUES,
            use_abs_pos_emb=config.MODEL.VIT.USE_APE,
            use_rel_pos_bias=config.MODEL.VIT.USE_RPB,
            use_shared_rel_pos_bias=config.MODEL.VIT.USE_SHARED_RPB,
            use_mean_pooling=config.MODEL.VIT.USE_MEAN_POOLING)
        encoder_stride = 16
    else:
        raise NotImplementedError(f"Unknown pre-train model: {model_type}")

    model = SimMIM(encoder=encoder, encoder_stride=encoder_stride)

    return model
