import os
import json
import warnings
import logging
from typing import Dict, Tuple
import torch
from .base import BaseTrainer
from ..utils.metric_tracker import MetricsTracker

from ..data.spk_veri_data_module import SpeakerVerificationDatamodule
from ..models.spk_veri.utils import build_verification_pairs, compute_verification_metrics

class SpeakerVerificationTrainer(BaseTrainer):
    def unwrap_model(self):
        return self.model.module if isinstance(self.model, torch.nn.parallel.DistributedDataParallel) else self.model
    
    def build_dataloaders(self, cfg):
        self.data_module = SpeakerVerificationDatamodule(cfg)
        train_dl = self.data_module.train_dl
        valid_dl = self.data_module.valid_dl
        return train_dl, valid_dl
        
    def _forward_one_batch(self, batch: dict, is_training: bool, return_embeddings=False
    ):
        device = self.device
        feature = batch["inputs"]
        # at entry, feature is (N, T, C)
        assert feature.ndim == 3
        feature = feature.to(device)

        supervisions = batch["supervisions"]
        label_field = getattr(self.cfg.data, "label_field", "speaker") 
        tags = supervisions[label_field]
        batch_size = len(tags) 

        feature_lens = supervisions["num_frames"].to(device)

        with torch.set_grad_enabled(is_training):
            loss, embeddings, acc, embedding_norm = self.model(
                x=feature,
                x_lens=feature_lens,
                target=tags,
            )
        assert loss.requires_grad == is_training

        info = MetricsTracker()
        num_frames = (feature_lens // 4).sum().item()
        num_samples = batch_size
        info.set_value('frames', num_frames, normalization='sum')
        info.set_value('samples', num_samples, normalization='sum')
        info.set_value('loss', loss.detach().cpu().item() / num_samples, normalization='sample_avg') 
        info.set_value('batch_acc', acc, normalization='sample_avg')
        info.set_value('emb_norm', embedding_norm, normalization='sample_avg')

        if not return_embeddings:
            return loss, info
        else:
            return loss, info, embeddings
    
    
    def validate(self, epoch):
        """Run the validation process."""
        self.model.eval()
        with torch.no_grad():
            for i, valid_dl_i in enumerate(self.valid_dl):
                speaker_ids_all_list = []
                embeddings_all = []
                tot_info = MetricsTracker()
                
                for batch_idx, batch in enumerate(valid_dl_i):
                    loss, info, embeddings = self._forward_one_batch(
                        batch=batch,
                        is_training=False,
                        return_embeddings=True
                    )
                    
                    assert loss.requires_grad is False
                    tot_info.update(info)
                    
                    label_field = getattr(self.cfg.data, "label_field", "speaker") 
                    tags = batch["supervisions"][label_field] 
                    
                    embeddings_all.append(embeddings)
                    speaker_ids_all_list.extend(tags)

                embeddings_all = torch.cat(embeddings_all, dim=0)


                scores, labels = build_verification_pairs(embeddings_all, speaker_ids_all_list, self.cfg.data.sample_ratio)
                auc, eer, threshold, min_dcf = compute_verification_metrics(scores, labels)
                # logging.info(f"Validation AUC: {auc:.4f}, EER: {eer:.4f}, Threshold: {threshold:.4f}, minDCF: {min_dcf:.4f}")

                tot_info.set_value('eer', eer, normalization='batch_avg')
                tot_info.set_value('auc', auc, normalization='batch_avg')
                tot_info.set_value('min_dcf', min_dcf, normalization='batch_avg')

                if self.world_size > 1:
                    tot_info.reduce(loss.device)
                
                if self.rank == 0:
                    logging.info(f"[Validation] Randomly sampled {len(scores)} pairs out of {len(embeddings_all) * (len(embeddings_all) - 1) // 2} total pairs (sample_ratio={self.cfg.data.sample_ratio}) for EER calculation.")
                    logging.info(f"Epoch {epoch}, global batch {self.global_step}, validation: {tot_info}")
                    if self.tb_writer is not None:
                        tot_info.write_summary(
                            self.tb_writer, f"train/valid_{i}", self.global_step
                        )
                
        self.model.train()