import logging
import k2
import torch
from .base import BaseTrainer
from ..data.asr_data_module import AsrDatamodule
from ..utils.metric_tracker import MetricsTracker
from ..models.tta.utils import filter_uneven_sized_batch
from torch.nn.parallel import DistributedDataParallel as DDP

def get_lens_from_raggedtensor(y: k2.RaggedTensor):
    row_splits = y.shape.row_splits(1)
    y_lens = row_splits[1:] - row_splits[:-1]
    return y_lens


class TtaTrainer(BaseTrainer):
    def build_dataloaders(self, cfg):
        self.data_module = AsrDatamodule(cfg)
        train_dl = self.data_module.train_dl
        valid_dl = self.data_module.valid_dl
        return train_dl, valid_dl

    def _forward_one_batch(self, batch: dict, is_training: bool, return_emb=False):
        device = self.device

        if is_training:
            batch = filter_uneven_sized_batch(batch, self.cfg.data.max_duration * 110)

        feature = batch["inputs"]
        # at entry, feature is (N, T, C)
        assert feature.ndim == 3
        feature = feature.to(device)

        supervisions = batch["supervisions"]
        feature_lens = supervisions["num_frames"].to(device)

        batch_idx_train = self.global_step
        warm_step = self.cfg.trainer.rnnt_warm_step

        # TODO: refactor this into tokenizer
        def tokenize_texts(texts, spm_start_space=259):
            # FIX BUG: remove preleading space
            if isinstance(self.model, DDP):
                tokenizer = self.model.module.tokenizer
            else:
                tokenizer = self.model.tokenizer
            text_ids = tokenizer.encode(texts)
            if spm_start_space is not None:
                text_ids = [ids[1:]
                            if len(ids) and ids[0] == spm_start_space
                            else ids
                            for ids in text_ids]
            return k2.RaggedTensor(text_ids)

        # 'text' for ASR
        texts = supervisions["text"]
        y = tokenize_texts(texts).to(device)
        y_lens = get_lens_from_raggedtensor(y)

        langs = supervisions["language"]
        langs_translated = supervisions["language_translated"]
        tasks = supervisions["task"]
        y_translated = tokenize_texts(supervisions["text_translated"]).to(device)
        y_translated_lens = get_lens_from_raggedtensor(y_translated)

        num_frames = (feature_lens // 4).sum().item()
        num_samples = len(texts)
        num_langs = len(set([f"{x}-{y}" if x != y else x 
                            for x, y in zip(langs, langs_translated)]))

        with torch.set_grad_enabled(is_training):
            model_results = self.model(
                x=feature,
                x_lens=feature_lens,
                y=y,
                y_lens=y_lens,
                y_translated=y_translated,
                y_translated_lens=y_translated_lens,
                task=tasks,
                prune_range=self.cfg.trainer.prune_range,
                am_scale=self.cfg.trainer.am_scale,
                lm_scale=self.cfg.trainer.lm_scale,
                language=langs,
                language_translated=langs_translated,
                text_align=texts, 
            )
            simple_loss, pruned_loss, ctc_loss, attention_loss, s2t_align_loss = model_results

            loss = 0.0
            if simple_loss is not None:
                s = self.cfg.trainer.simple_loss_scale
                # take down the scale on the simple loss from 1.0 at the start
                # to simple_loss scale by warm_step.
                simple_loss_scale = (
                    s
                    if batch_idx_train >= warm_step
                    else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
                )
                pruned_loss_scale = (
                    1.0
                    if batch_idx_train >= warm_step
                    else 0.1 + 0.9 * (batch_idx_train / warm_step)
                )
                loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss

            if ctc_loss is not None:
                loss += self.cfg.trainer.ctc_loss_scale * ctc_loss

            if attention_loss is not None:
                loss += self.cfg.trainer.attention_loss_scale * attention_loss

            if s2t_align_loss is not None:
                s2t_align_warm_step = self.cfg.trainer.s2t_align_warm_step
                s2t_align_loss_scale = (
                    self.cfg.trainer.s2t_align_loss_scale
                    if batch_idx_train >= s2t_align_warm_step
                    else self.cfg.trainer.s2t_align_loss_scale * (
                        batch_idx_train / s2t_align_warm_step
                    )
                )
                s2t_align_loss /= num_samples
                s2t_align_loss *= num_frames
                loss += s2t_align_loss_scale * s2t_align_loss

        assert loss.requires_grad == is_training

        info = MetricsTracker()
        info.set_value('frames', num_frames, normalization='sum')
        info.set_value('samples', num_samples, normalization='sum')
        info.set_value('langs', num_langs, normalization='batch_avg')

        # Note: We use reduction=sum while computing the loss.
        info.set_value("loss", loss.detach().cpu().item() / num_frames, normalization='frame_avg')
        if simple_loss is not None:
            info.set_value("simple_loss", simple_loss.detach().cpu().item() / num_frames, normalization='frame_avg')
            info.set_value("pruned_loss", pruned_loss.detach().cpu().item() / num_frames, normalization='frame_avg')
        if ctc_loss is not None:
            info.set_value("ctc_loss", ctc_loss.detach().cpu().item() / num_frames, normalization='frame_avg')
        if attention_loss is not None:
            info.set_value("attention_loss", attention_loss.detach().cpu().item() / num_frames, normalization='frame_avg')
        if s2t_align_loss is not None:
            info.set_value("s2t_align_loss", s2t_align_loss.detach().cpu().item() / num_frames, normalization='sample_avg')

        return loss, info

    @torch.no_grad()
    def validate(self, epoch):
        """Run the validation process."""
        self.model.eval()
        for i, valid_dl_i in enumerate(self.valid_dl):
            tot_info = MetricsTracker()
            for batch_idx, batch in enumerate(valid_dl_i):
                loss, info = self._forward_one_batch(
                    batch=batch,
                    is_training=False,
                )
                assert loss.requires_grad is False
                tot_info.update(info)
                torch.cuda.empty_cache()

            if self.world_size > 1:
                tot_info.reduce(loss.device)

            if self.rank == 0:
                logging.info(f"Epoch {epoch}, global batch {self.global_step}, validation: {tot_info}")
                if self.tb_writer is not None:
                    tot_info.write_summary(
                        self.tb_writer, f"train/valid_{i}", self.global_step
                    )
        self.model.train()