from transformers import AutoModel
import torch.nn as nn
import torch

class EATModelWrapper(nn.Module):
    def __init__(self, model_id, img_size, embed_dim=768, num_classes=527, **kwargs):
        super().__init__()
        self.img_size = img_size
        self.model = AutoModel.from_pretrained(model_id, trust_remote_code=True).eval().cuda()
        self.blocks = self.model.model.blocks
        self.head = nn.Linear(embed_dim, num_classes)
        self.head.weight.data.normal_(mean=0.0, std=2.5e-5)
        self.head.bias.data.zero_()
        self.fc_norm = nn.LayerNorm(embed_dim, eps=1.e-6)
        nn.init.constant_(self.fc_norm.bias, 0)
        nn.init.constant_(self.fc_norm.weight, 1.0)
            
    def random_masking_2d(self, x, mask_t_prob, mask_f_prob):
        """
        2D: Spectrogram (msking t and f under mask_t_prob and mask_f_prob)
        Perform per-sample random masking by per-sample shuffling.
        Per-sample shuffling is done by argsort random noise.
        x: [N, L, D], sequence
        """
        
        N, L, D = x.shape  # batch, length, dim
        if self.img_size[0] == 1024:
            # for AS
            T=64
            F=8
        elif self.img_size[0] == 512:
            # for ESC
            T=32
            F=8
        elif self.img_size[0] == 128:
            # for SPC
            T=8
            F=8
        # mask T
        x = x.reshape(N, T, F, D)
        len_keep_T = int(T * (1 - mask_t_prob))
        noise = torch.rand(N, T, device=x.device)  # noise in [0, 1]
        # sort noise for each sample
        ids_shuffle = torch.argsort(noise, dim=1)  # ascend: small is keep, large is remove
        ids_keep = ids_shuffle[:, :len_keep_T]
        index = ids_keep.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, F, D)
        #x_masked = torch.gather(x, dim=1, index=index)
        #x_masked = x_masked.reshape(N,len_keep_T*F,D)
        x = torch.gather(x, dim=1, index=index) # N, len_keep_T(T'), F, D

        # mask F
        #x = x.reshape(N, T, F, D)
        x = x.permute(0,2,1,3) # N T' F D => N F T' D
        len_keep_F = int(F * (1 - mask_f_prob))
        noise = torch.rand(N, F, device=x.device)  # noise in [0, 1]
        # sort noise for each sample
        ids_shuffle = torch.argsort(noise, dim=1)  # ascend: small is keep, large is remove
        ids_keep = ids_shuffle[:, :len_keep_F]
        #index = ids_keep.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, T, D)
        index = ids_keep.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, len_keep_T, D)
        x_masked = torch.gather(x, dim=1, index=index)
        x_masked = x_masked.permute(0,2,1,3) # N F' T' D => N T' F' D 
        #x_masked = x_masked.reshape(N,len_keep*T,D)
        x_masked = x_masked.reshape(N,len_keep_F*len_keep_T,D)
            
        return x_masked, None, None

    def encode(self, x, mask_t_prob=0.0, mask_f_prob=0.0):
        B = x.shape[0]
        x = self.model.model.local_encoder(x)
        if self.model.model.fixed_positional_encoder is not None:
            x = x + self.model.model.fixed_positional_encoder(x, None)[:, :x.size(1), :]
        x = torch.cat((self.model.model.extra_tokens.expand(B, -1, -1), x), dim=1)
        x = self.model.model.pre_norm(x)
        x = self.model.model.pos_drop(x)

        mask, ids_restore, ids_keep = None, None, None
        B, L, D = x.shape
        if self.training:
            cls, x = x[:, :1, :], x[:, 1:, :]
            x, _, _ = self.random_masking_2d(x, mask_t_prob, mask_f_prob)
            x = torch.cat((cls, x), dim=1)

        for blk in self.model.model.blocks:
            x, _ = blk(x)
        return x
    
    def no_weight_decay(self):
        """Set of parameters that should not use weight decay."""
        return {'pos_embed', 'cls_token', 'dist_token'}
    
    def adjust_linear_prob_train(self):
        # Set model to training mode except for fc_norm and head
        self.model.eval()
        self.fc_norm.train() 
        self.head.train()

    def _linear_prob_freeze(self):
        for param in self.parameters():
            param.requires_grad = False
        self.head.weight.requires_grad = True
        self.head.bias.requires_grad = True
        self.fc_norm.weight.requires_grad = True
        self.fc_norm.bias.requires_grad = True

    def forward(self, x, mask_t_prob=0.0, mask_f_prob=0.0):
        features = self.encode(x, mask_t_prob, mask_f_prob)
        features = self.fc_norm(features[:, 0])
        logits = self.head(features)
        return logits

def eat_classifier(**kwargs):
    return EATModelWrapper(**kwargs)
