import torch
import torch.nn as nn

import transformers

from .decoder import TextDecoder
from .encoder import TextEncoder
from .utils import make_pad_mask


transformers.logging.set_verbosity_error()


def prepare_input_ids(
        input_ids: torch.Tensor, pad_token_id: int=50258, ignore_token_id: int=-100
        ):
    """
    Shift input ids one token to the right.
    """
    shifted_input_ids = input_ids[:, :-1].clone()
    labels = input_ids[:, 1:].clone()
    labels.masked_fill_(labels == pad_token_id, ignore_token_id)

    return shifted_input_ids.contiguous(), labels.contiguous()

def make_attn_mask(seq_lens, dtype):
    mask = make_pad_mask(seq_lens).to(dtype) * torch.finfo(dtype).min
    mask = mask.unsqueeze(1).expand(-1, seq_lens.max(), -1)
    return mask


class WhisperTextTeacher(nn.Module):
    def __init__(
        self,
        configs,     
    ):
        super().__init__()
        
        self.configs = configs
        whisper_cfgs = self.configs['Whisper']
        self.whisper_cfgs = self.configs['Whisper']

        self.sos = self.configs['Whisper']['sos_token']
        self.eos = self.configs['Whisper']['eos_token']
        self.n_vocab = self.configs['Whisper']['n_vocab']     
        
        #-------------------------decoder-----------------------------------------
        n_vocab = whisper_cfgs.get('n_vocab',51865)
        self.n_vocab = n_vocab
        n_text_state = whisper_cfgs.get('n_text_state',1024)
        n_text_ctx = whisper_cfgs.get('n_text_ctx',448)
        
        self.decoder_token_embedding = nn.Embedding(n_vocab, n_text_state)
        self.decoder_positional_embedding = nn.Parameter(torch.empty(n_text_ctx, n_text_state))                  
        self.ast_decoder = TextDecoder(
            n_vocab = whisper_cfgs.get('n_vocab', 51865),
            n_ctx = whisper_cfgs.get('n_text_ctx', 448),
            n_state = whisper_cfgs.get('n_text_state', 1024),
            n_head = whisper_cfgs.get('n_text_head', 16),
            n_layer = whisper_cfgs.get('n_text_layer_ast', 12),
            token_embedding = self.decoder_token_embedding,
            positional_embedding = self.decoder_positional_embedding,
            use_lite_blocks = True
        )

        self.text_encoder = TextEncoder(
            n_state=whisper_cfgs.get('n_text_state', 1024),
            n_head=whisper_cfgs.get('n_text_head', 16),
            n_layer=whisper_cfgs.get('n_text_encoder_layer', 6),
            token_embedding=self.decoder_token_embedding,
            positional_embedding=self.decoder_positional_embedding,
        )
        
        #-----------------------------------------------------------------------------------
        # load previous state dict
        pre_dict = torch.load(whisper_cfgs['teacher_path'], map_location='cpu')#['model_state_dict']
        print(self.load_state_dict(pre_dict, strict=False))
        #-------------------------lora------------------------------------------------------- 
        self.requires_grad_(False)
        self.ce_loss = torch.nn.CrossEntropyLoss()
    
    
    def forward_ast_decoder(self,encoder_out,encoder_out_lens,token,token_lens):        
        max_encoder_out_len_in_batch = encoder_out_lens.max()
        max_token_len_in_batch = token_lens.max()
        memory_mask = make_pad_mask(encoder_out_lens,max_encoder_out_len_in_batch).unsqueeze(1).expand(-1,max_token_len_in_batch,-1)   
        memory_mask = torch.zeros_like(memory_mask, dtype=encoder_out.dtype).masked_fill(memory_mask, float("-inf")).to(encoder_out.device)                    
        return self.ast_decoder(token, encoder_out, None, memory_mask)
        

    def forward(
        self,
        text = None,
        text_lens = None,
        ast_ids = None,
        ast_lens = None,
        **kwargs,
    ):
        
        '''
        mel_mask = [
            [0, 0, 1, 1, 1],
            [0, 0, 0, 1, 1],
            [0, 0, 0, 1, 1]]
        toekn_mask = [
            [0, 0, 1, 1, 1],
            [0, 0, 0, 1, 1],
            [0, 0, 0, 1, 1]]
            
        '''
        # inputs = self.inputs_precess(inputs)
        dtype = next(self.parameters()).dtype
        loss_asr = loss_ast = loss_cif = loss_lid = 0
        
        encoder_out_lens = text_lens
        encoder_out_mask = make_attn_mask(encoder_out_lens, dtype)
        hidden_feature = self.text_encoder(text, encoder_out_mask)
            
        ast_mask = None

        ast_ids, ast_label = prepare_input_ids(ast_ids)
        ast_mask = ast_label != -100

        ast_logits = self.forward_ast_decoder(hidden_feature, encoder_out_lens, ast_ids, ast_lens)    
        loss_ast = self.ce_loss(ast_logits.reshape(-1, self.n_vocab), ast_label.reshape(-1))

        return {
            'mask': ast_mask,
            'logits' : ast_logits,
            'loss':loss_ast,
        }

            
        
        