import logging
import json
import os
import yaml
import hydra
import torch
from torch.utils.data import DataLoader
from omegaconf import DictConfig, OmegaConf

from lhotse import CutSet, Fbank, FbankConfig
from lhotse.dataset import DynamicBucketingSampler, SimpleCutSampler, OnTheFlyFeatures

from auden.auto.auto_model import AutoModel
from auden.models.audio_clap.utils import t2a_metric, a2t_metric
from auden.data.dataset.audio_caption_dataset import AudioCaptionDataset
from auden.utils.checkpoint import resolve_checkpoint_filename, generate_and_save_averaged_model

from sklearn.metrics import average_precision_score
import torch.nn.functional as F
import numpy as np
import random
import csv
from collections import defaultdict


def tag2multihot(tag_strings, label2id):
    # input: ['sand;rub', 'butterfly']
    # output: torch.tensor([[1,1,0], [0,0,1]])  torch.tensor([[0], [1]]) 
    num_classes = len(label2id)
    multihot = torch.zeros((len(tag_strings), num_classes), dtype=torch.float32)

    for i, tag_str in enumerate(tag_strings):
        tags = tag_str.split(";")
        for tag in tags:
            if tag.isdigit():
                tag_index = int(tag)
                multihot[i, tag_index] = 1.0
            else:
                multihot[i, int(label2id[tag])] = 1.0
    return multihot


def get_description(description_path, dataset_name, class_names, add_tag=False):
    descriptions_map = {}
    with open(description_path, mode='r', encoding='utf-8') as file:
        reader = csv.DictReader(file, delimiter='\t')
        for row in reader:
            if row['dataset'] == dataset_name:
                if add_tag:
                    descriptions_map[row['class_name']] = row['class_name'] + ': ' + row['base_description']
                else:
                    descriptions_map[row['class_name']] = row['base_description']
    
    descriptions = []
    for name in class_names:
        if name in descriptions_map:
            descriptions.append(descriptions_map[name])

    return descriptions


def get_test_dataloaders(cfg):
    test_dls = []
    test_names = []
    with open(cfg.data.test_data_config, 'r') as file:
        test_data_config = yaml.load(file, Loader=yaml.FullLoader)
        
    def remove_short_utterance(c):
        c.supervisions = [c.supervisions[0]]
        if c.duration < 1.0:
            return False
        if c.duration > 30.0:
            return False
        return True
        
    for test_set in test_data_config:
        logging.info(f"Getting {test_set['manifest']} cuts")
        cutset = CutSet.from_file(test_set['manifest'])
        cutset = cutset.filter(remove_short_utterance)
        test_name = test_set['name']
        testset = AudioCaptionDataset(
            input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
            return_cuts=True,
        )
        sampler = DynamicBucketingSampler(
            cutset,
            max_duration=cfg.data.max_duration,
            shuffle=False,
        )
        test_dl = DataLoader(
            testset,
            batch_size=None,
            sampler=sampler,
            num_workers=cfg.data.num_workers,
        )

        test_dls.append(test_dl)
        test_names.append(test_name)
    return test_names, test_dls

@hydra.main(version_base=None, config_path="configs", config_name="evaluate")
@torch.no_grad()
def main(cfg: DictConfig):
    logging.info("\n" + OmegaConf.to_yaml(cfg))

    # initialize dataloader
    test_sets, test_dls = get_test_dataloaders(cfg)

    # Initialize model
    checkpoint_filename = resolve_checkpoint_filename(
        checkpoint_filename=cfg.checkpoint.get("pretrained_model", None),
        epoch=cfg.checkpoint.get("epoch", 0),
        iter=cfg.checkpoint.get("iter", 0),
        avg=cfg.checkpoint.get("avg", 1),
    )
    if checkpoint_filename.startswith('averaged'):
        generate_and_save_averaged_model(cfg.exp_dir, 
                                         epoch=cfg.checkpoint.get("epoch", 0),
                                         iter=cfg.checkpoint.get("iter", 0),
                                         avg=cfg.checkpoint.get("avg", 1))
    model = AutoModel.from_pretrained(
        exp_dir=cfg.exp_dir,
        checkpoint_filename=checkpoint_filename,
    )

    device = torch.device("cuda", 0) if torch.cuda.is_available() else torch.device("cpu")
    model.to(device)
    model.eval()
    num_param = sum(p.numel() for p in model.parameters())
    logging.info(f"Number of model parameters: {num_param}")

    test_set_categories = {'acc': ['esc50', 'urbansound', 'vggsound'], 'map': ['fsd50k', 'audioset']}
    accumulators = defaultdict(lambda: {"correct_top1": 0, "correct_top5": 0, "map": 0.0, "total_samples": 0})
    
    for test_set_name, test_dl in zip(test_sets, test_dls):
        test_base_name = test_set_name.split('-')[0]
        num_cuts = 0
        try:
            num_batches = len(test_dl)
        except TypeError:
            num_batches = "?"

        label_path  = f'../audio_tag/configs/{test_base_name}/id2label_{test_base_name}.json'
        with open(label_path, 'r') as f:
            label_dict = json.load(f)
        
        is_multilabel = test_base_name in test_set_categories['map']
        
        label2id = {value: key for key, value in label_dict.items()}
        label_list = list(label_dict.values())

        using_description = False # True if want to use description as label

        if using_description:
            caption_list = get_description(cfg.description_path, test_base_name, label_list, add_tag=False)
            label_embeds = model.encode_text(text=caption_list, device=device)
        else:
            label_embeds = model.encode_text(text=label_list, device=device)

        correct_top1, correct_top5 = 0, 0
        total_samples = 0
        
        all_logits = []
        all_labels = []

        for batch_idx, batch in enumerate(test_dl):
            cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
            num_cuts += len(cut_ids)
            
            feature = batch["inputs"].to(device)
            supervisions = batch["supervisions"]
            label = supervisions["audio_tag"]
            
            feature_lens = supervisions["num_frames"].to(device)
            audio_embeds = model.encode_audio(x=feature, x_lens=feature_lens) # [B, D]

            similarity = audio_embeds @ label_embeds.T  # [B, 50]
            probabilities = torch.sigmoid(similarity).cpu()

            # get predicted labels
            topk_values, topk_indices = similarity.topk(k=5, dim=-1)
            predicted_labels = [[label_list[idx.item()] for idx in indices] for indices in topk_indices]
            
            if is_multilabel:
                 all_logits.append(probabilities)
                 hot_labels = tag2multihot(label, label2id)  
                 all_labels.append(hot_labels)      
            else:
                for i, true_label in enumerate(label):
                    total_samples += 1
                    accumulators[test_base_name]["total_samples"] += 1

                    if true_label == predicted_labels[i][0]:
                        correct_top1 += 1
                        accumulators[test_base_name]["correct_top1"] += 1
                    if true_label in predicted_labels[i]:
                        correct_top5 += 1
                        accumulators[test_base_name]["correct_top5"] += 1
                    
        
        if is_multilabel:
            all_logits = torch.cat(all_logits, dim=0)
            all_labels = torch.cat(all_labels, dim=0)

            mAP = average_precision_score(
                y_true=all_labels.cpu().numpy(),
                y_score=all_logits.cpu().numpy(),
            )
            logging.info(f"{test_set_name}: mAP: {mAP}")
            accumulators[test_base_name]["map"] = mAP
        else:
            top1_acc = correct_top1 / total_samples
            top5_acc = correct_top5 / total_samples
            logging.info(f"{test_set_name}: Top1 Acc: {top1_acc}, Top5 Acc: {top5_acc}, Total Samples: {total_samples}")

        torch.cuda.empty_cache()


    for test_base_name, metrics in accumulators.items():
        if test_base_name in test_set_categories['map']:
            logging.info("=========================================")
            logging.info(f"Total: {test_base_name}: mAP: {metrics['map']}, Total Samples: {metrics['total_samples']}")
        else:
            all_top1_acc = metrics["correct_top1"] / metrics["total_samples"]
            all_top5_acc = metrics["correct_top5"] / metrics["total_samples"]
            logging.info("=========================================")
            logging.info(f"Total: {test_base_name}: Top1 Acc: {all_top1_acc}, Top5 Acc: {all_top5_acc}, Total Samples: {metrics['total_samples']}")

    
    logging.info("Done")


if __name__ == "__main__":
    main()
