import warnings
import logging
import torch
import torch.nn.functional as F
from .base import BaseTrainer
from ..data.sound_event_detection_data_module import SoundEventDetectionDatamodule
from ..utils.metric_tracker import MetricsTracker
from ..utils.dist import ddp_all_gather_to_rank0
from ..models.audio_frame_classification.utils import (
    tag2_framelevel_multihot,
    decode_intervals,
    get_timestamps,
)
from ..models.audio_tag.utils import tag2multihot

import numpy as np
import scipy.signal
from sed_scores_eval.utils.scores import create_score_dataframe
from sed_scores_eval.utils.array_ops import get_first_index_where
from dcase_util.data import DecisionEncoder
from desed_task.evaluation.evaluation_measures import compute_psds_from_scores
import torch.distributed as dist

class SoundEventDetectionTrainer(BaseTrainer):
    def unwrap_model(self):
        return self.model.module if isinstance(self.model, torch.nn.parallel.DistributedDataParallel) else self.model
    
    def build_dataloaders(self, cfg):
        self.data_module = SoundEventDetectionDatamodule(cfg)
        train_dl = self.data_module.train_dl
        valid_dl = self.data_module.valid_dl
        return train_dl, valid_dl
    
    def _forward_one_batch(self, batch: dict, is_training: bool, return_logits=False, return_lens=False, return_gt=False):
        device = self.device
        feature = batch["inputs"]
        # at entry, feature is (N, T, C)
        assert feature.ndim == 3
        feature = feature.to(device)

        supervisions = batch["supervisions"]
        tags = supervisions["audio_tag"]
        batch_size = len(tags)

        label2id = self.unwrap_model().label2id

        feature_lens = supervisions["num_frames"].to(device)

        start_times = supervisions["start_times_seconds"]
        end_times = supervisions['end_times_seconds']

        frame_level_target = tag2_framelevel_multihot(
            tags, label2id, start_times, end_times, max(feature_lens)
        ).to(device)

        with torch.set_grad_enabled(is_training):
            results = self.model(
                x=feature,
                x_lens=feature_lens,
                frame_target=frame_level_target,
                return_ground_truth=True,
            )

        results = results['frame']
        loss, logits, output_lens, ground_truth = results['loss'], results['logits'], results['output_lens'], results['ground_truth']
        assert loss.requires_grad == is_training

        info = MetricsTracker()
        num_frames = (feature_lens // 4).sum().item()
        num_samples = batch_size
        info.set_value('frames', num_frames, normalization='sum')
        info.set_value('samples', num_samples, normalization='sum')
        info.set_value('loss', loss.detach().cpu().item(), normalization='sample_avg')

        return_tuple = (loss, info)
        if return_logits:
            return_tuple += (logits,)
        if return_lens:
            return_tuple += (output_lens,)
        if return_gt:
            return_tuple += (ground_truth,)
        return return_tuple

    def validate(self, epoch):
        """Run the validation process."""
        self.model.eval()
        if self.rank == 0:
            logging.info("Validating model...")

        median_filter = [7] * len(self.unwrap_model().label2id)
        psds_scores = {k: torch.tensor(0.0, device=self.device) for k in self.cfg.trainer.psds_types}

        with torch.no_grad():
            id2label = self.unwrap_model().id2label
            label2id = self.unwrap_model().label2id

            event_classes = [id2label[str(i)] for i in range(len(id2label))]
            for i, valid_dl_i in enumerate(self.valid_dl):
                ground_truth = {}
                audio_durations = {}
                val_scores_postprocessed, val_scores_gt_postprocessed = {}, {}
                tot_info = MetricsTracker()
                for batch_idx, batch in enumerate(valid_dl_i):
                    loss, info, logits, output_lens, gt_probs = self._forward_one_batch(
                        batch=batch,
                        is_training=False,
                        return_logits=True,
                        return_lens=True,
                        return_gt=True
                    )

                    assert loss.requires_grad is False
                    tot_info.update(info)

                    audio_ids = batch["supervisions"]["ids"]
                    start_times = batch["supervisions"]["start_times_seconds"]
                    end_times = batch["supervisions"]["end_times_seconds"]
                    tags = batch["supervisions"]["audio_tag"]
                    durations = batch["supervisions"]["duration_sec"]
                    output_lens = output_lens.cpu().numpy().tolist()
                    B, T, C = logits.shape

                    pred_probs = torch.sigmoid(logits).cpu().numpy()
                    gt_probs = gt_probs.cpu().numpy()
                    timestamps = get_timestamps(np.arange(max(output_lens)+1)) 

                    for k in range(B):
                        audio_id = audio_ids[k]
                        audio_durations[audio_id] = durations[k]

                        valid_len = output_lens[k]

                        ts_k = timestamps[:valid_len+1]

                        if audio_id not in ground_truth:
                            ground_truth[audio_id] = []

                        label_list = tags[k].split(";")

                        for label, start, end in zip(label_list, start_times[k], end_times[k]):
                            onset_idx = max(get_first_index_where(ts_k, 'gt', start)-1, 0)
                            offset_idx = min(get_first_index_where(ts_k, 'geq', end), len(ts_k)-1)
                            if onset_idx < offset_idx:
                                ground_truth[audio_id].append([
                                    start,
                                    end,
                                    label
                                ])

                        probs = pred_probs[k, :valid_len, :]

                        # Calculate scores
                        smoothed_scores = []
                        for score, fil_val in zip(probs.T, median_filter):
                            smoothed = scipy.signal.medfilt(score, kernel_size=fil_val)
                            smoothed_scores.append(smoothed)
                        smoothed_scores = np.stack(smoothed_scores, axis=1)

                        val_scores_postprocessed[audio_id] = create_score_dataframe(
                            scores=smoothed_scores,
                            timestamps=ts_k,
                            event_classes=event_classes,
                        )

                ground_truth = {
                    k: v for k, v in ground_truth.items() if len(v) > 0
                }
                
                ground_truth_gather = ddp_all_gather_to_rank0(ground_truth)
                val_scores_postprocessed_gather = ddp_all_gather_to_rank0(val_scores_postprocessed)
                audio_durations_gather = ddp_all_gather_to_rank0(audio_durations)
        
                # These are list of dicts, so we need to concatenate them
                if self.world_size > 1 and self.rank == 0:
                    ground_truth_gather = merge_dicts(ground_truth_gather)
                    val_scores_postprocessed_gather = merge_dicts(val_scores_postprocessed_gather)
                    audio_durations_gather = merge_dicts(audio_durations_gather)

                if self.rank == 0:
                    for psds_type in psds_scores.keys():
                        score = compute_psds(
                            val_scores_postprocessed_gather,
                            ground_truth_gather,
                            audio_durations_gather,
                            psds_type=psds_type
                        )
                        logging.info(f"PSDS scores for validation subset {i}, {psds_type}: {score}")
                        psds_scores[psds_type] += score

            for psds_type in psds_scores.keys():
                psds_scores[psds_type] = psds_scores[psds_type] / len(self.valid_dl)
                tot_info.set_value(
                    psds_type, psds_scores[psds_type].item(), normalization='sum'
                )
            
            if self.world_size > 1:
                tot_info.reduce(self.device)

            if self.rank == 0:
                logging.info(f"Epoch {epoch}, step {self.global_step}, valid: {tot_info}")
                if self.tb_writer is not None:
                    tot_info.write_summary(self.tb_writer, f"valid/", self.global_step)
        
        self.model.train()

class DESEDTrainer(SoundEventDetectionTrainer):
    """
    Trainer for the DESED task, which is a sound event detection task.
    It inherits from SoundEventDetectionTrainer and overrides some methods.
    """
    
    def _forward_one_batch(self, batch: dict, is_training: bool, return_logits=False, return_lens=False, return_gt=False):
        device = self.device
        feature = batch["inputs"]
        # at entry, feature is (N, T, C)
        assert feature.ndim == 3
        feature = feature.to(device)

        supervisions_all = batch["supervisions"]
        tags_all = supervisions_all["audio_tag"]
        batch_size = len(tags_all)

        label2id = self.unwrap_model().label2id

        feature_lens_all = supervisions_all["num_frames"].to(device)
        start_times_all = supervisions_all["start_times_seconds"]
        end_times_all = supervisions_all['end_times_seconds']

        strong_indices = list(np.where([len(x)>0 for x in start_times_all])[0])
        weak_indices = list(np.where([len(x)==0 for x in start_times_all])[0])

        assert len(strong_indices) > 0, "No strong annotations found in the batch"
        tags = [tags_all[i] for i in strong_indices]
        start_times = [start_times_all[i] for i in strong_indices]
        end_times = [end_times_all[i] for i in strong_indices]

        frame_level_target = tag2_framelevel_multihot(
            tags, label2id, start_times, end_times, max(feature_lens_all)
        ).to(device)

        if len(weak_indices) > 0:
            weak_tags = [tags_all[i] for i in weak_indices]
            clip_level_target = tag2multihot(weak_tags, label2id).to(device)
            target = (frame_level_target, clip_level_target)
        else:
            target = frame_level_target
            weak_indices = []
        
        with torch.set_grad_enabled(is_training):
            results = self.model(
                x=feature,
                x_lens=feature_lens_all,
                frame_target=target,
                clip_only_indices=weak_indices,
                return_ground_truth=True
            )
        loss = results['total']['loss']
        assert loss.requires_grad == is_training

        frame_results = results['frame']
        frame_loss, frame_logits, frame_output_lens, frame_ground_truth = frame_results['loss'], frame_results['logits'], frame_results['output_lens'], frame_results['ground_truth']
        clip_results = results.get('clip', None)
        clip_loss = clip_results['loss'] if clip_results is not None else None

        info = MetricsTracker()
        num_frames = (feature_lens_all // 4).sum().item()
        num_samples = batch_size
        info.set_value('frames', num_frames, normalization='sum')
        info.set_value('samples', num_samples, normalization='sum')
        info.set_value('loss', loss.detach().cpu().item(), normalization='sample_avg')
        info.set_value('frame_loss', frame_loss.detach().cpu().item(), normalization='sample_avg')
        if clip_loss is not None:
            info.set_value('clip_loss', clip_loss.detach().cpu().item(), normalization='sample_avg')

        return_tuple = (loss, info)
        if return_logits:
            return_tuple += (frame_logits,)
        if return_lens:
            return_tuple += (frame_output_lens,)
        if return_gt:
            return_tuple += (frame_ground_truth,)
        return return_tuple

def merge_dicts(dicts):
    """Merge a list of dictionaries into a single dictionary."""
    merged = {}
    for d in dicts:
        merged.update(d)
    return merged

def compute_psds(scores, ground_truth, audio_durations, psds_type='psds1'):
    """
    Compute the PSDS scores from the predicted scores and ground truth.
    
    Args:
        scores: Dictionary of predicted scores.
        ground_truth: Dictionary of ground truth annotations.
        audio_durations: Dictionary of audio durations.
        
    Returns:
        A dictionary containing the PSDS scores.
    """
    VALID_TYPES = ['psds1', 'psds1_no_penalty', 'psds2']
    assert psds_type in VALID_TYPES, f"Type must be one of {VALID_TYPES}"

    if psds_type == 'psds1':
        psds_scores = compute_psds_from_scores(
            scores, ground_truth, audio_durations,
            dtc_threshold=0.7, gtc_threshold=0.7,
            cttc_threshold=None, alpha_ct=0, alpha_st=1.0, num_jobs=20
        )
    elif psds_type == 'psds1_no_penalty':
        psds_scores = compute_psds_from_scores(
            scores, ground_truth, audio_durations,
            dtc_threshold=0.7, gtc_threshold=0.7,
            cttc_threshold=None, alpha_ct=0, alpha_st=0.0, num_jobs=20,
        )
    elif psds_type == 'psds2':
        psds_scores = compute_psds_from_scores(
            scores, ground_truth, audio_durations,
            dtc_threshold=0.1, gtc_threshold=0.1,
            cttc_threshold=0.3, alpha_ct=0.5, alpha_st=1.0, num_jobs=20,
        )
    
    return psds_scores