import logging
import json
import os
import yaml
import hydra
import torch
from torch.utils.data import DataLoader
from omegaconf import DictConfig, OmegaConf

from lhotse import CutSet, Fbank, FbankConfig
from lhotse.dataset import DynamicBucketingSampler, OnTheFlyFeatures
from auden.auto.auto_model import AutoModel
from auden.models.audio_tag.utils import compute_acc
from auden.data.dataset.audio_tag_dataset import AudioTaggingDataset
from auden.utils.checkpoint import resolve_checkpoint_filename, generate_and_save_averaged_model
from sklearn.metrics import average_precision_score

def get_test_dataloaders(cfg):
    test_dls = []
    test_names = []
    with open(cfg.data.test_data_config, 'r') as file:
        test_data_config = yaml.load(file, Loader=yaml.FullLoader)

    def remove_short_utterance(cut):
        return cut.duration >= 1.0

    for test_set in test_data_config:
        logging.info(f"Getting {test_set['manifest']} cuts")
        cutset = CutSet.from_file(test_set['manifest']).resample(16000)
        cutset = cutset.filter(remove_short_utterance)
        test_name = test_set['name']
        testset = AudioTaggingDataset(
            input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
            return_cuts=True,
        )
        sampler = DynamicBucketingSampler(
            cutset,
            max_duration=cfg.data.max_duration,
            shuffle=False,
        )
        test_dl = DataLoader(
            testset,
            batch_size=None,
            sampler=sampler,
            num_workers=cfg.data.num_workers,
        )
        test_dls.append(test_dl)
        test_names.append(test_name)
    return test_names, test_dls


@hydra.main(version_base=None, config_path='configs', config_name="evaluate")
@torch.no_grad()
def main(cfg: DictConfig):
    logging.info("\n" + OmegaConf.to_yaml(cfg))

    # Get dataloaders
    test_sets, test_dls = get_test_dataloaders(cfg)

    # Initialize model
    checkpoint_filename = resolve_checkpoint_filename(
        checkpoint_filename=cfg.checkpoint.get("pretrained_model", None),
        epoch=cfg.checkpoint.get("epoch", 0),
        iter=cfg.checkpoint.get("iter", 0),
        avg=cfg.checkpoint.get("avg", 1),
    )
    if checkpoint_filename.startswith('averaged'):
        generate_and_save_averaged_model(cfg.exp_dir, 
                                         epoch=cfg.checkpoint.get("epoch", 0),
                                         iter=cfg.checkpoint.get("iter", 0),
                                         avg=cfg.checkpoint.get("avg", 1))
    model = AutoModel.from_pretrained(
        exp_dir=cfg.exp_dir,
        checkpoint_filename=checkpoint_filename,
    )

    device = torch.device("cuda", 0) if torch.cuda.is_available() else torch.device("cpu")
    model.to(device)
    model.eval()
    num_param = sum(p.numel() for p in model.parameters())
    logging.info(f"Number of model parameters: {num_param}")

    # do evaluation for the dataset
    for test_set_name, test_dl in zip(test_sets, test_dls):
        num_cuts = 0
        try:
            num_batches = len(test_dl)
        except TypeError:
            num_batches = "?"

        all_logits = []
        all_labels = []

        for batch_idx, batch in enumerate(test_dl):
            cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
            num_cuts += len(cut_ids)

            feature = batch["inputs"].to(device)
            feature_lens = batch["supervisions"]["num_frames"].to(device)
            audio_tag = batch["supervisions"]["audio_tag"]
            
            audio_label = model.tag2multihot(audio_tag)
            audio_logits = model.generate((feature, feature_lens), return_full_logits=True)

            all_logits.append(audio_logits)
            all_labels.append(audio_label)

            if batch_idx % 20 == 1:
                logging.info(f"Processed {num_cuts} cuts already.")
        logging.info("Finish collecting audio logits")

        all_logits = torch.cat(all_logits, dim=0)
        all_labels = torch.cat(all_labels, dim=0)

        if model.is_multilabel:
            mAP = average_precision_score(
                y_true=all_labels.numpy(),
                y_score=all_logits.cpu().numpy(),
            )
            logging.info(f"{test_set_name}: mAP: {mAP}")

        top1_acc, top5_acc = compute_acc(all_logits.cpu(), all_labels)
        logging.info(f"{test_set_name}: Top1 Acc: {top1_acc}, Top5 Acc: {top5_acc}")

        logging.info("Done")


if __name__ == "__main__":
    main()
