import logging
import json
import os
import yaml
import hydra
import torch
from torch.utils.data import DataLoader
from omegaconf import DictConfig, OmegaConf

from lhotse import CutSet, Fbank, FbankConfig
from lhotse.dataset import DynamicBucketingSampler, OnTheFlyFeatures
from auden.auto.auto_model import AutoModel
from auden.data.dataset.sound_event_detection_dataset import SoundEventDetectionDataset
from auden.trainer.sound_event_detection import compute_psds
from auden.utils.checkpoint import resolve_checkpoint_filename, generate_and_save_averaged_model
from auden.auto.auto_config import AutoConfig

from auden.models.audio_frame_classification.utils import (
    tag2_framelevel_multihot,
    decode_intervals,
    get_timestamps,
)

import numpy as np
import scipy.signal
from sed_scores_eval.utils.scores import create_score_dataframe
from sed_scores_eval.utils.array_ops import get_first_index_where
from dcase_util.data import DecisionEncoder
from desed_task.evaluation.evaluation_measures import compute_psds_from_scores

logger = logging.getLogger(__name__)

def get_test_dataloaders(cfg):
    test_dls = []
    test_names = []
    with open(cfg.data.test_data_config, 'r') as file:
        test_data_config = yaml.load(file, Loader=yaml.FullLoader)

    def remove_short_utterance(cut):
        return cut.duration >= 1.0

    for test_set in test_data_config:
        logging.info(f"Getting {test_set['manifest']} cuts")
        cutset = CutSet.from_file(test_set['manifest']).resample(16000)
        cutset = cutset.filter(remove_short_utterance)
        test_name = test_set['name']
        testset = SoundEventDetectionDataset(
            cut_transforms=[],
            input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
            return_cuts=True,
        )
        sampler = DynamicBucketingSampler(
            cutset,
            max_duration=cfg.data.max_duration,
            shuffle=False,
        )
        test_dl = DataLoader(
            testset,
            batch_size=None,
            sampler=sampler,
            num_workers=cfg.data.num_workers,
        )
        test_dls.append(test_dl)
        test_names.append(test_name)
    return test_names, test_dls


@hydra.main(version_base=None, config_path='configs', config_name="evaluate")
@torch.no_grad()
def main(cfg: DictConfig):
    logging.info("\n" + OmegaConf.to_yaml(cfg))

    # Get dataloaders
    test_sets, test_dls = get_test_dataloaders(cfg)

    # Initialize model
    checkpoint_filename = resolve_checkpoint_filename(
        checkpoint_filename=cfg.checkpoint.get("pretrained_model", None),
        epoch=cfg.checkpoint.get("epoch", 0),
        iter=cfg.checkpoint.get("iter", 0),
        avg=cfg.checkpoint.get("avg", 1),
    )
    if checkpoint_filename.startswith('averaged'):
        generate_and_save_averaged_model(cfg.exp_dir, 
                                         epoch=cfg.checkpoint.get("epoch", 0),
                                         iter=cfg.checkpoint.get("iter", 0),
                                         avg=cfg.checkpoint.get("avg", 1))
    model = AutoModel.from_pretrained(
        exp_dir=cfg.exp_dir,
        checkpoint_filename=checkpoint_filename,
    )

    device = torch.device("cuda", 0) if torch.cuda.is_available() else torch.device("cpu")
    model.to(device)
    model.eval()
    num_param = sum(p.numel() for p in model.parameters())
    logging.info(f"Number of model parameters: {num_param}")

    median_filter = [7] * len(model.label2id)
    id2label = model.id2label
    label2id = model.label2id

    event_classes = [id2label[str(i)] for i in range(len(id2label))]
    # do evaluation for the dataset
    for test_set_name, test_dl in zip(test_sets, test_dls):
        num_cuts = 0
        try:
            num_batches = len(test_dl)
        except TypeError:
            num_batches = "?"

        ground_truth = {}
        audio_durations = {}
        val_scores_postprocessed, val_scores_gt_postprocessed = {}, {}

        for batch_idx, batch in enumerate(test_dl):
            cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
            num_cuts += len(cut_ids)

            feature = batch["inputs"].to(device)
            feature_lens = batch["supervisions"]["num_frames"].to(device)
            
            tags = batch["supervisions"]["audio_tag"]
            audio_ids = batch["supervisions"]["ids"]
            start_times = batch["supervisions"]["start_times_seconds"]
            end_times = batch["supervisions"]['end_times_seconds']
            durations = batch["supervisions"]['duration_sec']

            frame_level_target = tag2_framelevel_multihot(
                tags, label2id, start_times, end_times, max(feature_lens)
            ).to(device)

            with torch.no_grad():
                results = model(
                    x=feature,
                    x_lens=feature_lens,
                    frame_target=frame_level_target,
                    return_ground_truth=True,
                )

            results = results['frame']

            logits, gt_probs, output_lens = results["logits"], results["ground_truth"], results["output_lens"]

            B, T, C = logits.shape
            pred_probs = torch.sigmoid(logits).cpu().numpy()
            gt_probs = gt_probs.cpu().numpy()
            output_lens = output_lens.cpu().numpy().tolist()
            timestamps = get_timestamps(np.arange(max(output_lens)+1))

            for k in range(B):
                audio_id = audio_ids[k]
                audio_durations[audio_id] = durations[k]
                valid_len = output_lens[k]
                ts_k = timestamps[:valid_len+1]

                if audio_id not in ground_truth:
                    ground_truth[audio_id] = []

                label_list = tags[k].split(";")
                for label, start, end in zip(label_list, start_times[k], end_times[k]):
                    onset_idx = max(get_first_index_where(ts_k, 'gt', start)-1, 0)
                    offset_idx = min(get_first_index_where(ts_k, 'geq', end), len(ts_k)-1)
                    if onset_idx < offset_idx:
                        ground_truth[audio_id].append([
                            start,
                            end,
                            label
                        ])

                probs = pred_probs[k, :valid_len, :]

                # Calculate scores
                smoothed_scores = []
                for score, fil_val in zip(probs.T, median_filter):
                    smoothed = scipy.signal.medfilt(score, kernel_size=fil_val)
                    smoothed_scores.append(smoothed)
                smoothed_scores = np.stack(smoothed_scores, axis=1)

                val_scores_postprocessed[audio_id] = create_score_dataframe(
                    scores=smoothed_scores,
                    timestamps=ts_k,
                    event_classes=event_classes,
                )

            if batch_idx % 20 == 1:
                logging.info(f"Processed {num_cuts} cuts already.")

        logging.info("Finish collecting audio logits")

        ground_truth = {
            k: v for k, v in ground_truth.items() if len(v) > 0
        }
        print(f"Number of audio files with ground truth: {len(ground_truth)}")

        for psds_type in cfg.metrics.psds_types:
            score = compute_psds(
                val_scores_postprocessed,
                ground_truth,
                audio_durations,
                psds_type=psds_type
            )
            logging.info(f"PSDS scores for {test_set_name}, {psds_type}: {score}")

        logging.info("Done")


if __name__ == "__main__":
    main()
