"""Qualitative evaluation of fmri2music models."""

from collections import defaultdict
import dataclasses
from typing import Iterable

from himalaya.scoring import correlation_score
import numpy as np

from fmri2music import audioset_metrics, emb_loader, fmri_loader, predictor, utils


def retrieve_fma_clip_name(
    fma_size: str,
    emb_name: str,
    pred: np.ndarray,
    intersect_w_emb_name: str | None,
    min_num_slices_intersection: int | None,
) -> str:
    """Retrieve the FMA clip name of the closest clip.

    Shape of pred is [num_slices, emb_dim]
    """
    num_pred_slices, _ = pred.shape
    fma_embs, grouped_fma_slice_names = emb_loader.get_grouped_fma_matrix(
        fma_size,
        emb_name,
        min_num_slices=num_pred_slices,
        intersect_w_emb_name=intersect_w_emb_name,
        min_num_slices_intersection=min_num_slices_intersection,
    )
    assert fma_embs.shape[1:] == pred.shape, f"{fma_embs.shape=}, {pred.shape=}"
    dot_product = np.einsum("ijk,jk->ij", fma_embs, pred)
    dot_product_mean = np.mean(dot_product, axis=1)

    max_idx = np.argmax(dot_product_mean)

    return utils.parse_long_key(grouped_fma_slice_names[max_idx][0])[0]


def retrieve_fma_embedding(
    fma_size: str,
    emb_name: str,
    pred: np.ndarray,
    eval_emb_name: str,
    expected_num_eval_slices: int,
) -> tuple[np.ndarray, list[str]]:
    """Retrieves the FMA embedding of the closest clip with temporal information.

    "with temporal information" means that a clip, split into its slices, is matched
    with a sequence of embeddings by matching the slices and selecting the clip which
    has the highest average cosine similarity on the slice-level.
    """
    fma_embs, fma_long_keys = emb_loader.get_grouped_fma_embs(
        fma_size, eval_emb_name, min_num_slices=expected_num_eval_slices
    )
    fma_key = retrieve_fma_clip_name(
        fma_size,
        emb_name,
        pred,
        intersect_w_emb_name=eval_emb_name,
        min_num_slices_intersection=expected_num_eval_slices,
    )
    return fma_embs[fma_key], fma_long_keys[fma_key]


def cos_sim(vec1: np.ndarray, vec2: np.ndarray) -> np.ndarray:
    """Computes the cosine similarity between two vectors."""
    return np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))


def mse(vec1: np.ndarray, vec2: np.ndarray) -> np.ndarray:
    """Computes the MSE between two vectors."""
    return np.mean((vec1 - vec2) ** 2)


def identification(mat_true: np.ndarray, mat_pred: np.ndarray) -> np.ndarray:
    """Computes the identification accuracy for a specific embedding."""
    if not mat_true.shape == mat_pred.shape:
        raise ValueError(
            f"The matrices have different shapes: {mat_true.shape=} and {mat_pred.shape=}"
        )

    # If the dimension of embedding is very large, it will use large memory
    r = np.corrcoef(mat_true, mat_pred)
    r = r[: mat_true.shape[0], mat_true.shape[0] :]
    congruents = np.diag(r)
    success = r < congruents
    identification_accuracies = np.sum(success, 0) / (len(success) - 1)

    return identification_accuracies


@dataclasses.dataclass(frozen=True)
class QuantEvalConfig:
    """Configuration for quantitative evaluation."""

    fma_size: str
    eval_emb_names: list[str]


@dataclasses.dataclass
class ModelQuantEvalResult:
    """Summary of the quantitative evaluation results for a model."""

    model_name: str
    emb_name: str
    eval_emb_name: str
    gtzan_pred_mses: list[float]  # MSEs between prediction and GTZAN emb.
    gtzan_pred_corr: list[float]  # Correlations between prediction and GTZAN emb.
    retrieved_fma_slice_names: list[str]
    retrieved_fma_cos_sims: list[float]
    retrieved_fma_identification: list[float]  # Identification accuracy on FMA.
    gtzan_identification: list[float]  # Identification accuracy on GTZAN.
    gtzan_slice_names: list[str]  # Slices included in the evaluation.
    gtzan_clip_names: list[str]  # Clip names included in the evaluation.
    fma_clip_names: list[str]  # Retrieved FMA clip names.
    skip_ctr: int  # Number of clips that were excluded from the evaluation.
    class_overlaps: dict[str, float]  # Percentage AudioSet class overlap.
    class_spearman_corr: dict[str, float]  # Spearman corr in class groups.

    def __post_init__(self):
        """Post-initialization checks."""
        num_slices = len(self.retrieved_fma_slice_names)
        assert num_slices == len(self.retrieved_fma_cos_sims)
        assert num_slices == len(self.retrieved_fma_identification)
        assert num_slices == len(self.gtzan_slice_names)

    def get_mean_identification_accuracy(self) -> float:
        """Mean identification accuracy across all GTZAN embeddings."""
        return np.mean(self.retrieved_fma_identification).item()

    def get_stddev_identification_accuracy(self) -> float:
        return np.std(self.retrieved_fma_identification).item()

    def get_mean_id_acc_gtzan(self) -> float:
        return np.mean(self.gtzan_identification).item()

    def get_stddev_id_acc_gtzan(self) -> float:
        return np.std(self.gtzan_identification).item()

    def get_mean_mse_gtzan(self) -> float:
        return np.mean(self.gtzan_pred_mses).item()

    def get_stddev_mse_gtzan(self) -> float:
        return np.std(self.gtzan_pred_mses).item()
    
    def get_mean_corr_gtzan(self) -> float:
        return np.mean(self.gtzan_pred_corr).item()

    def get_stddev_corr_gtzan(self) -> float:
        return np.std(self.gtzan_pred_corr).item()

    def get_mean_identification_accuracy_by_category(self) -> dict[str, float]:
        """Mean identification accuracy by each GTZAN category."""

        accs_per_category = defaultdict(list)
        for idx, gtzan_slice_name in enumerate(self.gtzan_slice_names):
            category = gtzan_slice_name.split(".")[0]
            accs_per_category[category].append(self.retrieved_fma_identification[idx])
        return {k: np.mean(v).item() for k, v in accs_per_category.items()}

    def get_stddev_identification_accuracy_by_category(self) -> dict[str, float]:
        accs_per_category = defaultdict(list)
        for idx, gtzan_slice_name in enumerate(self.gtzan_slice_names):
            category = gtzan_slice_name.split(".")[0]
            accs_per_category[category].append(self.retrieved_fma_identification[idx])
        return {k: np.std(v).item() for k, v in accs_per_category.items()}

    def get_mean_cos_sim(self) -> float:
        """Mean cosine similarity between the retrieved FMA embeddings and the GTZAN embeddings."""
        return np.mean(self.retrieved_fma_cos_sims).item()

    def get_stddev_cos_sim(self) -> float:
        return np.std(self.retrieved_fma_cos_sims).item()

    def get_mean_cos_sim_by_category(self) -> dict[str, float]:
        """Mean cosine similarity between the retrieved FMA embeddings and the GTZAN embeddings by category."""

        cos_sims_per_category = defaultdict(list)
        for idx, gtzan_slice_name in enumerate(self.gtzan_slice_names):
            category = gtzan_slice_name.split(".")[0]
            cos_sims_per_category[category].append(self.retrieved_fma_cos_sims[idx])
        return {k: np.mean(v).item() for k, v in cos_sims_per_category.items()}

    def get_stddev_cos_sim_by_category(self) -> dict[str, float]:
        cos_sims_per_category = defaultdict(list)
        for idx, gtzan_slice_name in enumerate(self.gtzan_slice_names):
            category = gtzan_slice_name.split(".")[0]
            cos_sims_per_category[category].append(self.retrieved_fma_cos_sims[idx])
        return {k: np.std(v).item() for k, v in cos_sims_per_category.items()}

    def get_diversity(self) -> float:
        """Ratio of unique retrieved FMA keys to the total number of retrieved keys."""
        return len(set(self.retrieved_fma_slice_names)) / len(
            self.retrieved_fma_slice_names
        )

    def report(self) -> str:
        """Returns a human-readable summary of the evaluation."""
        return f"""
        Model: {self.model_name}
        Embedding names: {self.emb_name} (pre) {self.eval_emb_name} (eval)
        GTZAN pred MSE: {round(self.get_mean_mse_gtzan(), 4)} (stddev: {round(self.get_stddev_mse_gtzan(), 4)})
        GTZAN pred corr: {round(self.get_mean_corr_gtzan(), 4)} (stddev: {round(self.get_stddev_corr_gtzan(), 4)})
        Mean cosine similarity: {round(self.get_mean_cos_sim(), 4)} (stddev: {round(self.get_stddev_cos_sim(), 4)})
        Mean cosine similarity (by category): {utils.str_from_float_dict(self.get_mean_cos_sim_by_category())} (stddev: {utils.str_from_float_dict(self.get_stddev_cos_sim_by_category())})
        Mean id acc (FMA): {round(self.get_mean_identification_accuracy(), 4)} (stddev: {round(self.get_stddev_identification_accuracy(), 4)}
        Mean id acc (FMA) (by category): {utils.str_from_float_dict(self.get_mean_identification_accuracy_by_category())} (stddev: {utils.str_from_float_dict(self.get_stddev_identification_accuracy_by_category())})
        Mean id acc (GTZAN): {round(self.get_mean_id_acc_gtzan(), 4)} (stddev: {round(self.get_stddev_id_acc_gtzan(), 4)})
        Diversity: {round(self.get_diversity(), 4)}
        # examples: {len(self.retrieved_fma_slice_names)}
        # eval clips: {len(self.gtzan_clip_names)} (and {self.skip_ctr} skipped)
        # eval slices: {len(self.gtzan_slice_names)}
        AudioSet class overlaps: {utils.str_from_float_dict(self.class_overlaps)}
        AudioSet top-n: {audioset_metrics.TOP_N_PER_CLASS}
        AudioSet spearman corr: {utils.str_from_float_dict(self.class_spearman_corr)}
        """

    def get_result_dict(self) -> dict[str, float]:
        """Returns a flat dictionary of the evaluation results."""
        return {
            **{
                "model-name": self.model_name,
                "emb-name": self.emb_name,
                "gtzan-pred-mse": self.get_mean_mse_gtzan(),
                "gtzan-pred-mse-stddev": self.get_stddev_mse_gtzan(),
                "gtzan-pred-corr": self.get_mean_corr_gtzan(),
                "gtzan-pred-corr-stddev": self.get_stddev_corr_gtzan(),
            },
            **utils.add_key_prefix(
                f"{self.eval_emb_name}-",
                {
                    "mean-cos-sim": self.get_mean_cos_sim(),
                    "stddev-cos-sim": self.get_stddev_cos_sim(),
                    "mean-id-acc": self.get_mean_identification_accuracy(),
                    "stddev-id-acc": self.get_stddev_identification_accuracy(),
                    "mean-id-acc-gtzan": self.get_mean_id_acc_gtzan(),
                    "diversity": self.get_diversity(),
                    **utils.add_key_prefix(
                        "mean-cos-sim-by-category-", self.get_mean_cos_sim_by_category()
                    ),
                    **utils.add_key_prefix(
                        "mean-id-acc-by-category-",
                        self.get_mean_identification_accuracy_by_category(),
                    ),
                    **utils.add_key_prefix(
                        "stddev-id-acc-by-category-",
                        self.get_stddev_identification_accuracy_by_category(),
                    ),
                    "num-eval-gtzan-clips": len(self.gtzan_clip_names),
                    "num-skipped-gtzan-clips": self.skip_ctr,
                    **utils.add_key_prefix(
                        "audioset-class-overlap-",
                        {
                            f"{class_group}": v
                            for class_group, v in self.class_overlaps.items()
                        },
                    ),
                    **utils.add_key_prefix(
                        "audioset-class-overlap-",
                        {
                            f"{class_group}-n-{audioset_metrics.TOP_N_PER_CLASS[class_group]}": v
                            for class_group, v in audioset_metrics.TOP_N_PER_CLASS.items()
                        },
                    ),
                    **utils.add_key_prefix(
                        "audioset-spearman-corr-",
                        {
                            f"{class_group}": v
                            for class_group, v in self.class_spearman_corr.items()
                        },
                    ),
                },
            ),
        }


def evaluate_models(
    path_from_model: dict[str, str],
    gtzan_clip_names: list[str],
    config: QuantEvalConfig,
) -> Iterable[ModelQuantEvalResult]:
    """Runs the quantitative evaluation for the given models."""
    eval_emb_names = config.eval_emb_names
    fma_size = config.fma_size

    models = []
    for eval_emb_name in eval_emb_names:
        models.append(predictor.RandomPredictor(eval_emb_name, fma_size))
        models.append(predictor.BestPredictor(eval_emb_name, fma_size))

    for model_name, file_name in path_from_model.items():
        emb_name = emb_loader.load_pred_emb_name(file_name)
        models.append(predictor.ExportedPredictor(model_name, file_name, emb_name))

    for model in models:
        yield from evaluate_model(model, gtzan_clip_names, config)


def evaluate_model(
    model: predictor.EmbeddingPredictor,
    gtzan_clip_names: list[str],
    config: QuantEvalConfig,
) -> list[ModelQuantEvalResult]:
    # TODO: Flip embedding and GTZAN loop, to avoid performing the retrieval based on the
    # predicted embedding multiple times.
    return [
        evaluate_model_single_emb(
            model, gtzan_clip_names, eval_emb_name, config.fma_size
        )
        for eval_emb_name in config.eval_emb_names
    ]


def evaluate_model_single_emb(
    model: predictor.EmbeddingPredictor,
    gtzan_clip_names: list[str],
    eval_emb_name: str,
    fma_size: str,
) -> ModelQuantEvalResult:
    """Runs the quantitative evaluation for the given model.

    The quant eval for a model is executed for every one of the provided GTZAN clip names.
    There are two embedding *types* at play here, the one of the model ("model_emb") and
    the one used in the eval ("eval_emb"). They may be the same, but don't have to.

    Some models are trained to predict multiple embeddings per GTZAN clip. We refer to the
    clip parts as "slices".

    Algorithmically we follow these steps:
    1. Predict the embeddings for the GTZAN clip slices using the model.
    2. Retrieve the closest FMA track for the predicted embeddings. This step simulates
       music generation and could be replaced by a generative model. The retrieval is done based on
       the predicted embeddings (type: model_emb).
    3. Get the eval embeddings of the retrieved FMA track.
    4. Get the eval embedding of the current GTZAN clip.
    5. Compare the similarity of target audio (from step 4.) to the retrieved audio from step 2./3.
       Critically, the similarity is computed based in the eval embedding space, which is model
       independent.

    The evaluation includes an identification analysis:
    1. Get the embeddings of the GTZAN clip.
    2. Get predicted embeddings and retrieve FMA embeddings for them (simulating the generation of music).
    3. Calculate identification accuracy using above two.
    """
    # The two embedding types: what the model was trained on and what we evalute with.
    train_emb_name = model.emb_name

    print(f"Evaluating model {model.name}; {train_emb_name=} {eval_emb_name=}")

    # Clip names of the FMA clips that were retrieved ("generated").
    all_retrieved_fma_music_slice_names = []
    all_gtzan_slice_names = []

    all_pred_mse = []  # MSE between predicted and GTZAN embeddings.
    all_pred_corr = []  # Correlation between predicted and GTZAN embeddings.

    # Cos sims beetween retrieved FMA audio and GTZAN audio.
    all_retrieval_cos_sims = []
    all_eval_gtzan_embs = []  # GTZAN embeddings (of eval_emb type).
    # FMA embeddings (of eval_emb type); embs of the retreived music.
    all_eval_fma_embs = []

    # GTZAN ground truth embs and predicted embs (in train emb space).
    all_train_gtzan_embs, all_gtzan_preds = [], []

    audioset_class_overlaps: dict[str, list[float]] = defaultdict(list)
    audioset_spearman_corr: dict[str, list[float]] = defaultdict(list)

    skip_ctr = 0
    included_gtzan_clip_names = []
    retrieved_fma_clip_names = []

    for gtzan_clip_name in gtzan_clip_names:
        (
            gtzan_clip_train_emb_mat,
            gtzan_train_slice_names,
        ) = emb_loader.get_grouped_gtzan_embs_for_clip_name(
            train_emb_name, gtzan_clip_name
        )
        if any(not model.can_predict(s) for s in gtzan_train_slice_names):
            # Models may not predict keys because of excluded genres or shortened training data.
            skip_ctr += 1
            continue

        included_gtzan_clip_names.append(gtzan_clip_name)
        predicted_gtzan_embs = model.predict_embs(gtzan_train_slice_names)
        all_train_gtzan_embs.extend(gtzan_clip_train_emb_mat)
        all_gtzan_preds.extend(predicted_gtzan_embs)

        for gtzan_ground_truth_emb, pred_emb in zip(
            gtzan_clip_train_emb_mat, predicted_gtzan_embs
        ):
            all_pred_mse.append(mse(gtzan_ground_truth_emb, pred_emb))
            all_pred_corr.append(correlation_score(gtzan_ground_truth_emb, pred_emb))
        
        del gtzan_train_slice_names
        del gtzan_clip_train_emb_mat

        (
            gtzan_slice_embs,
            gtzan_slice_names,
        ) = emb_loader.get_grouped_gtzan_embs_for_clip_name(
            eval_emb_name, gtzan_clip_name
        )
        if not len(gtzan_slice_embs) == len(gtzan_slice_names):
            raise ValueError(
                f"Got {len(gtzan_slice_embs)=}, {len(gtzan_slice_names)=}"
                f"Emb name: {eval_emb_name}; GTZAN clip name: {gtzan_clip_name}"
            )
        all_gtzan_slice_names.extend(gtzan_slice_names)

        # Simulate music generation via retrieval from the FMA dataset.
        fma_slice_embs, fma_slice_names = retrieve_fma_embedding(
            fma_size,
            train_emb_name,
            predicted_gtzan_embs,
            eval_emb_name,
            expected_num_eval_slices=len(gtzan_slice_embs),
        )
        if len(gtzan_slice_embs) > len(fma_slice_embs):
            raise ValueError(
                f"GTZAN slice embs ({len(gtzan_slice_embs)}) > FMA slice embs ({len(fma_slice_embs)}); "
                f"Emb name: {eval_emb_name}; GTZAN clip name: {gtzan_clip_name}; {fma_slice_names=}"
            )

        for gtzan_slice_emb, fma_slice_emb, fma_slice_name in zip(
            gtzan_slice_embs, fma_slice_embs, fma_slice_names
        ):
            # Zip is shortening to the length of the first iterable, which is (slices of) 15s of GTZAN.
            # We ignore the second 15s of FMA intentionally,
            # because the retrieval was also only done based on the first 15s.

            all_retrieved_fma_music_slice_names.append(fma_slice_name)
            all_retrieval_cos_sims.append(cos_sim(fma_slice_emb, gtzan_slice_emb))
            all_eval_gtzan_embs.append(gtzan_slice_emb)
            all_eval_fma_embs.append(fma_slice_emb)

        fma_clip_name = utils.parse_long_key(fma_slice_names[0])[0]
        retrieved_fma_clip_names.append(fma_clip_name)

        for class_group, overlap in audioset_metrics.get_class_overlap(
            gtzan_clip_name, fma_clip_name, fma_size
        ).items():
            audioset_class_overlaps[class_group].append(overlap)

        for class_group, spearman_corr in audioset_metrics.get_spearman_corr(
            gtzan_clip_name, fma_clip_name, fma_size
        ).items():
            audioset_spearman_corr[class_group].append(spearman_corr)

    identification_accuracies = identification(
        np.array(all_eval_gtzan_embs), np.array(all_eval_fma_embs)
    )
    gtzan_id_accs = identification(
        np.array(all_train_gtzan_embs), np.array(all_gtzan_preds)
    )

    return ModelQuantEvalResult(
        model_name=model.name,
        emb_name=train_emb_name,
        eval_emb_name=eval_emb_name,
        gtzan_pred_mses=all_pred_mse,
        gtzan_pred_corr=all_pred_corr,
        retrieved_fma_slice_names=all_retrieved_fma_music_slice_names,
        retrieved_fma_cos_sims=all_retrieval_cos_sims,
        retrieved_fma_identification=identification_accuracies,
        gtzan_identification=gtzan_id_accs,
        gtzan_slice_names=all_gtzan_slice_names,
        gtzan_clip_names=included_gtzan_clip_names,
        fma_clip_names=retrieved_fma_clip_names,
        skip_ctr=skip_ctr,
        class_overlaps=utils.avg_dict_of_lists(audioset_class_overlaps),
        class_spearman_corr=utils.avg_dict_of_lists(audioset_spearman_corr),
    )


def evaluate_musiclm_model(
    model_name: str,
    eval_emb_name: str,
    slice_name_to_gen_emb: dict[str, np.ndarray],
    clip_name_to_audioset_probs: dict[str, np.ndarray],
) -> ModelQuantEvalResult:
    """Runs the quantitative evaluation for the given predictions.

    Here the music has already been generated an embedded again, so no retrieval is required.
    Slice names are from the GTZAN dataset.
    In our evals we call this only with validation examples, hence no filtering with
    fmri_loader.get_val_clip_names() is done below; a check is in place.
    """

    # Cos sims beetween generated audio and GTZAN audio.
    all_gen_cos_sims = []

    all_gen_embs, all_target_embs = [], []

    audioset_class_overlaps: dict[str, list[float]] = defaultdict(list)
    audioset_spearman_corr: dict[str, list[float]] = defaultdict(list)

    all_gtzan_slice_names = []

    gtzan_clip_names = set(utils.parse_long_key(k)[0] for k in slice_name_to_gen_emb)
    assert set(gtzan_clip_names) == set(fmri_loader.get_val_clip_names())
    print(f"Number of GTZAN clips: #{len(gtzan_clip_names)}")

    for gtzan_clip_name in gtzan_clip_names:
        (
            gtzan_slice_embs,  # Matrix [num_slices, emb_dim].
            gtzan_slice_names,
        ) = emb_loader.get_grouped_gtzan_embs_for_clip_name(
            eval_emb_name, gtzan_clip_name
        )

        for idx, slice_name in enumerate(gtzan_slice_names):
            target_emb = gtzan_slice_embs[idx]
            gen_emb = slice_name_to_gen_emb[slice_name]

            all_gen_cos_sims.append(cos_sim(target_emb, gen_emb))
            all_gen_embs.append(gen_emb)
            all_target_embs.append(target_emb)
            all_gtzan_slice_names.append(slice_name)

        gen_audioset_vec = clip_name_to_audioset_probs[gtzan_clip_name]
        target_audioset_vec = emb_loader.get_audioset_probs_for_gtzan()[gtzan_clip_name]

        for class_group, overlap in audioset_metrics.get_class_overlap_for_vecs(
            gen_audioset_vec, target_audioset_vec
        ).items():
            audioset_class_overlaps[class_group].append(overlap)

        for (
            class_group,
            spearman_corr,
        ) in audioset_metrics.get_all_spearman_corrs_for_vecs(
            gen_audioset_vec, target_audioset_vec
        ).items():
            audioset_spearman_corr[class_group].append(spearman_corr)

    identification_accuracies = identification(
        np.array(all_gen_embs), np.array(all_target_embs)
    )

    return ModelQuantEvalResult(
        model_name=model_name,
        emb_name="N/A",
        eval_emb_name=eval_emb_name,
        gtzan_pred_mses=[],  # Not applicable because predicted embeddings are not available.
        gtzan_pred_corr=[],  # Not applicable because predicted embeddings are not available.
        retrieved_fma_slice_names=all_gtzan_slice_names,  # No FMA, because the audio is generated.
        retrieved_fma_cos_sims=all_gen_cos_sims,
        retrieved_fma_identification=identification_accuracies,
        gtzan_identification=[],  # Not applicable because predicted embeddings are not available.
        gtzan_slice_names=all_gtzan_slice_names,
        gtzan_clip_names=gtzan_clip_names,  # All are included currently.
        fma_clip_names=[],
        skip_ctr=0,
        class_overlaps=utils.avg_dict_of_lists(audioset_class_overlaps),
        class_spearman_corr=utils.avg_dict_of_lists(audioset_spearman_corr),
    )
