import torch
import torch.nn as nn
import numpy as np

import timm
from timm.models.vision_transformer import Block


def get_mae_vit_tiny_encoder():
    if timm is None: return None
    encoder = timm.create_model('vit_tiny_patch16_224', pretrained=False)
    return encoder

class MAEDecoder(nn.Module):
    def __init__(self, encoder_embed_dim=192, decoder_embed_dim=128, decoder_depth=4, 
                 decoder_num_heads=8, patch_size=16, num_patches=14*14):
        super().__init__()
        self.decoder_embed = nn.Linear(encoder_embed_dim, decoder_embed_dim, bias=True)
        self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
        
        # 使用可学习的位置编码
        self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim))

        self.decoder_blocks = nn.ModuleList([
            Block(decoder_embed_dim, decoder_num_heads, mlp_ratio=4., qkv_bias=True)
            for _ in range(decoder_depth)])

        self.decoder_norm = nn.LayerNorm(decoder_embed_dim)
        self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * 3, bias=True)

    def forward(self, x, ids_restore):
        # 嵌入到解码器维度
        x = self.decoder_embed(x)

        # 准备掩码标记
        mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] - x.shape[1] + 1, 1)

        # 将可见块的特征和掩码标记拼接
        x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1)  # 移除 CLS token
        
        # 恢复原始顺序
        x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]))
        
        # 重新添加 CLS token
        x = torch.cat([x[:, :1, :], x_], dim=1)

        # 添加解码器的位置编码
        x = x + self.decoder_pos_embed

        # 通过解码器 Transformer blocks
        for blk in self.decoder_blocks:
            x = blk(x)
        x = self.decoder_norm(x)
        
        # 预测原始像素值
        x = self.decoder_pred(x)
        
        # 移除 CLS token 的预测
        x = x[:, 1:, :]
        return x

class MaskedAutoencoderViT(nn.Module):
    def __init__(self, encoder, decoder, mask_ratio=0.75):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.mask_ratio = mask_ratio
        
        # 初始化解码器的位置编码 (简单的 sin-cos 编码)
        decoder_pos_embed = self.get_2d_sincos_pos_embed(self.decoder.decoder_pos_embed.shape[-1], 14)
        self.decoder.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))

    def get_2d_sincos_pos_embed(self, embed_dim, grid_size, cls_token=True):
        """生成 2D Sin-Cos 位置编码"""
        grid_h = np.arange(grid_size, dtype=np.float32)
        grid_w = np.arange(grid_size, dtype=np.float32)
        grid = np.meshgrid(grid_w, grid_h)
        grid = np.stack(grid, axis=0)
        grid = grid.reshape([2, 1, grid_size, grid_size])

        pos_embed = self.get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
        if cls_token:
            pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
        return pos_embed

    def get_2d_sincos_pos_embed_from_grid(self, embed_dim, grid):
        assert embed_dim % 2 == 0
        emb_h = self.get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])
        emb_w = self.get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])
        return np.concatenate([emb_h, emb_w], axis=1)

    def get_1d_sincos_pos_embed_from_grid(self, embed_dim, pos):
        omega = np.arange(embed_dim // 2, dtype=np.float32)
        omega /= embed_dim / 2.
        omega = 1. / 10000**omega
        pos = pos.reshape(-1)
        out = np.einsum('m,d->md', pos, omega)
        emb_sin = np.sin(out)
        emb_cos = np.cos(out)
        return np.concatenate([emb_sin, emb_cos], axis=1)

    def patchify(self, imgs):
        """将图像转换为 Patches"""
        p = self.encoder.patch_embed.patch_size[0]
        h, w = imgs.shape[2] // p, imgs.shape[3] // p
        x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
        x = torch.einsum('nchpwq->nhwpqc', x)
        return x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))

    def random_masking(self, x):
        """对 Patches 序列进行随机掩码"""
        N, L, D = x.shape
        len_keep = int(L * (1 - self.mask_ratio))
        
        noise = torch.rand(N, L, device=x.device)
        ids_shuffle = torch.argsort(noise, dim=1)
        ids_restore = torch.argsort(ids_shuffle, dim=1)

        ids_keep = ids_shuffle[:, :len_keep]
        x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
        
        mask = torch.ones([N, L], device=x.device)
        mask[:, :len_keep] = 0
        mask = torch.gather(mask, dim=1, index=ids_restore)

        return x_masked, mask, ids_restore

    def forward_encoder(self, x):
        x = self.encoder.patch_embed(x)
        x = x + self.encoder.pos_embed[:, 1:, :]
        x, mask, ids_restore = self.random_masking(x)
        cls_token = self.encoder.cls_token + self.encoder.pos_embed[:, :1, :]
        cls_tokens = cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        latent = self.encoder.blocks(x) 
        latent = self.encoder.norm(latent)
        return latent, mask, ids_restore

    def forward_decoder(self, x, ids_restore):
        return self.decoder(x, ids_restore)

    def forward_loss(self, imgs, pred, mask):
        """计算重建损失"""
        target = self.patchify(imgs)
        loss = (pred - target) ** 2
        loss = loss.mean(dim=-1) 
        mask = mask.to(torch.float32)
        loss = (loss * mask).sum() / mask.sum() 
        return loss

    def forward(self, imgs):
        latent, mask, ids_restore = self.forward_encoder(imgs)
        pred = self.forward_decoder(latent, ids_restore)
        loss = self.forward_loss(imgs, pred, mask)
        return loss, pred, mask