import logging
import warnings
import torch
from .base import BaseTrainer
from ..data.asr_data_module import AsrDatamodule
from ..utils.metric_tracker import MetricsTracker

class AsrTrainer(BaseTrainer):
    def build_dataloaders(self, cfg):
        self.data_module = AsrDatamodule(cfg)
        train_dl = self.data_module.train_dl
        valid_dl = self.data_module.valid_dl
        return train_dl, valid_dl
        
    def _forward_one_batch(self, batch: dict, is_training: bool, return_emb=False):
        device = self.device
        feature = batch["inputs"]
        # at entry, feature is (N, T, C)
        assert feature.ndim == 3
        feature = feature.to(device)

        supervisions = batch["supervisions"]
        feature_lens = supervisions["num_frames"].to(device)

        batch_idx_train = self.global_step
        warm_step = self.cfg.trainer.rnnt_warm_step

        texts = batch["supervisions"]["text"]
        batch_size = len(texts)

        with torch.set_grad_enabled(is_training):
            simple_loss, pruned_loss, ctc_loss = self.model(
                x=feature,
                x_lens=feature_lens,
                texts=texts,
                prune_range=self.cfg.trainer.prune_range,
                am_scale=self.cfg.trainer.am_scale,
                lm_scale=self.cfg.trainer.lm_scale,
            )
            
            loss = 0.0

            if simple_loss:
                s = self.cfg.trainer.simple_loss_scale
                # take down the scale on the simple loss from 1.0 at the start
                # to simple_loss scale by warm_step.
                simple_loss_scale = (
                    s
                    if batch_idx_train >= warm_step
                    else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
                )
                pruned_loss_scale = (
                    1.0
                    if batch_idx_train >= warm_step
                    else 0.1 + 0.9 * (batch_idx_train / warm_step)
                )
                loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
            
            if ctc_loss:
                loss += self.cfg.trainer.ctc_loss_scale * ctc_loss

        assert loss.requires_grad == is_training

        info = MetricsTracker()
        num_frames = (feature_lens // 4).sum().item()
        num_samples = batch_size
        info.set_value('frames', num_frames, normalization='sum')
        info.set_value('samples', num_samples, normalization='sum')
        
        # Note: We use reduction=sum while computing the loss.
        info.set_value("loss", loss.detach().cpu().item() / num_frames, normalization='frame_avg')
        if simple_loss:
            info.set_value("simple_loss", simple_loss.detach().cpu().item() / num_frames, normalization='frame_avg')
            info.set_value("pruned_loss", pruned_loss.detach().cpu().item() / num_frames, normalization='frame_avg')
        if ctc_loss:
            info.set_value("ctc_loss", ctc_loss.detach().cpu().item() / num_frames, normalization='frame_avg')

        return loss, info
        
    def validate(self, epoch):
        """Run the validation process."""
        self.model.eval()
        with torch.no_grad():
            for i, valid_dl_i in enumerate(self.valid_dl):
                tot_info = MetricsTracker()
                for batch_idx, batch in enumerate(valid_dl_i):
                    loss, info = self._forward_one_batch(
                        batch=batch,
                        is_training=False,
                    )
                    
                    assert loss.requires_grad is False
                    tot_info.update(info)
                    
                if self.world_size > 1:
                    tot_info.reduce(loss.device)
                
                if self.rank == 0:
                    logging.info(f"Epoch {epoch}, global batch {self.global_step}, validation: {tot_info}")
                    if self.tb_writer is not None:
                        tot_info.write_summary(
                            self.tb_writer, f"train/valid_{i}", self.global_step
                        )        
        self.model.train()