import torch
from torch import nn
import math

class StandardViT(nn.Module):
    def __init__(self, 
                 img_size=224,
                 patch_size=16,
                 in_chans=3,
                 num_classes=10,
                 embed_dim=768,
                 depth=12,
                 num_heads=12,
                 mlp_ratio=4.,
                 drop_rate=0.1):
        super().__init__()
        
        # Patch Embedding
        self.patch_embed = nn.Conv2d(in_chans, embed_dim, 
                                   kernel_size=patch_size, 
                                   stride=patch_size)
        num_patches = (img_size // patch_size) ** 2
        
        # Positional Embedding
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        
        # Transformer Blocks
        self.blocks = nn.ModuleList([
            nn.TransformerEncoderLayer(
                d_model=embed_dim,
                nhead=num_heads,
                dim_feedforward=int(embed_dim*mlp_ratio),
                dropout=drop_rate,
                activation='gelu'
            ) for _ in range(depth)])
        
        # Head
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)

        # Init
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        
    def forward(self, x, output_feature=False):
        B = x.shape[0]
        
        # Patch Embedding
        x = self.patch_embed(x)  # [B, E, H, W]
        x = x.flatten(2).transpose(1, 2)  # [B, N, E]
        
        # Add CLS token
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embed
        
        # Transformer
        for blk in self.blocks:
            x = blk(x)
            
        # Classification
        cls_feature = self.norm(x[:, 0])
        logits = self.head(cls_feature)
        
        if output_feature:
            return logits, cls_feature
        else:
            return logits

def vit_base_patch16(**kwargs):
    return StandardViT(
        patch_size=16,
        embed_dim=768,
        depth=12,
        num_heads=12,
        mlp_ratio=4,
        **kwargs
    )

def vit_large_patch16(**kwargs):
    return StandardViT(
        embed_dim=1024,
        depth=24,
        num_heads=16,
        **kwargs
    )

def vit_small_patch16(**kwargs):
    return StandardViT(
        embed_dim=384,
        depth=9,
        num_heads=6,
        **kwargs
    )

class MAE_ViT(nn.Module):
    def __init__(self, 
                 img_size=224,
                 patch_size=16,
                 in_chans=3,
                 num_classes=10,
                 embed_dim=384,
                 depth=9,
                 num_heads=6,
                 decoder_embed_dim=192,
                 decoder_depth=4,
                 decoder_num_heads=3,
                 mlp_ratio=3.,
                 attn_drop_rate=0.1,
                 drop_rate=0.1,
                 norm_pix_loss=False):
        super().__init__()
        
        # --------------------------
        # Encoder (Original ViT)
        # --------------------------
        self.patch_embed = nn.Conv2d(in_chans, embed_dim, 
                                   kernel_size=patch_size, 
                                   stride=patch_size)
        num_patches = (img_size // patch_size) ** 2
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.blocks = nn.ModuleList([
            TransformerBlock(
                dim=embed_dim,
                num_heads=num_heads,
                mlp_ratio=mlp_ratio,
                drop=drop_rate,
                attn_drop=attn_drop_rate
            ) for _ in range(depth)])
        self.norm = nn.LayerNorm(embed_dim)
        
        # --------------------------
        # Decoder (For Pre-training)
        # --------------------------
        self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)
        self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
        self.decoder_pos_embed = nn.Parameter(
            torch.zeros(1, num_patches + 1, decoder_embed_dim))
        self.decoder_blocks = nn.ModuleList([
            TransformerBlock(decoder_embed_dim, decoder_num_heads, mlp_ratio)
            for _ in range(decoder_depth)])
        self.decoder_norm = nn.LayerNorm(decoder_embed_dim)
        self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True)
        
        # --------------------------
        # Initialization
        # --------------------------
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        nn.init.trunc_normal_(self.mask_token, std=0.02)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.xavier_uniform_(m.weight)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def random_masking(self, x, mask_ratio=0.75):
        """
        实现随机掩码，返回：
        - x_masked: 可见patch
        - mask: 二进制掩码 (0表示保留, 1表示掩码)
        - ids_restore: 恢复原始顺序的索引
        """
        B, N, D = x.shape
        len_keep = int(N * (1 - mask_ratio))
        
        noise = torch.rand(B, N, device=x.device)
        ids_shuffle = torch.argsort(noise, dim=1)
        ids_restore = torch.argsort(ids_shuffle, dim=1)
        
        ids_keep = ids_shuffle[:, :len_keep]
        x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1,1,D))
        
        mask = torch.ones([B, N], device=x.device)
        mask[:, :len_keep] = 0
        mask = torch.gather(mask, dim=1, index=ids_restore)
        
        return x_masked, mask, ids_restore

    def forward_encoder(self, x, mask_ratio):
        x = self.patch_embed(x)
        x = x.flatten(2).transpose(1, 2)
        
        x = x + self.pos_embed[:, 1:, :]
        
        x, mask, ids_restore = self.random_masking(x, mask_ratio)
        
        cls_tokens = self.cls_token + self.pos_embed[:, :1, :]
        cls_tokens = cls_tokens.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        
        for blk in self.blocks:
            x = blk(x)
        return x, mask, ids_restore

    def forward_decoder(self, x, ids_restore):
        x = self.decoder_embed(x)
        
        mask_tokens = self.mask_token.repeat(
            x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
        x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1)
        x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1,1,x.shape[2]))
        x = torch.cat([x[:, :1, :], x_], dim=1)
        
        x = x + self.decoder_pos_embed
        
        for blk in self.decoder_blocks:
            x = blk(x)
        x = self.decoder_norm(x)
        
        pred = self.decoder_pred(x)
        return pred

    def forward_loss(self, imgs, pred, mask):
        """
        计算MSE损失（标准化像素值可选）
        """
        target = self.patchify(imgs)
        if self.norm_pix_loss:
            mean = target.mean(dim=-1, keepdim=True)
            var = target.var(dim=-1, keepdim=True)
            target = (target - mean) / (var + 1.e-6)**.5
            
        loss = (pred - target) ** 2
        loss = loss.mean(dim=-1)
        loss = (loss * mask).sum() / mask.sum()
        return loss

    def patchify(self, imgs):
        """
        将图像转换为patch序列
        """
        p = self.patch_embed.kernel_size[0]
        B, C, H, W = imgs.shape
        x = imgs.reshape(B, C, H//p, p, W//p, p)
        x = torch.einsum('nchpwq->nhwpqc', x)
        x = x.reshape(B, (H//p)*(W//p), p**2 * C)
        return x

    def forward(self, imgs, mask_ratio=0.75):
        latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio)
        pred = self.forward_decoder(latent, ids_restore)
        loss = self.forward_loss(imgs, pred, mask)
        return loss, pred, mask

class TransformerBlock(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4., drop=0., attn_drop=0.):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = nn.MultiheadAttention(dim, num_heads, dropout=attn_drop, batch_first=True)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, int(dim*mlp_ratio)),
            nn.GELU(),
            nn.Dropout(drop),
            nn.Linear(int(dim*mlp_ratio), dim),
            nn.Dropout(drop)
        )
        
    def forward(self, x):
        x = x + self.attn(self.norm1(x))[0]
        x = x + self.mlp(self.norm2(x))
        return x
    

def mae_vit_base(**kwargs):
    return MAE_ViT(
        embed_dim=768,
        depth=12,
        num_heads=12,
        decoder_embed_dim=512,
        decoder_depth=8,
        decoder_num_heads=16,
        mlp_ratio=4,
        **kwargs
    )

def mae_vit_small(**kwargs):
    return MAE_ViT(
        embed_dim=384,
        depth=9,
        num_heads=6,
        decoder_embed_dim=192,
        decoder_depth=4,
        decoder_num_heads=3,
        mlp_ratio=3,
        **kwargs
    )

def load_pretrained_mae(model, path):
    checkpoint = torch.load(path, map_location='cpu')
    
    state_dict = {k: v for k, v in checkpoint['model'].items() 
                if 'decoder' not in k}
    
    msg = model.load_state_dict(state_dict, strict=False)
    print(f'Loaded pretrained MAE weights with msg: {msg}')

def replace_classifier(model, num_classes=10):
    if hasattr(model, 'head'):
        model.head = nn.Linear(model.embed_dim, num_classes)
    else:
        model.classifier = nn.Linear(model.embed_dim, num_classes)
    
    for name, param in model.named_parameters():
        if 'patch_embed' in name or 'blocks.0' in name:
            param.requires_grad = False