import logging
import warnings
import torch
from .base import BaseTrainer
from ..data.asr_data_module import AsrDatamodule
from ..utils.metric_tracker import MetricsTracker
DEFAULT_SPEECH_TOKEN = "<speech>"

class AsrLLMTrainer(BaseTrainer):
    def build_dataloaders(self, cfg):
        self.data_module = AsrDatamodule(cfg)
        train_dl = self.data_module.train_dl
        valid_dl = self.data_module.valid_dl
        return train_dl, valid_dl
        
    def _forward_one_batch(self, batch: dict, is_training: bool, return_emb=False):
        device = self.device
        feature = batch["inputs"]
        # at entry, feature is (N, T, C)
        assert feature.ndim == 3
        feature = feature.to(device)

        supervisions = batch["supervisions"]
        feature_lens = supervisions["num_frames"].to(device)

        texts = batch["supervisions"]["text"]
        batch_size = len(texts)
        messages = []
        for _, text in enumerate(texts):
            message = [
                {"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}请转写音频为文字"},
                {"role": "assistant", "content": text},
            ]
            messages.append(message)

        with torch.set_grad_enabled(is_training):
            model_outputs, acc = self.model(
                fbank=feature,
                feature_lens=feature_lens,
                messages=messages,
            )
            loss = model_outputs.loss

        assert loss.requires_grad == is_training

        info = MetricsTracker()
        num_frames = (feature_lens // self.cfg.model.config.subsampling_factor).sum().item()
        info.set_value("frames", num_frames, normalization="sum")
        info.set_value("samples", batch_size, normalization="sum")
        info.set_value("loss", loss.detach().cpu().item(), normalization="frame_avg")
        info.set_value("acc", acc, normalization="sample_avg")

        return loss, info
        
    def validate(self, epoch):
        """Run the validation process."""
        self.model.eval()
        with torch.no_grad():
            for i, valid_dl_i in enumerate(self.valid_dl):
                tot_info = MetricsTracker()
                for batch_idx, batch in enumerate(valid_dl_i):
                    loss, info = self._forward_one_batch(
                        batch=batch,
                        is_training=False,
                    )
                    
                    assert loss.requires_grad is False
                    tot_info.update(info)
                    
                if self.world_size > 1:
                    tot_info.reduce(loss.device)
                
                if self.rank == 0:
                    logging.info(f"Epoch {epoch}, global batch {self.global_step}, validation: {tot_info}")
                    if self.tb_writer is not None:
                        tot_info.write_summary(
                            self.tb_writer, f"train/valid_{i}", self.global_step
                        )        
        self.model.train()