from abc import ABC
from enum import Enum

import numpy as np
from numpy.typing import NDArray

import research.wsl_ece.metric.ece as ece_module
from research.wsl_ece.metric.wandb_util import get_validation_and_test_data_artifacts

SEED_LIST = [42, 1373158606, 239081663, 53710184, 1592467581, 590620971, 525901256, 479341423, 299655412, 1581559892]


class DatasetForECEGapExperiment(ABC):
    MODEL_ARTIFACT_NAMES: list[str]

    def __init__(self):
        data = get_validation_and_test_data_artifacts(self.MODEL_ARTIFACT_NAMES)
        positive_logits, unlabeled_logits, unlabeled_labels, test_logits, test_labels = data
        self.positive_logits = positive_logits
        self.unlabeled_logits = unlabeled_logits
        self.unlabeled_labels = unlabeled_labels
        self.test_logits = test_logits
        self.test_labels = test_labels
        self._prior = float(np.mean(self.test_labels))

    @property
    def prior(self) -> float:
        """Return the prior probability of the positive class."""
        return self._prior

    def ece(self, logits: NDArray[np.floating], labels: NDArray[np.integer]) -> float:
        n_samples = len(logits)
        n_bins = ece_module.get_recommended_n_bins_for_ece(n_samples)
        ece = ece_module.calculate_ece(logits=logits, labels=labels, n_bins=n_bins)
        return ece

    def pu_ece(self, positive_logits: NDArray[np.floating], unlabeled_logits: NDArray[np.floating]) -> float:
        n_positive = len(positive_logits)
        n_unlabeled = len(unlabeled_logits)
        n_bins = ece_module.get_recommended_n_bins_for_pu_ece(self.prior, n_positive, n_unlabeled)
        pu_ece = ece_module.calculate_pu_ece(
            positive_logits=positive_logits,
            unlabeled_logits=unlabeled_logits,
            prior=self.prior,
            n_bins=n_bins,
            n_positive=n_positive,
            n_unlabeled=n_unlabeled,
        )
        return pu_ece


class MNIST(DatasetForECEGapExperiment):
    MODEL_ARTIFACT_NAMES = [
        "wsl-ece/"
        f"predictions_table_mnist_MLP_300_300_10000_256_100_0_001_sigmoid_False_False_cross_entropy_{seed}:latest"
        for seed in SEED_LIST
    ]


class CIFAR10(DatasetForECEGapExperiment):
    MODEL_ARTIFACT_NAMES = [
        "wsl-ece/"
        f"predictions_table_cifar10_ResNet18_10000_256_100_1e_05_sigmoid_False_False_cross_entropy_{seed}:latest"
        for seed in SEED_LIST
    ]


class DatasetNamesForECEGapExperiment(Enum):
    """Enum for datasets used in the ECE gap experiment."""

    MNIST = "MNIST"
    CIFAR10 = "CIFAR10"

    def dataset_class(self) -> DatasetForECEGapExperiment:
        match self:
            case DatasetNamesForECEGapExperiment.MNIST:
                return MNIST()
            case DatasetNamesForECEGapExperiment.CIFAR10:
                return CIFAR10()


class ECEGapExperimentRunner:
    """Orchestrates the ECE gap experiment for a given dataset."""

    def __init__(self, dataset_name: DatasetNamesForECEGapExperiment):
        self.dataset_name = dataset_name
        self.dataset = dataset_name.dataset_class()

    def run(self) -> dict[str, NDArray[np.floating]]:
        n_positive = len(self.dataset.positive_logits)
        # Restrict the number of unlabeled samples to match the number of positive samples for calculating ECE
        idx = np.random.choice(len(self.dataset.unlabeled_logits), n_positive, replace=False)

        logits_for_ece = [logits[idx] for logits in self.dataset.unlabeled_logits]
        labels_for_ece = [labels[idx] for labels in self.dataset.unlabeled_labels]
        results = {
            "test_ece": np.array(
                [
                    self.dataset.ece(self.dataset.test_logits[i], self.dataset.test_labels[i])
                    for i in range(len(self.dataset.test_logits))
                ]
            ),
            "ece": np.array(
                [self.dataset.ece(logits_for_ece[i], labels_for_ece[i]) for i in range(len(logits_for_ece))]
            ),
            "pu_ece": np.array(
                [
                    self.dataset.pu_ece(self.dataset.positive_logits[i], self.dataset.unlabeled_logits[i])
                    for i in range(len(self.dataset.positive_logits))
                ]
            ),
        }
        return results
