# Copyright    2021-2023  Xiaomi Corp.        (authors: Fangjun Kuang,
#                                                       Wei Kang,
#                                                       Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Tuple

import k2
import torch
import torch.nn as nn
import torch.nn.functional as F
from .encoder_interface import EncoderInterface
from .zipformer import Zipformer2
#from .zipformer_moe import ZipformerMoe
from .subsampling import Conv2dSubsampling
from .joiner import Joiner
from .decoder import Decoder
from .scaling import ScheduledFloat
from .model_config import ZipformerConfig
from ..modeling_output import AsrLossComponents, IcefallAsrModelOutput, ZipformerEncoderOutput

from auden.utils.icefall_utils import add_sos, make_pad_mask

class ZipformerAsrModel(nn.Module):
    def __init__(self, config: ZipformerConfig):
        super(ZipformerAsrModel, self).__init__()
        self.config = config

        # Initialize encoder embedding
        self.encoder_embed = Conv2dSubsampling(
            in_channels=config.feature_dim,
            out_channels=config.encoder_dim[0],
            dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
        )

        # Initialize encoder
        self.encoder = Zipformer2(
            output_downsampling_factor=config.output_downsampling_factor,
            downsampling_factor=tuple(config.downsampling_factor),
            encoder_dim=config.encoder_dim,
            num_encoder_layers=config.num_encoder_layers,
            encoder_unmasked_dim=config.encoder_unmasked_dim,
            query_head_dim=config.query_head_dim,
            pos_head_dim=config.pos_head_dim,
            value_head_dim=config.value_head_dim,
            num_heads=config.num_heads,
            feedforward_dim=config.feedforward_dim,
            cnn_module_kernel=config.cnn_module_kernel,
            pos_dim=config.pos_dim,
            dropout=config.dropout,
            warmup_batches=config.warmup_batches,
            causal=config.causal,
            chunk_size=tuple(config.chunk_size),
            left_context_frames=tuple(config.left_context_frames),
            num_experts=config.num_experts,
            top_k=config.top_k,
            granularity=config.granularity,
            num_shared_experts=config.num_shared_experts,
            moe_layers=config.moe_type,
        )
        
        # Initialize decoder
        if config.use_transducer:
            self.decoder = Decoder(
                vocab_size=config.vocab_size,
                decoder_dim=config.decoder_dim,
                blank_id=config.blank_id,
                context_size=config.context_size,
            )

            self.joiner = Joiner(
                encoder_dim=max(config.encoder_dim),
                decoder_dim=config.decoder_dim,
                joiner_dim=config.joiner_dim,
                vocab_size=config.vocab_size,
            )
            
            self.simple_am_proj = nn.Linear(
                max(config.encoder_dim), config.vocab_size,
            )
            self.simple_lm_proj = nn.Linear(
                config.decoder_dim, config.vocab_size,
            )
            
        else:
            self.decoder = None
            self.joiner = None
            
        if config.use_ctc:
            # Modules for CTC head
            self.ctc_output = nn.Sequential(
                nn.Dropout(p=0.1),
                nn.Linear(config.encoder_dim, config.vocab_size),
                nn.LogSoftmax(dim=-1),
            )

        self.lid_predictor = None


    def forward_encoder(
        self, x: torch.Tensor, x_lens: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Compute encoder outputs.
        Args:
          x:
            A 3-D tensor of shape (N, T, C).
          x_lens:
            A 1-D tensor of shape (N,). It contains the number of frames in `x`
            before padding.

        Returns:
          encoder_out:
            Encoder output, of shape (N, T, C).
          encoder_out_lens:
            Encoder output lengths, of shape (N,).
        """
        # logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M")
        x, x_lens = self.encoder_embed(x, x_lens)
        # logging.info(f"Memory allocated after encoder_embed: {torch.cuda.memory_allocated() // 1000000}M")
        src_key_padding_mask = make_pad_mask(x_lens)
        x = x.permute(1, 0, 2)  # (N, T, C) -> (T, N, C)

        encoder_out, encoder_out_lens, balance_loss = self.encoder(x, x_lens, src_key_padding_mask)
            
        encoder_out = encoder_out.permute(1, 0, 2)  # (T, N, C) ->(N, T, C)
        
        return ZipformerEncoderOutput(
            encoder_out=encoder_out,
            encoder_out_lens=encoder_out_lens,
        )
        

    def forward_ctc(
        self,
        encoder_out: torch.Tensor,
        encoder_out_lens: torch.Tensor,
        targets: torch.Tensor,
        target_lengths: torch.Tensor,
    ) -> torch.Tensor:
        """Compute CTC loss.
        Args:
          encoder_out:
            Encoder output, of shape (N, T, C).
          encoder_out_lens:
            Encoder output lengths, of shape (N,).
          targets:
            Target Tensor of shape (sum(target_lengths)). The targets are assumed
            to be un-padded and concatenated within 1 dimension.
        """
        # Compute CTC log-prob
        ctc_output = self.ctc_output(encoder_out)  # (N, T, C)

        ctc_loss = torch.nn.functional.ctc_loss(
            log_probs=ctc_output.permute(1, 0, 2),  # (T, N, C)
            targets=targets.cpu(),
            input_lengths=encoder_out_lens.cpu(),
            target_lengths=target_lengths.cpu(),
            reduction="sum",
        )
        return ctc_loss

    def forward_transducer(
        self,
        encoder_out: torch.Tensor,
        encoder_out_lens: torch.Tensor,
        y: k2.RaggedTensor,
        y_lens: torch.Tensor,
        prune_range: int = 5,
        am_scale: float = 0.0,
        lm_scale: float = 0.0,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Compute Transducer loss.
        Args:
          encoder_out:
            Encoder output, of shape (N, T, C).
          encoder_out_lens:
            Encoder output lengths, of shape (N,).
          y:
            A ragged tensor with 2 axes [utt][label]. It contains labels of each
            utterance.
          prune_range:
            The prune range for rnnt loss, it means how many symbols(context)
            we are considering for each frame to compute the loss.
          am_scale:
            The scale to smooth the loss with am (output of encoder network)
            part
          lm_scale:
            The scale to smooth the loss with lm (output of predictor network)
            part
        """
        # Now for the decoder, i.e., the prediction network
        blank_id = self.decoder.blank_id
        sos_y = add_sos(y, sos_id=blank_id)

        # sos_y_padded: [B, S + 1], start with SOS.
        sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)

        # decoder_out: [B, S + 1, decoder_dim]
        decoder_out = self.decoder(sos_y_padded)

        # Note: y does not start with SOS
        # y_padded : [B, S]
        y_padded = y.pad(mode="constant", padding_value=0)

        y_padded = y_padded.to(torch.int64)
        boundary = torch.zeros(
            (encoder_out.size(0), 4),
            dtype=torch.int64,
            device=encoder_out.device,
        )
        boundary[:, 2] = y_lens
        boundary[:, 3] = encoder_out_lens

        lm = self.simple_lm_proj(decoder_out)
        am = self.simple_am_proj(encoder_out)

        # if self.training and random.random() < 0.25:
        #    lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04)
        # if self.training and random.random() < 0.25:
        #    am = penalize_abs_values_gt(am, 30.0, 1.0e-04)

        with torch.cuda.amp.autocast(enabled=False):
            simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
                lm=lm.float(),
                am=am.float(),
                symbols=y_padded,
                termination_symbol=blank_id,
                lm_only_scale=lm_scale,
                am_only_scale=am_scale,
                boundary=boundary,
                reduction="sum",
                return_grad=True,
            )

        # ranges : [B, T, prune_range]
        ranges = k2.get_rnnt_prune_ranges(
            px_grad=px_grad,
            py_grad=py_grad,
            boundary=boundary,
            s_range=prune_range,
        )

        # am_pruned : [B, T, prune_range, encoder_dim]
        # lm_pruned : [B, T, prune_range, decoder_dim]
        am_pruned, lm_pruned = k2.do_rnnt_pruning(
            am=self.joiner.encoder_proj(encoder_out),
            lm=self.joiner.decoder_proj(decoder_out),
            ranges=ranges,
        )

        # logits : [B, T, prune_range, vocab_size]

        # project_input=False since we applied the decoder's input projections
        # prior to do_rnnt_pruning (this is an optimization for speed).
        logits = self.joiner(am_pruned, lm_pruned, project_input=False)

        with torch.cuda.amp.autocast(enabled=False):
            pruned_loss = k2.rnnt_loss_pruned(
                logits=logits.float(),
                symbols=y_padded,
                ranges=ranges,
                termination_symbol=blank_id,
                boundary=boundary,
                reduction="sum",
            )
        return simple_loss, pruned_loss
      
    def forward_lid_predictor(
        self,
        gate_scores: torch.Tensor,
        language,
    ):
        if self.detach_lid:
            detached_gate_scores = [gate_score.detach() for gate_score in gate_scores]
            lid_output = self.lid_predictor(detached_gate_scores)
        else:
            lid_output = self.lid_predictor(gate_scores)
        language_labels = self.lid_predictor.convert_language_ids(language)
        lid_prediction_loss = F.cross_entropy(lid_output, language_labels, reduction='mean')
        
        return lid_output, lid_prediction_loss
    
    def calculate_accuracy(self, lid_output, langs):
        return self.lid_predictor.calculate_accuracy(lid_output, langs)

    def forward(
        self,
        x: torch.Tensor,
        x_lens: torch.Tensor,
        y: k2.RaggedTensor,
        prune_range: int = 5,
        am_scale: float = 0.0,
        lm_scale: float = 0.0,
        language=None
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Args:
          x:
            A 3-D tensor of shape (N, T, C).
          x_lens:
            A 1-D tensor of shape (N,). It contains the number of frames in `x`
            before padding.
          y:
            A ragged tensor with 2 axes [utt][label]. It contains labels of each
            utterance.
          prune_range:
            The prune range for rnnt loss, it means how many symbols(context)
            we are considering for each frame to compute the loss.
          am_scale:
            The scale to smooth the loss with am (output of encoder network)
            part
          lm_scale:
            The scale to smooth the loss with lm (output of predictor network)
            part
        Returns:
          Return the transducer losses and CTC loss,
          in form of (simple_loss, pruned_loss, ctc_loss, attention_decoder_loss)

        Note:
           Regarding am_scale & lm_scale, it will make the loss-function one of
           the form:
              lm_scale * lm_probs + am_scale * am_probs +
              (1-lm_scale-am_scale) * combined_probs
        """
        assert x.ndim == 3, x.shape
        assert x_lens.ndim == 1, x_lens.shape
        assert y.num_axes == 2, y.num_axes

        assert x.size(0) == x_lens.size(0) == y.dim0, (x.shape, x_lens.shape, y.dim0)

        device = x.device

        # Compute encoder outputs
        encoder_output = self.forward_encoder(x, x_lens)
        encoder_out = encoder_output.encoder_out
        encoder_out_lens = encoder_output.encoder_out_lens

        row_splits = y.shape.row_splits(1)
        y_lens = row_splits[1:] - row_splits[:-1]

        if self.config.use_transducer:
            # Compute transducer loss
            simple_loss, pruned_loss = self.forward_transducer(
                encoder_out=encoder_out,
                encoder_out_lens=encoder_out_lens,
                y=y.to(device),
                y_lens=y_lens,
                prune_range=prune_range,
                am_scale=am_scale,
                lm_scale=lm_scale,
            )
        else:
            simple_loss = None
            pruned_loss = None

        if self.config.use_ctc:
            # Compute CTC loss
            targets = y.values
            ctc_loss = self.forward_ctc(
                encoder_out=encoder_out,
                encoder_out_lens=encoder_out_lens,
                targets=targets,
                target_lengths=y_lens,
            )
        else:
            ctc_loss = None

        if self.config.use_attention_decoder:
            attention_decoder_loss = self.attention_decoder.calc_att_loss(
                encoder_out=encoder_out,
                encoder_out_lens=encoder_out_lens,
                ys=y.to(device),
                ys_lens=y_lens.to(device),
            )
        else:
            attention_decoder_loss = torch.empty(0)
            
        if self.config.predict_lid:
            lid_output, lid_loss = self.forward_lid_predictor(encoder_output.encoder_out, language)
        else:
            lid_output = None
            lid_loss = None

        return (simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, 
                lid_output, lid_loss, 
                )


        # loss_values = AsrLossComponents(
        #     simple_loss=simple_loss,
        #     pruned_loss=pruned_loss,
        #     ctc_loss=ctc_loss,
        #     attention_decoder_loss=attention_decoder_loss,
        #     lid_loss=lid_loss,
        #     balance_loss=balance_loss,
        # )
        # output = IcefallAsrModelOutput(
        #     loss=loss_values,
        #     lid_output=lid_output,
        #     gate_logits=encoder_output.gate_logits,
        #     padding_mask=encoder_output.padding_mask,
        # )

        # return output