import logging
import torch
from .base import BaseTrainer
from ..data.audio_caption_data_module import AudioCaptionDatamodule
from ..utils.metric_tracker import MetricsTracker
from ..utils.dist import ddp_all_gather_to_rank0

from aac_metrics import Evaluate

class AudioCaptioningTrainer(BaseTrainer):
    def __init__(self, cfg, model, rank=0, local_rank=0, world_size=1):
        super().__init__(cfg, model, rank=rank, local_rank=local_rank, world_size=world_size)
        
        # Initialize evaluation metrics
        self.evaluators = {
            "bleu_4": Evaluate(metrics=["bleu_4"]),
            "meteor": Evaluate(metrics=["meteor"]),
            "rouge_l": Evaluate(metrics=["rouge_l"]),
            # "spider": Evaluate(metrics=["spider"]),
        } 

        if self.rank == 0:
            logging.info(f"Trainer initialized with model: {model.__class__.__name__}")
            logging.info(f"Using BLEU, METEOR, ROUGE and SPIDER for evaluation.")

    def build_dataloaders(self, cfg):
        self.data_module = AudioCaptionDatamodule(cfg)
        train_dl = self.data_module.train_dl
        valid_dl = self.data_module.valid_dl
        return train_dl, valid_dl
    
    def _forward_one_batch(self, batch: dict, is_training: bool):
        device = self.device
        feature = batch["inputs"]   # at entry, feature is (N, T, C)
        assert feature.ndim == 3, f"Expected feature to be 3D, got {feature.ndim}D"
        feature = feature.to(device)

        supervisions = batch["supervisions"]
        feature_lens = supervisions["num_frames"].to(device)
        raw_text = supervisions["audio_caption"]
        
        # Handle the case where text can have multiple references (e.g. Clotho)
        text = []
        for text_i in raw_text:
            if isinstance(text_i, str):
                text.append(text_i)
            elif isinstance(text_i, list):
                idx = torch.randint(high=len(text_i), size=(1,)).item() if is_training else 0
                text.append(text_i[idx])
            else:
                raise ValueError(f"Unsupported text type: {type(text_i)}")

        batch_size = len(text)

        with torch.set_grad_enabled(is_training):
            loss, output_text = self.model(
                x=feature,
                x_lens=feature_lens,
                text=text,
            )

        if loss is not None:
            assert loss.requires_grad == is_training, f"Loss requires grad: {loss.requires_grad}, is_training: {is_training}"

        info = MetricsTracker()
        num_frames = (feature_lens // 4).sum().item()
        num_samples = batch_size
        info.set_value('frames', num_frames, normalization='sum')
        info.set_value('samples', num_samples, normalization='sum')
        if loss is not None:
            info.set_value('caption_loss', loss.detach().cpu().item(), normalization='sample_avg')

        if is_training:
            return loss, info
        else:
            return loss, info, output_text
        
    def validate(self, epoch):
        """Run the validation process."""
        if self.rank == 0:
            logging.info(f"Validating epoch {epoch} at global step {self.global_step}")
        self.model.eval()

        with torch.no_grad():
            for i, valid_dl_i in enumerate(self.valid_dl):
                # audio_embeds = []
                pred_str_all, ref_str_all = [], []
                tot_info = MetricsTracker()
                for batch_idx, batch in enumerate(valid_dl_i):
                    ref_text = batch["supervisions"]["audio_caption"]
                    loss, info, pred_text = self._forward_one_batch(
                        batch=batch,
                        is_training=False,
                    )
                    
                    tot_info.update(info)

                    # put a barrier to ensure all processes have completed before gathering results
                    if self.world_size > 1:
                        torch.distributed.barrier()
                    pred_text_gather_list = ddp_all_gather_to_rank0(pred_text)
                    if self.world_size > 1:
                        torch.distributed.barrier()
                    ref_text_gather_list = ddp_all_gather_to_rank0(ref_text)
                    if self.world_size > 1:
                        torch.distributed.barrier()

                    if self.rank == 0:
                        if self.world_size > 1:
                            # For multi-GPU, we need to flatten the lists gathered from all processes
                            pred_text_gather = []
                            ref_text_gather = []
                            for pred_text_i, ref_text_i in zip(pred_text_gather_list, ref_text_gather_list):
                                pred_text_gather.extend(pred_text_i)
                                ref_text_gather.extend(ref_text_i)
                        else:
                            # For single GPU, we can directly use the gathered lists
                            pred_text_gather = pred_text_gather_list
                            ref_text_gather = ref_text_gather_list
                    
                        pred_str_all.extend(pred_text_gather)
                        ref_str_all.extend([obj if isinstance(obj, list) else [obj] for obj in ref_text_gather])  # wrap each text in a list for consistency
                        if batch_idx == 0:
                            logging.info(f"Sampled predictions: {pred_text_gather[0]}")
                            logging.info(f"Sampled references: {ref_text_gather[0]}")

                if self.world_size > 1 and loss is not None:
                    tot_info.reduce(loss.device)
                
                if self.rank == 0:
                    assert len(pred_str_all) > 0, "No predictions found"
                    assert len(ref_str_all) > 0, "No references found"
                    assert len(pred_str_all) == len(ref_str_all), "Predictions and references must have the same length"
                    # strip every string in ref_str_all
                    if i == 0:
                        ref_str_all = [[s.replace('\n', ' ').strip() for s in ref] for ref in ref_str_all]
                        pred_str_all = [s.replace('\n', ' ').strip() for s in pred_str_all]
                        print(pred_str_all[0], ref_str_all[0])

                        for metric_name, evaluator in self.evaluators.items():
                            try:
                                corpus_scores, _ = evaluator(pred_str_all, ref_str_all)
                                tot_info.set_value(metric_name, corpus_scores[metric_name], normalization='sample_avg')
                            except Exception as e:
                                logging.error(f"Error computing metrics: {e}")

                    logging.info(f"Epoch {epoch}, global batch {self.global_step}, validation: {tot_info}")
                    if self.tb_writer is not None:
                        tot_info.write_summary(
                            self.tb_writer, f"valid_{i}", self.global_step
                        )
            
        self.model.train()