"""Predictor classes for the computation/retrieval of embeddings for GTZAN music."""

import abc

import numpy as np

from fmri2music import emb_loader


class EmbeddingPredictor(abc.ABC):
    """Abstract class for embedding predictors."""

    def __init__(self, name: str, emb_name: str) -> None:
        self.name = name
        self.emb_name = emb_name

    @abc.abstractmethod
    def predict_emb(self, gtzan_key: str) -> np.ndarray:
        """Predicts an embedding for the given GTZAN key."""
        raise NotImplementedError()

    def predict_embs(self, key_group: list[str]) -> np.ndarray:
        """Predict embeddings for a group of GTZAN keys."""
        embs = [self.predict_emb(key) for key in key_group]
        return np.stack(embs)

    @abc.abstractmethod
    def can_predict(self, gtzan_key: str) -> bool:
        """Returns whether the predictor can predict the given GTZAN key."""
        raise NotImplementedError()


class RandomPredictor(EmbeddingPredictor):
    """Predicts a random embedding."""

    def __init__(self, emb_name: str, fma_size: str):
        super().__init__("random", emb_name)
        self.fma_size = fma_size

    def predict_emb(self, gtzan_key: str) -> np.ndarray:
        del gtzan_key  # Input is ignored in the random predictor.
        _, fma_embs = emb_loader.get_fma_emb(self.fma_size, self.emb_name)
        return fma_embs[np.random.randint(0, fma_embs.shape[0]), :]

    def can_predict(self, gtzan_key: str) -> bool:
        return True


class BestPredictor(EmbeddingPredictor):
    """Predicts an embedding by retrieving the closest one (an upper bound)."""

    def __init__(self, emb_name: str, fma_size: str):
        super().__init__("retrieval", emb_name)
        self.fma_size = fma_size

    def retrieve_closest_fma_emb(self, emb: np.ndarray) -> np.ndarray:
        """Retrieves the closest FMA embedding for a given embedding."""
        _, fma_embs = emb_loader.get_fma_emb(self.fma_size, self.emb_name)
        similarities = np.dot(fma_embs, emb)
        closest_idx = np.argmax(similarities)
        return fma_embs[closest_idx]

    def predict_emb(self, gtzan_key: str) -> np.ndarray:
        gtzan_emb = emb_loader.get_gtzan_emb(self.emb_name)[gtzan_key]
        # The GTZAN embedding is the "optimal" embedding. Predicting it would be exactly right.
        # We retrieve out of many FMA embeddings the closest one.
        return self.retrieve_closest_fma_emb(gtzan_emb)

    def can_predict(self, gtzan_key: str) -> bool:
        return True


class ExportedPredictor(EmbeddingPredictor):
    """Predicts an embedding by loading previously exported predictions."""

    def __init__(self, name: str, file_name: str, emb_name: str):
        super().__init__(name, emb_name)
        self.file_name = file_name
        self.pred = emb_loader.load_predictions(file_name)

    def predict_emb(self, gtzan_key: str) -> np.ndarray:
        return self.pred[gtzan_key]

    def can_predict(self, gtzan_key: str) -> bool:
        return gtzan_key in self.pred


class OnlinePredictor(EmbeddingPredictor):
    """Predicts an embedding online by running model inference."""

    def __init__(
        self, name: str, emb_name: str, gtzan_keys: list[str], preds: np.ndarray
    ) -> None:
        super().__init__(name, emb_name)
        self.gtzan_keys = gtzan_keys
        self.gtzan_keys_set = set(gtzan_keys)
        self.preds = preds

    def predict_emb(self, gtzan_key: str) -> np.ndarray:
        return self.preds[self.gtzan_keys.index(gtzan_key)]

    def can_predict(self, gtzan_key: str) -> bool:
        return gtzan_key in self.gtzan_keys_set
