import random
import warnings
from abc import abstractmethod
from typing import Any
from scipy import stats

import pandas as pd
import numpy as np

from .gain import GainPlugin
from hyperimpute.plugins.imputers.plugin_hyperimpute import HyperImputePlugin
from hyperimpute.plugins.imputers.plugin_mice import MicePlugin
from hyperimpute.plugins.imputers.plugin_EM import EMPlugin
from hyperimpute.plugins.imputers.plugin_missforest import MissForestPlugin
from hyperimpute.plugins.imputers.plugin_miwae import MIWAEPlugin
from hyperimpute.plugins.imputers.plugin_mean import MeanPlugin
from hyperimpute.plugins.imputers.plugin_median import MedianPlugin
from ...base import *
from ...utils import get_kwargs, SeedContext


class WrapperBase(BaseImputerMixIn, Base):

    """
    https://github.com/vanderschaarlab/hyperimpute
    Jarrett, D., Cebere, B. C., Liu, T., Curth, A., & van der Schaar, M. (2022, June).
    Hyperimpute: Generalized iterative imputation with automatic model selection.
    In International Conference on Machine Learning (pp. 9916-9937). PMLR.
    """

    def __init__(self, num_average=1, plugin_params=dict(), **kwargs):
        super().__init__(**get_kwargs(**kwargs))
        self._is_fit = False
        self.plugin_params = plugin_params
        self.model = self.method()(**self.plugin_params)
        self.num_average = num_average

    @abstractmethod
    def method(self):
        pass

    def fit(self, scenario=lambda x:x):
        if self.model_flags['in_sample_only']:
            return super().fit(scenario)
        df = scenario(pd.read_csv(self._cfg.dataset.train_path))
        self._transform.fit(df)
        X = self.tabular_transform.transform(df, return_as_tensor=True).numpy()
        self.model.fit(pd.DataFrame(X))
        return super().fit(scenario)

    def impute(self, df: pd.DataFrame, seed=None, **kwargs) -> pd.DataFrame:
        df = df.copy(deep=True)
        num_tokens = []
        cat_tokens = []
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            X = self.tabular_transform.transform(df, return_as_tensor=True).numpy()
            if not self.model_flags['onehot']:
                X[:, self.numerical_dim:] = X[:, self.numerical_dim:].astype(str)
            nan_mask = np.isnan(X)
            for _ in range(self.num_average):
                imputed = self.model.transform(X)
                if isinstance(imputed, pd.DataFrame):
                    imputed = imputed.to_numpy()
                num_tokens.append(imputed[:, :self.numerical_dim])
                cat_tokens.append(imputed[:, self.numerical_dim:])

            num = np.stack(num_tokens).mean(axis=0)
            if self.model_flags['onehot']:
                cat = np.stack(cat_tokens).mean(dim=0)
            else:
                cat = stats.mode(np.stack(cat_tokens), axis=0, keepdims=False).mode
            imputed = np.concatenate([num, cat], axis=1)
            X[nan_mask] = imputed[nan_mask]
            df_ = self.tabular_transform.inverse_transform(X)
        df_.index = df.index
        df[df_.columns] = df_
        return df

    def _impute(self, df: pd.DataFrame, **kwargs):
        pass

    def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None:
        checkpoint['_is_fit'] = self._is_fit
        checkpoint['model'] = self.model

    def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None:
        self._is_fit = checkpoint['_is_fit']
        self.mode = checkpoint['model']


class HyperImpute(WrapperBase):

    def method(self):
        return HyperImputePlugin


class GAIN(WrapperBase):

    def method(self):
        return GainPlugin


class MICE(WrapperBase):

    def method(self):
        return MicePlugin


class MissForest(WrapperBase):

    def method(self):
        return MissForestPlugin


class EM(WrapperBase):

    def method(self):
        return EMPlugin
