from __future__ import annotations

from abc import ABC
import os
from typing import Any, Final, Generic, Optional, Type, TypeVar

import numpy as np
import scipy.io as sp_io
import sklearn.preprocessing as skl_preprocessing
import torch as th
import torch.utils.data as th_data

import models.classifiers

from .. import base, common, protocols, utils

SubsetFeatureClassifierType = TypeVar(
    "SubsetFeatureClassifierType", bound=models.classifiers.SubsetFeatureClassifier
)


class MatlabClassificationDataset(
    th_data.Dataset[tuple[th.Tensor, th.Tensor, th.Tensor]],
    protocols.DatasetHasFeaturizer,
):
    featurizer: skl_preprocessing.StandardScaler

    def __init__(self, name: str, to_return_zs: bool = True) -> None:
        super().__init__()
        data_p: str = os.path.join(
            common.get_datasets_files_root_dir(),
            "MatlabClassificationDataset",
            f"{name}.mat",
        )
        data_raw: dict[str, Any] = sp_io.loadmat(data_p)
        xs_n: np.ndarray = data_raw["X"]
        ys_n: np.ndarray = data_raw["Y"].flatten()
        # scale each column to zero and one
        self.featurizer = skl_preprocessing.StandardScaler()
        zs_n: np.ndarray = self.featurizer.fit_transform(xs_n)
        # convert to torch tensor and set dataset attributes
        self.xs = th.as_tensor(xs_n, dtype=th.float32).share_memory_()
        self.masks = th.ones_like(self.xs, dtype=th.long).share_memory_()
        self.ys = th.as_tensor(ys_n, dtype=th.long).flatten().share_memory_()
        self.zs = th.as_tensor(zs_n, dtype=th.float32).share_memory_()
        self._dataset = th_data.TensorDataset(self.xs, self.zs, self.masks, self.ys)
        self._to_ret_zs = to_return_zs

    @property
    def to_return_zs(self) -> bool:
        return self._to_ret_zs

    @to_return_zs.setter
    def to_return_zs(self, to_return_zs: bool):
        self._to_ret_zs = to_return_zs

    def __getitem__(self, index) -> tuple[th.Tensor, th.Tensor, th.Tensor]:
        x, z, mask, y = self._dataset.__getitem__(index)
        if self.to_return_zs:
            return z, mask, y
        return x, mask, y

    def __len__(self):
        return len(self._dataset)


class MatlabClassificationEnvManagerBase(
    base.EnvManager, th.nn.Module, ABC, Generic[SubsetFeatureClassifierType]
):
    alpha: float

    _n_experts_per_fcomb: int
    _cinds_f: th.Tensor
    _xacts_f: th.Tensor

    _train_env: base.TorchDatasetEnv | None
    _val_env: base.TorchDatasetEnv | None

    _extrain_set: Final[th_data.Subset]
    _train_set: Final[th_data.Subset]
    _val_set: Final[th_data.Subset]

    _n_covs: Final[int]
    _n_labels: Final[int]

    _n_acts_avail: Optional[int]
    _classifier: SubsetFeatureClassifierType | None

    @property
    def n_acts_avail(self) -> int:
        return self.train_env.n_acts_avail

    @property
    def classifier(self) -> SubsetFeatureClassifierType:
        assert self._classifier is not None
        return self._classifier

    @property
    def n_covs(self) -> int:
        return self._n_covs

    @property
    def n_labels(self) -> int:
        return self._n_labels

    @property
    def n_experts_per_fcomb(self) -> int:
        return self._n_experts_per_fcomb

    @property
    def train_env(self):
        if self._train_env is None:
            self._train_env = base.TorchDatasetEnv(
                n_covs=self.n_covs,
                n_experts_per_comb=self.n_experts_per_fcomb,
                cinds_f=self._cinds_f,
                xacts_f=self._xacts_f,
                dataset=self._train_set,
                classifier=self.classifier,
                alpha=self.alpha,
                is_train=True,
                n_acts_avail=self._n_acts_avail,
            )
        return self._train_env

    @property
    def val_env(self):
        if self._val_env is None:
            self._val_env = base.TorchDatasetEnv(
                n_covs=self.n_covs,
                n_experts_per_comb=self.n_experts_per_fcomb,
                cinds_f=self._cinds_f,
                xacts_f=self._xacts_f,
                dataset=self._val_set,
                classifier=self.classifier,
                alpha=self.alpha,
                is_train=False,
                n_acts_avail=self._n_acts_avail,
            )
        return self._val_env

    @property
    def test_env(self):
        return self.val_env

    def __init__(
        self,
        name: str,
        n_experts_per_fcomb: int,
        to_return_zs: bool = True,
        min_features: int = 1,
        max_features: int | None = None,
        alpha: float = 0,
        n_acts_avail: Optional[int] = None,
        n_tests: int = 2000,
    ) -> None:
        super().__init__()
        self.orig_dataset = MatlabClassificationDataset(
            name=name, to_return_zs=to_return_zs
        )
        self._n_experts_per_fcomb = n_experts_per_fcomb
        self.alpha = alpha
        self._n_acts_avail = n_acts_avail
        # split dataset
        xs, _, ys = utils.th_dataset_to_ndarrays(self.orig_dataset)
        n_samps: int = len(xs)
        self._n_covs = xs.shape[1]
        self._n_labels = len(np.unique(ys))
        rprm: np.ndarray = np.random.RandomState(42).permutation(n_samps)
        extrain_idxs = sorted(rprm[: (n_samps - n_tests) // 2].tolist())
        train_idxs = sorted(rprm[(n_samps - n_tests) // 2 : -n_tests].tolist())
        val_idxs = sorted(rprm[-n_tests:].tolist())
        self._extrain_set = th_data.Subset(self.orig_dataset, extrain_idxs)
        self._train_set = th_data.Subset(self.orig_dataset, train_idxs)
        self._val_set = th_data.Subset(self.orig_dataset, val_idxs)
        # all possible feature combinations
        self._cinds_f, self._xacts_f, _, _ = (
            models.classifiers.SubsetFeatureClassifier.make_full_action_features(
                n_covs=self.n_covs,
                min_features=min_features,
                max_features=max_features,
                n_experts_per_comb=n_experts_per_fcomb,
            )
        )
        # construct classifiers
        # will get initialized once subclass provides classifiers
        self._classifier = None
        self._train_env = None
        self._val_env = None

    def expert_func(self, ctxs: th.Tensor, acts: th.Tensor) -> th.Tensor:
        return self.classifier.predict_proba(ctxs, acts)


class MatlabClassificationXGBCEnvManager(
    MatlabClassificationEnvManagerBase[models.classifiers.SubsetFeatureXGBClassifier]
):
    xgbc_kwargs: dict[str, Any]

    def __init__(
        self,
        name: str,
        to_return_zs: bool = True,
        min_features: int = 1,
        max_features: int | None = None,
        alpha: float = 0,
        xgbc_kwargs: dict[str, Any] = {},
        n_acts_avail: Optional[int] = None,
        n_tests: int = 2000,
        to_cache_model: bool = True,
    ) -> None:
        super().__init__(
            name=name,
            n_experts_per_fcomb=1,
            to_return_zs=to_return_zs,
            min_features=min_features,
            max_features=max_features,
            alpha=alpha,
            n_acts_avail=n_acts_avail,
            n_tests=n_tests,
        )
        # construct classifiers
        xs_extrain, _, ys_extrain = utils.th_dataset_to_ndarrays(self._extrain_set)
        self._classifier = models.classifiers.SubsetFeatureXGBClassifier(
            xacts_f=self._xacts_f,
            xs_train=xs_extrain,
            ys_train=ys_extrain,
            to_cache_model=to_cache_model,
            xgbc_kwargs=xgbc_kwargs,
        )


class MatlabClassificationLRCEnvManager(
    MatlabClassificationEnvManagerBase[
        models.classifiers.SubsetFeatureLogisticRegressionClassifier
    ]
):

    def __init__(
        self,
        name: str,
        to_return_zs: bool = True,
        min_features: int = 1,
        max_features: int | None = None,
        alpha: float = 0,
        lrc_kwargs: dict[str, Any] = {},
        n_acts_avail: Optional[int] = None,
        n_tests: int = 2000,
        to_cache_model: bool = True,
    ) -> None:
        super().__init__(
            name=name,
            n_experts_per_fcomb=1,
            to_return_zs=to_return_zs,
            min_features=min_features,
            max_features=max_features,
            alpha=alpha,
            n_acts_avail=n_acts_avail,
            n_tests=n_tests,
        )
        # construct classifiers
        xs_extrain, _, ys_extrain = utils.th_dataset_to_ndarrays(self._extrain_set)
        self._classifier = models.classifiers.SubsetFeatureLogisticRegressionClassifier(
            xacts_f=self._xacts_f,
            xs_train=xs_extrain,
            ys_train=ys_extrain,
            lrc_kwargs=lrc_kwargs,
            to_cache_model=to_cache_model,
        )


class MatlabClassificationKNNLRCEnvManager(
    MatlabClassificationEnvManagerBase[
        models.classifiers.SubsetFeatureKNNLogisticRegressionClassifier
    ]
):

    def __init__(
        self,
        name: str,
        to_return_zs: bool = True,
        min_features: int = 1,
        max_features: int | None = None,
        alpha: float = 0,
        knn_kwargs: dict[str, Any] = {},
        lrc_kwargs: dict[str, Any] = {},
        n_acts_avail: Optional[int] = None,
        n_tests: int = 2000,
    ) -> None:
        super().__init__(
            name=name,
            n_experts_per_fcomb=1,
            to_return_zs=to_return_zs,
            min_features=min_features,
            max_features=max_features,
            alpha=alpha,
            n_acts_avail=n_acts_avail,
            n_tests=n_tests,
        )
        # construct classifiers
        xs_extrain, _, ys_extrain = utils.th_dataset_to_ndarrays(self._extrain_set)
        self._classifier = (
            models.classifiers.SubsetFeatureKNNLogisticRegressionClassifier(
                xacts_f=self._xacts_f,
                xs_train=xs_extrain,
                ys_train=ys_extrain,
                knn_kwargs=knn_kwargs,
                lrc_kwargs=lrc_kwargs,
            )
        )


class MatlabClassificationNWCEnvManager(
    MatlabClassificationEnvManagerBase[
        models.classifiers.SubsetFeatureNadarayaWatsonClassifier
    ]
):

    def __init__(
        self,
        name: str,
        to_return_zs: bool = True,
        min_features: int = 1,
        max_features: int | None = None,
        alpha: float = 0,
        sig_mult: float = 0.15,
        n_acts_avail: Optional[int] = None,
        n_tests: int = 2000,
    ) -> None:
        super().__init__(
            name=name,
            n_experts_per_fcomb=1,
            to_return_zs=to_return_zs,
            min_features=min_features,
            max_features=max_features,
            alpha=alpha,
            n_acts_avail=n_acts_avail,
            n_tests=n_tests,
        )
        # construct classifiers
        xs_extrain, _, ys_extrain = utils.th_dataset_to_ndarrays(self._extrain_set)
        self._classifier = models.classifiers.SubsetFeatureNadarayaWatsonClassifier(
            xacts_f=self._xacts_f,
            xs_train=xs_extrain,
            ys_train=ys_extrain,
            sig_mult=sig_mult,
        )


class MatlabClassificationMultiExpertNWCEnvManager(
    MatlabClassificationEnvManagerBase[
        models.classifiers.SubsetFeatureMultiExpertNadarayaWatsonClassifier
    ]
):

    def __init__(
        self,
        name: str,
        n_experts_per_fcomb: int,
        to_return_zs: bool = True,
        min_features: int = 1,
        max_features: int | None = None,
        alpha: float = 0,
        kmeans_kwargs: dict[str, Any] = dict(),
        sig_mult: float = 0.15,
        n_acts_avail: Optional[int] = None,
        n_tests: int = 2000,
    ) -> None:
        super().__init__(
            name=name,
            n_experts_per_fcomb=n_experts_per_fcomb,
            to_return_zs=to_return_zs,
            min_features=min_features,
            max_features=max_features,
            alpha=alpha,
            n_acts_avail=n_acts_avail,
            n_tests=n_tests,
        )
        # construct classifiers
        xs_extrain, _, ys_extrain = utils.th_dataset_to_ndarrays(self._extrain_set)
        self._classifier = (
            models.classifiers.SubsetFeatureMultiExpertNadarayaWatsonClassifier(
                n_experts_per_act=n_experts_per_fcomb,
                cinds_f=self._cinds_f,
                xacts_f=self._xacts_f,
                xs_train=xs_extrain,
                ys_train=ys_extrain,
                kmeans_kwargs=kmeans_kwargs,
                sig_mult=sig_mult,
            )
        )


class MatlabClassificationBiasedWrapperEnvManager(base.EnvManager, th.nn.Module):

    base_manager: MatlabClassificationEnvManagerBase

    _train_env: base.TorchDatasetEnv | None
    _val_env: base.TorchDatasetEnv | None

    classifier: models.classifiers.SubsetFeatureBiasedClassifierWrapperBase

    @property
    def n_covs(self) -> int:
        return self.base_manager.n_covs

    @property
    def n_labels(self) -> int:
        return self.base_manager.n_labels

    @property
    def train_env(self):
        if self._train_env is None:
            self._train_env = base.TorchDatasetEnv(
                n_covs=self.n_covs,
                n_experts_per_comb=self.base_manager.n_experts_per_fcomb,
                cinds_f=self.base_manager._cinds_f,
                xacts_f=self.base_manager._xacts_f,
                dataset=self.base_manager._train_set,
                classifier=self.classifier,
                alpha=self.base_manager.alpha,
                is_train=True,
                n_acts_avail=self.base_manager._n_acts_avail,
            )
        return self._train_env

    @property
    def val_env(self):
        if self._val_env is None:
            self._val_env = base.TorchDatasetEnv(
                n_covs=self.n_covs,
                n_experts_per_comb=self.base_manager.n_experts_per_fcomb,
                cinds_f=self.base_manager._cinds_f,
                xacts_f=self.base_manager._xacts_f,
                dataset=self.base_manager._val_set,
                classifier=self.classifier,
                alpha=self.base_manager.alpha,
                is_train=False,
                n_acts_avail=self.base_manager._n_acts_avail,
            )
        return self._val_env

    @property
    def test_env(self):
        return self.val_env

    def __init__(
        self,
        base_manager: MatlabClassificationEnvManagerBase[SubsetFeatureClassifierType],
        classifier_wrapper_class: Type[
            models.classifiers.SubsetFeatureBiasedClassifierWrapperBase
        ],
        classifier_wrapper_kwargs: dict[str, Any],
    ) -> None:
        super().__init__()
        self.base_manager = base_manager
        self.classifier = classifier_wrapper_class(
            base_classifier=base_manager.classifier, **classifier_wrapper_kwargs
        )
        self._train_env = None
        self._val_env = None

    def expert_func(self, ctxs: th.Tensor, acts: th.Tensor) -> th.Tensor:
        return self.classifier.predict_proba(ctxs, acts)
