import logging
import json
import os
import yaml
import hydra
import torch
from torch.utils.data import DataLoader
from omegaconf import DictConfig, OmegaConf

from lhotse import CutSet, Fbank, FbankConfig
from lhotse.dataset import DynamicBucketingSampler, OnTheFlyFeatures

from auden.auto.auto_model import AutoModel
from auden.models.audio_clap.utils import t2a_metric, a2t_metric, multi_a2t
from auden.data.dataset.audio_caption_dataset import AudioCaptionDataset
from auden.utils.checkpoint import resolve_checkpoint_filename, generate_and_save_averaged_model

def get_test_dataloaders(cfg):
    test_dls = []
    test_names = []
    with open(cfg.data.test_data_config, 'r') as file:
        test_data_config = yaml.load(file, Loader=yaml.FullLoader)
        
    def remove_short_utterance(c):
        if not cfg.multi_caption_eval:
            c.supervisions = [c.supervisions[0]]
        if c.duration < 1.0:
            return False
        return True
        
    for test_set in test_data_config:
        logging.info(f"Getting {test_set['manifest']} cuts")
        cutset = CutSet.from_file(test_set['manifest'])
        cutset = cutset.filter(remove_short_utterance)
        test_name = test_set['name']
        testset = AudioCaptionDataset(
            input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
            return_cuts=True,
        )
        sampler = DynamicBucketingSampler(
            cutset,
            max_duration=cfg.data.max_duration,
            shuffle=False,
        )
        test_dl = DataLoader(
            testset,
            batch_size=None,
            sampler=sampler,
            num_workers=cfg.data.num_workers,
        )
        test_dls.append(test_dl)
        test_names.append(test_name)
    return test_names, test_dls

@hydra.main(version_base=None, config_path="configs", config_name="evaluate")
@torch.no_grad()
def main(cfg: DictConfig):
    logging.info("\n" + OmegaConf.to_yaml(cfg))

    # initialize dataloader
    test_sets, test_dls = get_test_dataloaders(cfg)

    # Initialize model
    checkpoint_filename = resolve_checkpoint_filename(
        checkpoint_filename=cfg.checkpoint.get("pretrained_model", None),
        epoch=cfg.checkpoint.get("epoch", 0),
        iter=cfg.checkpoint.get("iter", 0),
        avg=cfg.checkpoint.get("avg", 1),
    )
    if checkpoint_filename.startswith('averaged'):
        generate_and_save_averaged_model(cfg.exp_dir, 
                                         epoch=cfg.checkpoint.get("epoch", 0),
                                         iter=cfg.checkpoint.get("iter", 0),
                                         avg=cfg.checkpoint.get("avg", 1))
    model = AutoModel.from_pretrained(
        exp_dir=cfg.exp_dir,
        checkpoint_filename=checkpoint_filename,
    )

    device = torch.device("cuda", 0) if torch.cuda.is_available() else torch.device("cpu")
    model.to(device)
    model.eval()
    num_param = sum(p.numel() for p in model.parameters())
    logging.info(f"Number of model parameters: {num_param}")
    
        
    for test_set_name, test_dl in zip(test_sets, test_dls):
        num_cuts = 0
        try:
            num_batches = len(test_dl)
        except TypeError:
            num_batches = "?"
            
        audio_embeds_all, text_embeds_all = [], []

        for batch_idx, batch in enumerate(test_dl):
            cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
            num_cuts += len(cut_ids)
            
            feature = batch["inputs"].to(device)
            # at entry, feature is (N, T, C)
            supervisions = batch["supervisions"]
            caption = supervisions["audio_caption"]  
            feature_lens = supervisions["num_frames"].to(device)

            # match audio-text batch size
            if cfg.multi_caption_eval:
                num_repeats = len(caption) // feature.size(0)
                feature = feature.repeat_interleave(num_repeats, dim=0)

            audio_embeds = model.encode_audio(x=feature, x_lens=feature_lens)
            text_embeds = model.encode_text(text=caption)

            audio_embeds_all.append(audio_embeds.cpu())
            text_embeds_all.append(text_embeds.cpu())

            if batch_idx % 20 == 1:
                logging.info(f"Processed {num_cuts} cuts already.")
        logging.info("Finish collecting audio logits")

        audio_embeds_all = torch.cat(audio_embeds_all, dim=0)
        text_embeds_all = torch.cat(text_embeds_all, dim=0)

        # only update a2t metric for multi-caption, since t2a is still 1-to-1
        if cfg.multi_caption_eval:
            r1_a, r5_a, r10_a, medr_a, meanr_a, mAP10_a = multi_a2t(audio_embeds_all, text_embeds_all, num_repeats)
            logging.info(f"a2t for {test_set_name} multi-caption eval is: r1: {r1_a}, r5: {r5_a}, r10: {r10_a}, medr: {medr_a}, meanr: {meanr_a}, mAP10: {mAP10_a}")
        
        else:
            r1, r5, r10, medr, meanr, mAP10 = t2a_metric(text_embeds_all, audio_embeds_all)
            r1_a, r5_a, r10_a, medr_a, meanr_a, mAP10_a = a2t_metric(audio_embeds_all, text_embeds_all)

            logging.info(f"t2a for {test_set_name} eval is: r1: {r1}, r5: {r5}, r10: {r10}, medr: {medr}, meanr: {meanr}, mAP10: {mAP10}")
            logging.info(f"a2t for {test_set_name} eval is: r1: {r1_a}, r5: {r5_a}, r10: {r10_a}, medr: {medr_a}, meanr: {meanr_a}, mAP10: {mAP10_a}")

        logging.info("Done")


if __name__ == "__main__":
    main()
