"""Utilities for evaluating on AudioSet class probabilities."""

import numpy as np
from scipy.stats import spearmanr

from fmri2music import data_const, emb_loader

CLASS_GROUPS = {
    "genres": data_const.AUDIOSET_GENRES,
    "instruments": data_const.AUDIOSET_INSTRUMENTS,
    "moods": data_const.AUDIOSET_MOODS,
}

TOP_N_PER_CLASS = {
    "genres": 2,
    "instruments": 2,
    "moods": 1,
}


def get_spearman_corr(
    gtzan_clip_name: str, fma_clip_name: str, fma_size: str
) -> dict[str, float]:
    gtzan_probs = emb_loader.get_audioset_probs_for_gtzan()[gtzan_clip_name]
    fma_probs = emb_loader.get_audioset_probs_for_fma(fma_size)[fma_clip_name]

    return {
        class_group: get_spearman_corr_for_vecs(
            gtzan_probs, fma_probs, CLASS_GROUPS[class_group]
        )
        for class_group in CLASS_GROUPS
    }


def get_all_spearman_corrs_for_vecs(
    p1_vec: np.ndarray,
    p2_vec: np.ndarray,
) -> dict[str, float]:
    return {
        class_group: get_spearman_corr_for_vecs(
            p1_vec, p2_vec, CLASS_GROUPS[class_group]
        )
        for class_group in CLASS_GROUPS
    }


def get_spearman_corr_for_vecs(
    p1_vec: np.ndarray, p2_vec: np.ndarray, class_subset: list[str]
) -> float:
    class_map = emb_loader.get_audioset_class_map()
    class_indices = [class_map[name] for name in class_subset]
    assert len(class_indices) == len(
        class_subset
    ), f"Found invalid classes in class_subset: {class_subset}"

    p1_vec = p1_vec[class_indices]
    p2_vec = p2_vec[class_indices]

    coef, p = spearmanr(p1_vec, p2_vec)
    return coef


def get_class_overlap(
    gtzan_clip_name: str, fma_clip_name: str, fma_size: str
) -> dict[str, float]:
    return {
        class_group: get_top_n_overlap(
            gtzan_clip_name,
            fma_clip_name,
            fma_size,
            CLASS_GROUPS[class_group],
            TOP_N_PER_CLASS[class_group],
        )
        for class_group in CLASS_GROUPS
    }


def get_class_overlap_for_vecs(
    p1_vec: np.ndarray, p2_vec: np.ndarray
) -> dict[str, float]:
    return {
        class_group: get_top_n_overlap_for_vecs(
            p1_vec,
            p2_vec,
            CLASS_GROUPS[class_group],
            TOP_N_PER_CLASS[class_group],
        )
        for class_group in CLASS_GROUPS
    }


def get_top_n_overlap(
    gtzan_clip_name: str,
    fma_clip_name: str,
    fma_size: str,
    class_subset: list[str],
    n: int,
) -> float:
    """Computes what percentage of the top-n class matches between GTZAN and FMA."""
    fma_clip_classes = get_top_n_for_fma(class_subset, fma_clip_name, fma_size, n)
    gtzan_clip_classes = get_top_n_for_gtzan(class_subset, gtzan_clip_name, n)
    return intersection_size(fma_clip_classes, gtzan_clip_classes) / n


def get_top_n_overlap_for_vecs(
    p1_vec: np.ndarray, p2_vec: np.ndarray, class_subset: list[str], n: int
) -> float:
    """Computes what percentage of the top-n classes matches between p1_vec and p2_vec."""
    top_classes_1 = get_top_classes(p1_vec, class_subset, n)
    top_classes_2 = get_top_classes(p2_vec, class_subset, n)
    return intersection_size(top_classes_1, top_classes_2) / n


def intersection_size(s1: list[str], s2: list[str]) -> int:
    """Returns the size of the intersection between two sets of classes."""
    return len(set(s1).intersection(set(s2)))


def get_top_n_for_gtzan(
    class_subset: list[str], gtzan_clip_name: str, n: int
) -> list[str]:
    """Get the top-n classes (from a subset) for a GTZAN clip."""
    audioset_probs = emb_loader.get_audioset_probs_for_gtzan()[gtzan_clip_name]
    return get_top_classes(audioset_probs, class_subset, n)


def get_top_n_for_fma(
    class_subset: list[str], fma_clip_name: str, fma_size: str, n: int
) -> list[str]:
    """Get the top-n classes (from a subset) for an FMA clip."""
    audioset_probs = emb_loader.get_audioset_probs_for_fma(fma_size)[fma_clip_name]
    return get_top_classes(audioset_probs, class_subset, n)


def get_top_classes(probs: np.ndarray, class_subset: list[str], n: int) -> list[str]:
    """Returns the top-n classes from the probs vector given a subset of all classes."""
    class_map = emb_loader.get_audioset_class_map()
    class_indices = [class_map[name] for name in class_subset]
    class_probs = probs[class_indices]
    top_indices = np.argpartition(class_probs, -n)[-n:]
    top_indices = top_indices[np.argsort(class_probs[top_indices])][::-1]
    top_classes = [class_subset[i] for i in top_indices]
    return top_classes
