from typing import Optional, Tuple, List
from pathlib import Path

import json
import k2
import torch
import torch.nn as nn
import torch.nn.functional as F

from ..zipformer.model import ZipformerEncoderModel
from ..zipformer.utils.scaling import ScaledLinear
from .asr_decoder.joiner import Joiner
from .asr_decoder.decoder import Decoder
from .utils import add_sos

from ...auto.auto_tokenizer import AutoTokenizer
from ...auto.auto_config import AutoConfig
from ...utils.checkpoint import load_model_params
from auden.utils.icefall_asr_decode import greedy_search_batch


class ZipformerForAsrModel(ZipformerEncoderModel):
    @classmethod
    def from_pretrained(cls, exp_dir, checkpoint_filename='pretrained.pt'):
        """
        Load model and tokenizer from exp_dir.
        """
        config = AutoConfig.from_pretrained(exp_dir)
        tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_type, Path(exp_dir)/ 'tokenizer')
        model = cls(config, tokenizer)
        ckpt_path = Path(exp_dir) / checkpoint_filename
        load_model_params(model, ckpt_path)
        return model
    
    def __init__(self, config, tokenizer):
        super().__init__(config)
        self.tokenizer = tokenizer
        self.blank_id = self.tokenizer.blank_id
        self.vocab_size = self.tokenizer.vocab_size
        # Initialize decoder
        if config.use_transducer:
            self.decoder = Decoder(
                vocab_size=self.vocab_size,
                decoder_dim=config.decoder_dim,
                blank_id=self.blank_id,
                context_size=config.context_size,
            )

            self.joiner = Joiner(
                encoder_dim=max(config.encoder_dim),
                decoder_dim=config.decoder_dim,
                joiner_dim=config.joiner_dim,
                vocab_size=self.vocab_size,
            )
            
            self.simple_am_proj = ScaledLinear(
                max(config.encoder_dim), self.vocab_size, initial_scale=0.25
            )
            self.simple_lm_proj = ScaledLinear(
                config.decoder_dim, self.vocab_size, initial_scale=0.25
            )
            
        else:
            self.decoder = None
            self.joiner = None
            
        if config.use_ctc:
            # Modules for CTC head
            self.ctc_output = nn.Sequential(
                nn.Dropout(p=0.1),
                nn.Linear(config.encoder_dim, config.vocab_size),
                nn.LogSoftmax(dim=-1),
            )

    def forward_ctc(
        self,
        encoder_out: torch.Tensor,
        encoder_out_lens: torch.Tensor,
        targets: torch.Tensor,
        target_lengths: torch.Tensor,
    ) -> torch.Tensor:
        """Compute CTC loss.
        Args:
          encoder_out:
            Encoder output, of shape (N, T, C).
          encoder_out_lens:
            Encoder output lengths, of shape (N,).
          targets:
            Target Tensor of shape (sum(target_lengths)). The targets are assumed
            to be un-padded and concatenated within 1 dimension.
        """
        # Compute CTC log-prob
        ctc_output = self.ctc_output(encoder_out)  # (N, T, C)

        ctc_loss = torch.nn.functional.ctc_loss(
            log_probs=ctc_output.permute(1, 0, 2),  # (T, N, C)
            targets=targets.cpu(),
            input_lengths=encoder_out_lens.cpu(),
            target_lengths=target_lengths.cpu(),
            reduction="sum",
        )
        return ctc_loss

    def forward_transducer(
        self,
        encoder_out: torch.Tensor,
        encoder_out_lens: torch.Tensor,
        y: k2.RaggedTensor,
        y_lens: torch.Tensor,
        prune_range: int = 5,
        am_scale: float = 0.0,
        lm_scale: float = 0.0,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Compute Transducer loss.
        Args:
          encoder_out:
            Encoder output, of shape (N, T, C).
          encoder_out_lens:
            Encoder output lengths, of shape (N,).
          y:
            A ragged tensor with 2 axes [utt][label]. It contains labels of each
            utterance.
          prune_range:
            The prune range for rnnt loss, it means how many symbols(context)
            we are considering for each frame to compute the loss.
          am_scale:
            The scale to smooth the loss with am (output of encoder network)
            part
          lm_scale:
            The scale to smooth the loss with lm (output of predictor network)
            part
        """
        # Now for the decoder, i.e., the prediction network
        blank_id = self.decoder.blank_id
        sos_y = add_sos(y, sos_id=blank_id)

        # sos_y_padded: [B, S + 1], start with SOS.
        sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)

        # decoder_out: [B, S + 1, decoder_dim]
        decoder_out = self.decoder(sos_y_padded)

        # Note: y does not start with SOS
        # y_padded : [B, S]
        y_padded = y.pad(mode="constant", padding_value=0)

        y_padded = y_padded.to(torch.int64)
        boundary = torch.zeros(
            (encoder_out.size(0), 4),
            dtype=torch.int64,
            device=encoder_out.device,
        )
        boundary[:, 2] = y_lens
        boundary[:, 3] = encoder_out_lens

        lm = self.simple_lm_proj(decoder_out)
        am = self.simple_am_proj(encoder_out)

        # if self.training and random.random() < 0.25:
        #    lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04)
        # if self.training and random.random() < 0.25:
        #    am = penalize_abs_values_gt(am, 30.0, 1.0e-04)

        with torch.cuda.amp.autocast(enabled=False):
            simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
                lm=lm.float(),
                am=am.float(),
                symbols=y_padded,
                termination_symbol=blank_id,
                lm_only_scale=lm_scale,
                am_only_scale=am_scale,
                boundary=boundary,
                reduction="sum",
                return_grad=True,
            )

        # ranges : [B, T, prune_range]
        ranges = k2.get_rnnt_prune_ranges(
            px_grad=px_grad,
            py_grad=py_grad,
            boundary=boundary,
            s_range=prune_range,
        )

        # am_pruned : [B, T, prune_range, encoder_dim]
        # lm_pruned : [B, T, prune_range, decoder_dim]
        am_pruned, lm_pruned = k2.do_rnnt_pruning(
            am=self.joiner.encoder_proj(encoder_out),
            lm=self.joiner.decoder_proj(decoder_out),
            ranges=ranges,
        )

        # logits : [B, T, prune_range, vocab_size]

        # project_input=False since we applied the decoder's input projections
        # prior to do_rnnt_pruning (this is an optimization for speed).
        logits = self.joiner(am_pruned, lm_pruned, project_input=False)

        with torch.cuda.amp.autocast(enabled=False):
            pruned_loss = k2.rnnt_loss_pruned(
                logits=logits.float(),
                symbols=y_padded,
                ranges=ranges,
                termination_symbol=blank_id,
                boundary=boundary,
                reduction="sum",
            )
        return simple_loss, pruned_loss
    
    
    def forward(
        self,
        x: torch.Tensor,
        x_lens: torch.Tensor,
        texts: List[str],
        prune_range: int = 5,
        am_scale: float = 0.0,
        lm_scale: float = 0.0,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Args:
          x:
            A 3-D tensor of shape (N, T, C).
          x_lens:
            A 1-D tensor of shape (N,). It contains the number of frames in `x`
            before padding.
          texts:
            A list of text of shape (N,)
          prune_range:
            The prune range for rnnt loss, it means how many symbols(context)
            we are considering for each frame to compute the loss.
          am_scale:
            The scale to smooth the loss with am (output of encoder network)
            part
          lm_scale:
            The scale to smooth the loss with lm (output of predictor network)
            part
        Returns:
          Return the transducer losses and CTC loss,
          in form of (simple_loss, pruned_loss, ctc_loss)

        Note:
           Regarding am_scale & lm_scale, it will make the loss-function one of
           the form:
              lm_scale * lm_probs + am_scale * am_probs +
              (1-lm_scale-am_scale) * combined_probs
        """
        y_list = self.tokenizer.encode(texts)
        device = x.device
        y = k2.RaggedTensor(y_list).to(device)
        assert x.ndim == 3, x.shape
        assert x_lens.ndim == 1, x_lens.shape
        assert y.num_axes == 2, y.num_axes

        assert x.size(0) == x_lens.size(0) == y.dim0, (x.shape, x_lens.shape, y.dim0)

        # Compute encoder outputs
        encoder_output = self.forward_encoder(x, x_lens)
        encoder_out = encoder_output.encoder_out
        encoder_out_lens = encoder_output.encoder_out_lens

        row_splits = y.shape.row_splits(1)
        y_lens = row_splits[1:] - row_splits[:-1]

        if self.config.use_transducer:
            # Compute transducer loss
            simple_loss, pruned_loss = self.forward_transducer(
                encoder_out=encoder_out,
                encoder_out_lens=encoder_out_lens,
                y=y.to(device),
                y_lens=y_lens,
                prune_range=prune_range,
                am_scale=am_scale,
                lm_scale=lm_scale,
            )
        else:
            simple_loss = None
            pruned_loss = None

        if self.config.use_ctc:
            # Compute CTC loss
            targets = y.values
            ctc_loss = self.forward_ctc(
                encoder_out=encoder_out,
                encoder_out_lens=encoder_out_lens,
                targets=targets,
                target_lengths=y_lens,
            )
        else:
            ctc_loss = None

        return (simple_loss, pruned_loss, ctc_loss)

    def generate(self, input, 
                 decoding_method='greedy_search', 
                 blank_penalty=0):
        # Handle flexible input
        if isinstance(input, tuple) and len(input) == 2:
            x, x_lens = input
        else:
            x, x_lens = self.extract_feature(input)
            
        output = self.forward_encoder(x, x_lens)
        if decoding_method == 'greedy_search':
            hyp_tokens = greedy_search_batch(
                model=self,
                encoder_out=output.encoder_out,
                encoder_out_lens=output.encoder_out_lens,
                blank_penalty=blank_penalty,
            )
        hyps = self.tokenizer.decode(hyp_tokens)
        return hyps