from typing import Optional, Tuple, List

import math
import logging
import k2
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertModel, BertTokenizer

from ..asr.model import ZipformerForAsrModel
from ..zipformer.utils.padding import make_pad_mask
from .asr_decoder.attention_decoder import AttentionDecoderModel
from .loss import siglip_loss
from .decode import greedy_search_batch, attention_beam_search

TEXT_MODELS = {
    'bert-base-multilingual-uncased': (BertModel, BertTokenizer, 768),
}


class ZipformerForTtaModel(ZipformerForAsrModel):
    def __init__(self, config, tokenizer):
        super().__init__(config, tokenizer)

        if config.use_attention_decoder:
            self.special_tokens = config.special_tokens
            if self.special_tokens is not None:
                self.special_to_id = {spec: self.vocab_size + i for i, spec in enumerate(self.special_tokens)}
                self.id_to_special = {self.vocab_size + i: spec for i, spec in enumerate(self.special_tokens)}
                vocab_size = self.vocab_size + len(self.special_tokens)
            else:
                vocab_size = self.vocab_size
            # Modules for attention decoder head
            self.attention_decoder = AttentionDecoderModel(
                vocab_size=vocab_size,
                decoder_dim=512,
                num_decoder_layers=6,
                attention_dim=512,
                num_heads=8,
                feedforward_dim=2048,
                memory_dim=max(config.encoder_dim),
                sos_id=1,
                eos_id=1,
                ignore_id=-1,
                label_smoothing=0.1,
            )

        self.use_s2t_alignment = config.use_s2t_alignment
        if self.use_s2t_alignment:
            model_name = config.text_embed_model
            model_cls, tokenizer_cls, embed_dim = TEXT_MODELS[model_name]
            logging.info(f"Load S2T text embedding model from {model_name}")
            self.text_embed_model = model_cls.from_pretrained(
                model_name, torch_dtype=torch.float16)
            self.text_embed_tokenizer = tokenizer_cls.from_pretrained(model_name)
            for param in self.text_embed_model.parameters():
                param.requires_grad = False

            self.encoder_align2text_proj = nn.Sequential(
                nn.Dropout(p=0.1),
                torch.nn.Linear(
                    max(config.encoder_dim),
                    embed_dim
                )
            )
            self.s2t_align_temp =  nn.Parameter(torch.ones([]) * math.log(10))
            self.s2t_align_bias = nn.Parameter(torch.ones([]) * -10.0)

    def calc_s2t_align_loss(
        self,
        encoder_out: torch.Tensor,
        encoder_out_lens: torch.Tensor,
        texts: List[str],
    ):
        device = encoder_out.device
        speech_embed = self.encoder_align2text_proj(encoder_out)
        speech_mask = ~make_pad_mask(encoder_out_lens, encoder_out.size(1)).to(device)
        # speech_mask = torch.arange(speech_embed.shape[1])[None, :].to(device) < encoder_out_lens[:, None]
        speech_embed = (speech_embed * speech_mask.unsqueeze(-1)).sum(dim=1)
        speech_embed = speech_embed  / encoder_out_lens.reshape(-1, 1)
        speech_embed = F.normalize(speech_embed, p=2, dim=-1)

        with torch.no_grad():
            text_inputs = self.text_embed_tokenizer(
                texts, return_tensors="pt", padding=True
            ).to(device)
            text_embed = self.text_embed_model(**text_inputs)["last_hidden_state"]
            text_embed = torch.sum(text_embed * (text_inputs["attention_mask"].unsqueeze(-1)), dim=1)
            text_embed = text_embed / text_inputs["attention_mask"].sum(-1).unsqueeze(-1) 
            text_embed = text_embed.detach().float()

        text_embed = F.normalize(text_embed, p=2, dim=-1)
        speech_text_logits = (speech_embed @ text_embed.T) * self.s2t_align_temp + self.s2t_align_bias
        s2t_align_loss = siglip_loss(speech_text_logits)

        return s2t_align_loss

    def forward_attention_decoder_language_token(
        self,
        encoder_out: torch.Tensor,
        encoder_out_lens: torch.Tensor,
    ):
        bs = encoder_out.size(0)
        prefix_tokens = torch.tensor(
            [self.attention_decoder.sos_id], 
            dtype=torch.long, 
            device=encoder_out.device
        ).repeat(bs).unsqueeze(1)
        logp = self.attention_decoder.forward_one_step(
            encoder_out, encoder_out_lens, prefix_tokens)
        top_prob = logp.argmax(-1).squeeze(-1).tolist()
        return [
            self.id_to_special[i] 
            if i in self.id_to_special 
            else 'ERR' for i in top_prob
        ]

    def decorate_decoder_input(
        self, 
        y: k2.RaggedTensor,
        y_lens: torch.Tensor,
        task: List[str],
        language: Optional[torch.Tensor] = None,
        language_translated: Optional[torch.Tensor] = None,
        mode: str = "src2tgt"
    ):
        """ Add task-related tokens to decoder target.
        Args:
            y, y_lens: origianl decoder target
            task: list or single str of 'transcribe' or 'translate'
            language: batch of source languages
            language_translated: batch of target languages
            mode: 
                'src2tgt':
                        asr: <src_lang><src_lang>, ignore_index = [1]
                        ast: <src_lang><translate_tgt_lang>, ignore_index = [1]
                'src2tgt_unified':
                        asr: <src_lang><src_lang>, ignore_index = [1]
                        ast: <src_lang><tgt_lang>, ignore_index = [1]
        """
        batch_size = y.dim0
        task_tag = []
        lid_tag = []
        ignore_indices = []
        has_lang = int(language is not None)

        if mode == "src2tgt":
            lid_tag = [self.special_to_id[f"<{l}>"] for l in language] if has_lang else []
            for i, t in enumerate(task):
                if t == 'transcribe':
                    task_tag.append(lid_tag[i])
                else:
                    assert t == 'translate' and language_translated is not None
                    task_tag.append(self.special_to_id[f"<{t}_{language_translated[i]}>"])
                
                ignore_indices.append([1] if t == "transcribe" else [1])
        elif mode == "src2tgt_unified":
            lid_tag = [self.special_to_id[f"<{l}>"] for l in language] if has_lang else []
            for i, t in enumerate(task):
                if t == 'transcribe':
                    task_tag.append(lid_tag[i])
                else:
                    assert t == 'translate' and language_translated is not None
                    task_tag.append(self.special_to_id[f"<{language_translated[i]}>"])
                
                ignore_indices.append([1] if t == "transcribe" else [1])
        else:
            raise ValueError(f"Invalid mode: {mode}")

        # [B, 2/1/0]
        task_tag = torch.tensor(task_tag, dtype=y.dtype).reshape(batch_size, -1)
        lid_tag = torch.tensor(lid_tag, dtype=y.dtype).reshape(batch_size, -1)
        if mode != "src2tgt":
            insert_tag = torch.cat((task_tag, lid_tag,), dim=1).to(y.device)
        else:
            insert_tag = torch.cat((lid_tag, task_tag,), dim=1).to(y.device)
            
        y_aed = k2.ragged.cat([k2.RaggedTensor(insert_tag), y], axis=1)
        y_lens = y_lens + insert_tag.size(1)

        return y_aed, y_lens, ignore_indices

    def forward(
        self,
        x: torch.Tensor,
        x_lens: torch.Tensor,
        y: k2.RaggedTensor,
        y_lens: torch.Tensor,
        prune_range: int = 5,
        am_scale: float = 0.0,
        lm_scale: float = 0.0,
        y_translated: Optional[k2.RaggedTensor] = None,
        y_translated_lens: Optional[k2.RaggedTensor] = None,
        language: Optional[List[str]] = None,
        language_translated: Optional[List[str]] = None,
        task: Optional[List[str]] = None,
        text_align: Optional[List[str]] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Args:
          x:
            A 3-D tensor of shape (N, T, C).
          x_lens:
            A 1-D tensor of shape (N,). It contains the number of frames in `x`
            before padding.
          y:
            A ragged tensor with 2 axes [utt][label]. It contains labels of each
            utterance. (Could be transcription or translation.)
          prune_range:
            The prune range for rnnt loss, it means how many symbols(context)
            we are considering for each frame to compute the loss.
          am_scale:
            The scale to smooth the loss with am (output of encoder network)
            part
          lm_scale:
            The scale to smooth the loss with lm (output of predictor network)
            part
          language:
            The speech languages of input
        Returns:
          Return the transducer losses and CTC loss,
          in form of (simple_loss, pruned_loss, ctc_loss, attention_decoder_loss)

        Note:
           Regarding am_scale & lm_scale, it will make the loss-function one of
           the form:
              lm_scale * lm_probs + am_scale * am_probs +
              (1-lm_scale-am_scale) * combined_probs
        """
        assert x.ndim == 3, x.shape
        assert x_lens.ndim == 1, x_lens.shape
        assert y.num_axes == 2, y.num_axes
        assert x.size(0) == x_lens.size(0) == y.dim0, (x.shape, x_lens.shape, y.dim0)

        device = x.device

        # Compute encoder outputs
        encoder_output = self.forward_encoder(x, x_lens)
        encoder_out = encoder_output.encoder_out
        encoder_out_lens = encoder_output.encoder_out_lens

        if self.config.use_transducer:
            # Compute transducer loss
            simple_loss, pruned_loss = self.forward_transducer(
                encoder_out=encoder_out,
                encoder_out_lens=encoder_out_lens,
                y=y.to(device),
                y_lens=y_lens,
                prune_range=prune_range,
                am_scale=am_scale,
                lm_scale=lm_scale,
            )
        else:
            simple_loss = None
            pruned_loss = None

        if self.config.use_ctc:
            # Compute CTC loss
            targets = y.values
            ctc_loss = self.forward_ctc(
                encoder_out=encoder_out,
                encoder_out_lens=encoder_out_lens,
                targets=targets,
                target_lengths=y_lens,
            )
        else:
            ctc_loss = None

        if self.config.use_attention_decoder:
            y_aed, y_aed_lens, prefix_ignore_indices = self.decorate_decoder_input(
                y_translated, 
                y_translated_lens, 
                task=task, 
                language=language,
                language_translated=language_translated,
                mode=self.config.translate_mode,
            )
            
            attention_decoder_loss, attention_decoder_out = self.attention_decoder.calc_att_loss(
                encoder_out=encoder_out,
                encoder_out_lens=encoder_out_lens,
                ys=y_aed.to(device),
                ys_lens=y_aed_lens.to(device),
                prefix_ignore_indices=prefix_ignore_indices,
            )
        else:
            attention_decoder_loss = None

        s2t_align_loss = None
        if self.use_s2t_alignment:
            s2t_align_loss = self.calc_s2t_align_loss(encoder_out, encoder_out_lens, text_align)

        return simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, s2t_align_loss

    def generate(
        self, 
        input,
        decoding_method: str = 'greedy_search', 
        blank_penalty: float = 0, 
        language: str = None,   # e.g. zh
        task: str = None,   # e.g. zh or translate_en
    ):
        # Handle flexible input
        if isinstance(input, tuple) and len(input) == 2:
            x, x_lens = input
        else:
            x, x_lens = self.extract_feature(input)
        output = self.forward_encoder(x, x_lens)

        # default as ASR
        if task is None:
            task = language

        if decoding_method == "attention_beam_search":
            prefixs = torch.tensor([
                self.attention_decoder.sos_id,
                self.special_to_id[f"<{language}>"],
                self.special_to_id[f"<{task}>"],
            ], dtype=torch.long)
            decoding_results = attention_beam_search(
                model=self,
                encoder_out=output.encoder_out,
                encoder_out_lens=output.encoder_out_lens,
                beam_size=1,
                prefix_tokens=prefixs,
                sos=self.attention_decoder.sos_id,
                eos=self.attention_decoder.eos_id,
            )
            hyp_tokens = decoding_results.hyps
        elif decoding_method == "greedy_search":
            hyp_tokens = greedy_search_batch(
                model=self,
                encoder_out=output.encoder_out,
                encoder_out_lens=output.encoder_out_lens,
                blank_penalty=blank_penalty,
            )
        hyps = self.tokenizer.decode(hyp_tokens)
        return hyps