import os
import json
import warnings
import logging
from typing import Dict, Tuple
import torch
from .base import BaseTrainer
from ..data.audio_tag_data_module import AudioTagDatamodule
from ..utils.metric_tracker import MetricsTracker
from sklearn.metrics import average_precision_score

class AudioTagTrainer(BaseTrainer):
    def unwrap_model(self):
        return self.model.module if isinstance(self.model, torch.nn.parallel.DistributedDataParallel) else self.model
    
    def build_dataloaders(self, cfg):
        self.data_module = AudioTagDatamodule(cfg)
        train_dl = self.data_module.train_dl
        valid_dl = self.data_module.valid_dl
        return train_dl, valid_dl
        
    def _forward_one_batch(self, batch: dict, is_training: bool, return_logits=False
    ):
        device = self.device
        feature = batch["inputs"]
        # at entry, feature is (N, T, C)
        assert feature.ndim == 3
        feature = feature.to(device)

        supervisions = batch["supervisions"]
        label_field = getattr(self.cfg.data, "label_field", "audio_tag") 
        tags = supervisions[label_field]
        batch_size = len(tags) 

        feature_lens = supervisions["num_frames"].to(device)

        with torch.set_grad_enabled(is_training):
            loss, logits, top1_acc, top5_acc = self.model(
                x=feature,
                x_lens=feature_lens,
                tags=tags,
            )
        assert loss.requires_grad == is_training

        info = MetricsTracker()
        num_frames = (feature_lens // 4).sum().item()
        num_samples = batch_size
        info.set_value('frames', num_frames, normalization='sum')
        info.set_value('samples', num_samples, normalization='sum')
        info.set_value('loss', loss.detach().cpu().item() / num_samples, normalization='sample_avg') 
        info.set_value('top1_acc', top1_acc, normalization='sample_avg')
        info.set_value('top5_acc', top5_acc, normalization='sample_avg')

        if not return_logits:
            return loss, info
        else:
            return loss, info, logits
    
    
    def validate(self, epoch):
        """Run the validation process."""
        self.model.eval()
        with torch.no_grad():
            for i, valid_dl_i in enumerate(self.valid_dl):
                logits_all = []
                labels_all = []
                tot_info = MetricsTracker()
                for batch_idx, batch in enumerate(valid_dl_i):
                    loss, info, logits = self._forward_one_batch(
                        batch=batch,
                        is_training=False,
                        return_logits=True
                    )
                    
                    assert loss.requires_grad is False
                    tot_info.update(info)
                    
                    label_field = getattr(self.cfg.data, "label_field", "audio_tag") 
                    tags = batch["supervisions"][label_field]
                    labels = self.unwrap_model().tag2multihot(tags)
                    
                    logits_all.append(logits)
                    labels_all.append(labels)

                logits_all = torch.cat(logits_all, dim=0)
                labels_all = torch.cat(labels_all, dim=0)
                
                is_multilabel = self.model.module.is_multilabel if hasattr(self.model, "module") else self.model.is_multilabel
                if is_multilabel:
                    mAP = average_precision_score(
                        y_true=labels_all.numpy(),
                        y_score=logits_all.sigmoid().cpu().detach().numpy(),
                    )
                    
                    tot_info.set_value('mAP', mAP, normalization='sample_avg')
                
                if self.world_size > 1:
                    tot_info.reduce(loss.device)
                
                if self.rank == 0:
                    logging.info(f"Epoch {epoch}, global batch {self.global_step}, validation: {tot_info}")
                    if self.tb_writer is not None:
                        tot_info.write_summary(
                            self.tb_writer, f"train/valid_{i}", self.global_step
                        )
                
        self.model.train()