# pylint: disable=C0114, C0301, R0913, C0303, C0115, C0116, R0914, E0402
import os
import torch
import torch.nn as nn
from peft import LoraConfig, get_peft_model
import librosa
import transformers
import torch.nn.functional as F

from .encoder import AudioEncoder_MoShared,AudioEncoder_PaShared
from .decoder import TextDecoder, Mode
from .utils import make_pad_mask, add_optional_chunk_mask
from .beam import BeamSearch
from .greedy import GreedySearch


transformers.logging.set_verbosity_error()


def prepare_input_ids(
        input_ids: torch.Tensor, pad_token_id: int=50258, ignore_token_id: int=-100
        ):
    """
    Shift input ids one token to the right.
    """
    shifted_input_ids = input_ids[:, :-1].clone()
    labels = input_ids[:, 1:].clone()
    labels.masked_fill_(labels == pad_token_id, ignore_token_id)

    return shifted_input_ids.contiguous(), labels.contiguous()

def make_attn_mask(seq_lens, dtype):
    mask = make_pad_mask(seq_lens).to(dtype) * torch.finfo(dtype).min
    mask = mask.unsqueeze(1).expand(-1, seq_lens.max(), -1)
    return mask


class Whisper(nn.Module):
    def __init__(
        self,
        configs,
    ):
        super().__init__()
        
        self.configs = configs
        whisper_cfgs = self.configs['Whisper']
        self.whisper_cfgs = self.configs['Whisper']

        self.sos = self.configs['Whisper']['sos_token']
        self.eos = self.configs['Whisper']['eos_token']
        self.input_modality = 'speech'
        n_vocab = whisper_cfgs.get('n_vocab',51865)
        self.n_vocab = n_vocab
        #-------------------------embedding-----------------------------------------
        n_text_state = whisper_cfgs.get('n_text_state',1024)
        n_text_ctx = whisper_cfgs.get('n_text_ctx',448)
        self.decoder_token_embedding = nn.Embedding(n_vocab, n_text_state)
        self.decoder_positional_embedding = nn.Parameter(torch.empty(n_text_ctx, n_text_state))                  
        #-------------------------encoder-----------------------------------------
        self.n_conv1_stride = whisper_cfgs.get('n_conv1_stride',1)
        self.n_conv2_stride = whisper_cfgs.get('n_conv2_stride',2)
        self.n_mels = whisper_cfgs.get('n_mels',80)
        self.shard_encoder = AudioEncoder_MoShared(
            n_mels = whisper_cfgs.get('n_mels',80),
            n_ctx = whisper_cfgs.get('n_audio_ctx',1500),
            n_conv1_stride = self.n_conv1_stride,
            n_conv2_stride = self.n_conv2_stride,
            n_state = whisper_cfgs.get('n_audio_state',1024),
            n_head = whisper_cfgs.get('n_audio_head',16),
            n_layer = whisper_cfgs.get('n_audio_layer_moshared',20),
            )        
        self.ast_encoder = AudioEncoder_PaShared(
            n_state = whisper_cfgs.get('n_audio_state',1024),
            n_head = whisper_cfgs.get('n_audio_head',16),
            n_layer = whisper_cfgs.get('n_audio_layer_pashared',4),
            )
        #-------------------------decoder-----------------------------------------
        self.ast_decoder = TextDecoder(
            n_vocab = whisper_cfgs.get('n_vocab', 51865),
            n_ctx = whisper_cfgs.get('n_text_ctx', 448),
            n_state = whisper_cfgs.get('n_text_state', 1024),
            n_head = whisper_cfgs.get('n_text_head', 16),
            n_layer = whisper_cfgs.get('n_text_layer_ast', 12),
            token_embedding = self.decoder_token_embedding,
            positional_embedding = self.decoder_positional_embedding,
        )
        #-----------------------------------------------------------------------------------
        # load previous state dict
        if whisper_cfgs['base_path'] is not None and os.path.exists(whisper_cfgs['base_path']):
            pre_dict = torch.load(whisper_cfgs['base_path'], map_location='cpu')#['model_state_dict']
            print(self.load_state_dict(pre_dict, strict=False))
        #-------------------------lora------------------------------------------------------- 
        self.decoder_positional_embedding.requires_grad_(False)
        self.shard_encoder.requires_grad_(False)
        self.ast_encoder.requires_grad_(False)
        self.ast_decoder.requires_grad_(False)
        self.ast_decoder.policy_blocks.requires_grad_(True)
        self.ast_decoder.policy_head.requires_grad_(True)
        #--------------------------loss-------------------------------------------------------
        
    def spch_enc_mask(
            self, 
            mel: torch.Tensor,
            mel_lengths:torch.Tensor
            ):        
        mel_lengths = mel_lengths.to(torch.int64)
        stream_mask = add_optional_chunk_mask(mel, use_dynamic_chunk = True,use_dynamic_left_chunk = False)
        pad_mask = make_pad_mask(mel_lengths, mel_lengths.max()).unsqueeze(1).expand(-1, mel_lengths.max(),-1)
        stream_mask = stream_mask | pad_mask     
        stream_mask = torch.zeros_like(stream_mask, dtype=mel.dtype).masked_fill(stream_mask, float("-inf"))
        pad_mask = torch.zeros_like(pad_mask, dtype=mel.dtype).masked_fill(pad_mask, float("-inf")).to(mel.device)    
        enc_out_lens = torch.ceil(mel_lengths / self.n_conv1_stride).long()
        enc_out_lens = torch.ceil(enc_out_lens / self.n_conv2_stride).long()
        if self.n_conv1_stride>1:
            stream_mask = stream_mask[:,::self.n_conv1_stride,::self.n_conv1_stride]
            pad_mask = pad_mask[:,::self.n_conv1_stride,::self.n_conv1_stride]
        if self.n_conv2_stride>1:            
            stream_mask = stream_mask[:,::self.n_conv2_stride,::self.n_conv2_stride]
            pad_mask = pad_mask[:,::self.n_conv2_stride,::self.n_conv2_stride] 
        return pad_mask, stream_mask, enc_out_lens
    
    @torch.inference_mode()
    def forward_encoder(
        self,
        mel,
        mel_lens,
    ):
        pad_mask, stream_mask, enc_out_lens = self.spch_enc_mask(mel, mel_lens)
        shard_enc_out = self.shard_encoder(mel, stream_mask)        
        enc_out = self.ast_encoder(shard_enc_out, pad_mask)

        return enc_out, enc_out_lens

    def make_cross_mask(
        self, enc_out, enc_out_lens, ast_lens
    ):
        if enc_out_lens is not None:
            cross_mask = make_pad_mask(enc_out_lens, enc_out_lens.max()).unsqueeze(1).expand(-1, ast_lens.max(),-1)   
            cross_mask = torch.zeros_like(cross_mask, dtype=enc_out.dtype).masked_fill(cross_mask, float("-inf")).to(enc_out.device)
        else:
            cross_mask = None
        return cross_mask

    def forward(
        self,
        mel = None,
        mel_lens = None,
        ast_ids = None,
        ast_lens = None,
        f_mel = None,
        f_mel_lens = None,
        **kwargs,
    ):
        

        f_enc_out, f_enc_out_lens = self.forward_encoder(f_mel, f_mel_lens)
        enc_out, enc_out_lens = self.forward_encoder(mel, mel_lens)
        
        cross_mask = self.make_cross_mask(enc_out, enc_out_lens, ast_lens)
        f_cross_mask = self.make_cross_mask(f_enc_out, f_enc_out_lens, ast_lens)
            
        ast_ids, ast_label = prepare_input_ids(ast_ids)
        ast_mask = ast_label != -100

        logits_trunc, route_scores = self.ast_decoder(ast_ids, enc_out, cross_mask)
        logits_full, _ = self.ast_decoder(ast_ids, f_enc_out, f_cross_mask)

        divergence = 1 - F.cosine_similarity(logits_trunc, logits_full, dim=-1).detach()
        loss = F.mse_loss(route_scores, divergence)
        loss = loss.masked_fill((ast_label==-100), 0.0)
        loss = torch.sum(loss) / torch.sum(loss>0.0)

        return {
            'loss': loss,
            'bsz': mel.shape[0]
        }

            
    def generate(
        self,
        wav,
        lang_id=None,
        task = 'transcribe',
        static_chunk_size = 64,
        beam_size = 3,
        simul_threhold = 0.5,
        tokenizer = None,
    ):        

        device = next(self.parameters()).device
        dtype = next(self.parameters()).dtype
        sr = 16000
        if lang_id is not None and isinstance(lang_id, str):
            lang_id = tokenizer.added_tokens_encoder[f'<|{lang_id}|>']
        if len(wav.shape) == 2:
            wav = wav[:, 0]
        if len(wav) > 30 * sr:
            wav = wav[:30 * sr]
        if sr != 16000:
            wav = librosa.resample(wav, orig_sr=sr, target_sr=16000, res_type="fft")
            
        def compute_log_mel_spectrogram(wav,num_mel_bins=80,n_fft=400,hop_length=160):        
            sample_rate = 16000
            waveform = wav.squeeze(0)            
            window = torch.hann_window(n_fft)
            stft = torch.stft(waveform,n_fft,hop_length,window=window,return_complex=True)
            magnitudes = stft[..., :-1].abs()**2        
            filters = torch.from_numpy(librosa.filters.mel(sr=sample_rate, n_fft=n_fft, n_mels=num_mel_bins))
            mel_spec = filters @ magnitudes            
            # NOTE(xcsong): https://github.com/openai/whisper/discussions/269
            log_spec = torch.clamp(mel_spec, min=1e-10).log10()
            log_spec = torch.clamp(log_spec, min=-8.0, max=0.0)
            log_spec = (log_spec + 4.0) / 4.0
            return log_spec.unsqueeze(0)
        mel = compute_log_mel_spectrogram(torch.from_numpy(wav).float(),self.n_mels).to(device).to(dtype)
        
        # if self.encoder_streaming:
        stream_mask = add_optional_chunk_mask(
            mel,
            use_dynamic_chunk = False,
            use_dynamic_left_chunk = False,
            decoding_chunk_size = 0,
            static_chunk_size = static_chunk_size,
            num_decoding_left_chunks = -1
            )
        stream_mask = torch.zeros_like(stream_mask, dtype=mel.dtype).masked_fill(stream_mask, float("-inf")).to(device)
        if self.n_conv1_stride>1:
            stream_mask = stream_mask[:,::self.n_conv1_stride,::self.n_conv1_stride]
        if self.n_conv2_stride>1:
            stream_mask = stream_mask[:,::self.n_conv2_stride,::self.n_conv2_stride]
        # else:
        #     stream_mask = None

        x = self.shard_encoder(mel, stream_mask)
        if simul_threhold <= 0.0:
            block_size = x.shape[1] + 1
        else:
            block_size = static_chunk_size // 2

        hs = x[0]

        if beam_size > 1:
            bs = BeamSearch(self, beam_size, block_size=block_size, forced_lang_id=lang_id, task=task)
            bs_res, latency = bs(hs)
            tokens = bs_res.yseq if bs_res is not None else []
        else:
            gs = GreedySearch(self, block_size=block_size, forced_lang_id=lang_id)
            gs_res, latency = gs(hs, simul_threhold=simul_threhold)
            tokens = gs_res.yseq[0] if gs_res is not None else []

        if tokenizer is not None:
            text = tokenizer.decode(tokens, skip_special_tokens=True)
        else:
            text = None

        return tokens, text, latency
        