# stdlib
from typing import Any, List

# third party
import pandas as pd
import torch

# hyperimpute absolute
import hyperimpute.plugins.core.params as params
import hyperimpute.plugins.imputers.base as base
import hyperimpute.plugins.utils.decorators as decorators
from hyperimpute.plugins.imputers.plugin_gain import GainImputation
from hyperimpute.utils.distributions import enable_reproducible_results

EPS = 1e-8

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class GainPlugin(base.ImputerPlugin):
    """Imputation plugin for completing missing values using the GAIN strategy.

    Method:
        Details in the GainImputation class implementation.

    Example:
        >>> import numpy as np
        >>> from hyperimpute.plugins.imputers import Imputers
        >>> plugin = Imputers().get("gain")
        >>> plugin.fit_transform([[1, 1, 1, 1], [np.nan, np.nan, np.nan, np.nan], [1, 2, 2, 1], [2, 2, 2, 2]])
    """

    def __init__(
            self,
            batch_size: int = 128,
            n_epochs: int = 100,
            hint_rate: float = 0.8,
            loss_alpha: int = 10,
            random_state: int = 0,
    ) -> None:
        super().__init__(random_state=random_state)

        enable_reproducible_results(random_state)

        self.batch_size = batch_size
        self.n_epochs = n_epochs
        self.hint_rate = hint_rate
        self.loss_alpha = loss_alpha

        self._model = GainImputation(batch_size, n_epochs, hint_rate, loss_alpha)  # fixed

    @staticmethod
    def name() -> str:
        return "gain"

    @staticmethod
    def hyperparameter_space(*args: Any, **kwargs: Any) -> List[params.Params]:
        return [
            params.Categorical("batch_size", [64, 128, 256, 512]),
            params.Integer("n_epochs", 100, 1000, 100),
            params.Float("hint_rate", 0.8, 0.99),
            params.Integer("loss_alpha", 10, 100, 10),
        ]

    @decorators.benchmark
    def _fit(self, X: pd.DataFrame, *args: Any, **kwargs: Any) -> "GainPlugin":
        return self._model.fit(torch.tensor(X.values).to(DEVICE))

    @decorators.benchmark
    def _transform(self, X: pd.DataFrame) -> pd.DataFrame:
        X = torch.tensor(X.values).to(DEVICE)
        self._model.fit(X)
        return self._model.transform(X).detach().cpu().numpy()
