import logging
import json
import os
import yaml
import hydra
import torch
from torch.utils.data import DataLoader
from omegaconf import DictConfig, OmegaConf

from lhotse import CutSet, Fbank, FbankConfig, load_manifest_lazy
from lhotse.dataset import DynamicBucketingSampler, OnTheFlyFeatures
from auden.auto.auto_model import AutoModel
from auden.models.audio_tag.utils import compute_acc
from auden.data.dataset.audio_tag_dataset import AudioTaggingDataset
from auden.utils.checkpoint import resolve_checkpoint_filename, generate_and_save_averaged_model
from sklearn.metrics import average_precision_score
from sklearn.metrics import precision_recall_curve
import numpy as np

def get_test_dataloaders(cfg):
    test_dls = []
    test_names = []
    with open(cfg.data.test_data_config, 'r') as file:
        test_data_config = yaml.load(file, Loader=yaml.FullLoader)

    def remove_short_utterance(cut):
        return cut.duration >= 1.0

    for test_set in test_data_config:
        logging.info(f"Getting {test_set['manifest']} cuts")
        cutset = load_manifest_lazy(test_set['manifest']).resample(16000)
        cutset = cutset.filter(remove_short_utterance)
        test_name = test_set['name']
        testset = AudioTaggingDataset(
            input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
            return_cuts=True,
        )
        sampler = DynamicBucketingSampler(
            cutset,
            max_duration=cfg.data.max_duration,
            shuffle=False,
        )
        test_dl = DataLoader(
            testset,
            batch_size=None,
            sampler=sampler,
            num_workers=cfg.data.num_workers,
        )
        test_dls.append(test_dl)
        test_names.append(test_name)
    return test_names, test_dls


@hydra.main(version_base=None, config_path='configs', config_name="evaluate")
@torch.no_grad()
def main(cfg: DictConfig):
    logging.info("\n" + OmegaConf.to_yaml(cfg))

    # Get dataloaders
    test_sets, test_dls = get_test_dataloaders(cfg)

    # Initialize model
    checkpoint_filename = cfg.checkpoint.filename
    model = AutoModel.from_pretrained(
        exp_dir=cfg.exp_dir,
        checkpoint_filename=checkpoint_filename,
    )

    device = torch.device("cuda", 0) if torch.cuda.is_available() else torch.device("cpu")
    model.to(device)
    model.eval()
    num_param = sum(p.numel() for p in model.parameters())
    logging.info(f"Number of model parameters: {num_param}")

    # do evaluation for the dataset
    for test_set_name, test_dl in zip(test_sets, test_dls):
        num_cuts = 0
        try:
            num_batches = len(test_dl)
        except TypeError:
            num_batches = "?"

        all_logits = []
        all_labels = []

        for batch_idx, batch in enumerate(test_dl):
            cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
            num_cuts += len(cut_ids)

            feature = batch["inputs"].to(device)
            feature_lens = batch["supervisions"]["num_frames"].to(device)
            audio_tag = batch["supervisions"]["audio_tag"]
            
            audio_label = model.tag2multihot(audio_tag)
            audio_logits = model.generate((feature, feature_lens), return_full_logits=True)

            all_logits.append(audio_logits)
            all_labels.append(audio_label)

            if batch_idx % 20 == 1:
                logging.info(f"Processed {num_cuts} cuts already.")
        logging.info("Finish collecting audio logits")

        all_logits = torch.cat(all_logits, dim=0)
        all_labels = torch.cat(all_labels, dim=0)

        all_probs = torch.sigmoid(all_logits)

        C = all_probs.shape[1]

        thr, f_best = {}, {}

        for c in range(C):
            yt = all_labels[:, c].cpu().numpy().astype(int)
            ys = all_probs[:, c].cpu().numpy().astype(float)

            if yt.sum() == 0:
                thr[f"{c}"], f_best[f"{c}"] = 1.0, 0.0
                continue

            b2 = 1.0*1.0
            p, r, t = precision_recall_curve(yt, ys)
            f = (1+b2) * (p*r) / (b2*p + r + 1e-12)

            f_aligned = f[:-1]
            if f_aligned.size == 0:
                thr[f"{c}"], f_best[f"{c}"] = 0.5, 0.0
                continue

            i = int(np.nanargmax(f_aligned))
            thr[c], f_best[c] = float(t[i]), float(f_aligned[i])
        
        with open(os.path.join(cfg.exp_dir, f"best_thresholds_{test_set_name}.json"), 'w') as f:
            json.dump(thr, f, indent=4)
        with open(os.path.join(cfg.exp_dir, f"best_f1_{test_set_name}.json"), 'w') as f:
            json.dump(f_best, f, indent=4)


if __name__ == "__main__":
    main()
