import json
import subprocess
import typing as ty
import warnings
from collections import Counter
from copy import deepcopy
from pathlib import Path

import numpy as np
import sklearn.preprocessing
from ablator import Enum
from category_encoders import LeaveOneOutEncoder
from joblib import Memory
from sklearn.impute import SimpleImputer
from torch.utils.data import dataset

ArrayDict = ty.Dict[str, np.ndarray]


class DatasetType(Enum):
    BINCLASS = "binclass"
    MULTICLASS = "multiclass"
    REGRESSION = "regression"


class TabDataset(dataset.Dataset):
    def __init__(
        self,
        path: ty.Union[Path, str],
        dataset_name: ty.Literal[
            "year",
            "yahoo",
            "helena",
            "covtype",
            "epsilon",
            "jannis",
            "adult",
            "aloi",
            "higgs_small",
            "microsoft",
            "california_housing",
        ],
        split: ty.Literal["train", "test", "val"],
        normalization: ty.Optional[ty.Literal["standard", "quantile"]],
        # num_nan_policy: ty.Literal["mean"],
        cat_nan_policy: ty.Literal["new", "most_frequent"],
        cat_policy: ty.Literal["ohe", "indices", "counter"],
        cat_min_frequency: float = 0.0,
        seed: int = 0,
    ) -> None:
        super().__init__()
        self.numerical = None
        self.categorical = None

        self.root_path = Path(path).joinpath(dataset_name)

        if not self.root_path.exists():
            data_dir = Path(path)
            data_dir.mkdir(exist_ok=True, parents=True)
            tar_file = data_dir.joinpath("revisiting_models_data.tar.gz")
            if not tar_file.exists():
                subprocess.run(
                    f"wget https://www.dropbox.com/s/o53umyg6mn3zhxy/data.tar.gz?dl=1 -O {tar_file}",
                    shell=True,
                    check=True,
                )
            subprocess.run(
                f"tar -xvf {data_dir.joinpath('revisiting_models_data.tar.gz')} -C {data_dir.parent}",
                shell=True,
                check=True,
            )

        self.info = json.loads((self.root_path / "info.json").read_text())

        self.info["size"] = (
            self.info["train_size"] + self.info["val_size"] + self.info["test_size"]
        )
        self.cache = self._make_cache(self.root_path)
        self.cat_policy = cat_policy
        self._categorical, self._numerical = self._make(
            root_path=self.root_path,
            normalization=normalization,
            # num_nan_policy=num_nan_policy,
            cat_nan_policy=cat_nan_policy,
            cat_policy=cat_policy,
            cat_min_frequency=cat_min_frequency,
            seed=seed,
        )
        if self._categorical is not None:
            self.categorical = self._categorical[split]
        if self._numerical is not None:
            self.numerical = self._numerical[split]
        assert (
            self.categorical is not None or self.numerical is not None
        ), "Empty dataset."

        y = self._load_y(self.root_path)
        self.y_mean = y["train"].mean()
        self.y_std = y["train"].std()
        if self.is_regression:
            y, (self.y_mean, self.y_std) = self._make_y(
                y, "mean_std", self.is_regression
            )

        self.y = y[split]

    def _make_cache(self, root_path):
        memory = Memory(location=root_path.joinpath(".cache"), verbose=0)
        self._make_y = memory.cache(self._make_y, ignore=["cls", "y"])
        self._make_numerical = memory.cache(
            self._make_numerical, ignore=["cls", "numerical"]
        )
        self._replace_nan_categorical = memory.cache(
            self._replace_nan_categorical, ignore=["cls", "categorical"]
        )
        self._replace_unpopular_categories = memory.cache(
            self._replace_unpopular_categories, ignore=["cls", "categorical"]
        )
        self._categorical_ordinal_encoding = memory.cache(
            self._categorical_ordinal_encoding, ignore=["cls", "categorical"]
        )
        self._categorical_ohe = memory.cache(
            self._categorical_ohe, ignore=["cls", "numerical", "categorical"]
        )
        self._categorical_counter = memory.cache(
            self._categorical_counter, ignore=["cls", "numerical", "categorical", "y"]
        )
        return memory

    @property
    def categories(self) -> ty.Optional[ty.List[int]]:
        return (
            None
            if self._categorical is None
            else [
                len(set(self._categorical["train"][:, i]))
                for i in range(self._categorical["train"].shape[1])
            ]
        )

    @property
    def task_type(self) -> DatasetType:
        return DatasetType(self.info["task_type"])

    @property
    def d_out(self) -> int:
        if self.is_regression or self.is_binclass:
            return 1
        else:
            return self.info["n_classes"]

    @property
    def is_binclass(self) -> bool:
        return self.info["task_type"] == DatasetType.BINCLASS

    @property
    def is_multiclass(self) -> bool:
        return self.info["task_type"] == DatasetType.MULTICLASS

    @property
    def is_regression(self) -> bool:
        return self.info["task_type"] == DatasetType.REGRESSION

    @property
    def n_num_features(self) -> int:
        return len(self.numerical[0]) if self.numerical is not None else 0

    @property
    def n_cat_features(self) -> int:
        return self.info["n_cat_features"]

    @property
    def n_features(self) -> int:
        return self.n_num_features + self.n_cat_features

    def __len__(self):
        return len(self.y)

    def __getitem__(self, index):
        y = self.y[index]
        return_dict = dict(y=y)
        if self.categorical is not None:
            return_dict["x_cat"] = self.categorical[index]
        if self.numerical is not None:
            return_dict["x_num"] = self.numerical[index]
        return return_dict

    @classmethod
    def _make_numerical(
        cls,
        numerical: ty.Optional[ArrayDict],
        num_nan_policy: ty.Literal["mean"],
        normalization: ty.Optional[ty.Literal["standard", "quantile"]],
        seed: int,
    ):
        if numerical is not None:
            numerical = deepcopy(numerical)

            num_nan_masks = {k: np.isnan(v) for k, v in numerical.items()}
            if any(x.any() for x in num_nan_masks.values()):  # type: ignore[code]
                if num_nan_policy == "mean":
                    num_new_values = np.nanmean(numerical["train"], axis=0)
                else:
                    raise NotImplementedError(f"numerical NaN policy {num_nan_policy}")
                for k, v in numerical.items():
                    num_nan_indices = np.where(num_nan_masks[k])
                    v[num_nan_indices] = np.take(num_new_values, num_nan_indices[1])
            if normalization is not None:
                numerical = cls._normalize(
                    numerical, normalization, seed, is_numerical=True
                )
        return numerical

    @classmethod
    def _load(cls, root_path, item: ty.Literal["N", "C", "y"]) -> ArrayDict:
        return {
            x: ty.cast(np.ndarray, np.load(root_path / f"{item}_{x}.npy"))
            for x in ["train", "test", "val"]
        }

    @classmethod
    def _load_numerical(cls, root_path: Path):
        numerical = None
        if (root_path / "N_train.npy").exists():
            numerical = cls._load(root_path, "N")
        return numerical

    @classmethod
    def _load_categorical(cls, root_path: Path):
        categorical = None
        if (root_path / "C_train.npy").exists():
            categorical = cls._load(root_path, "C")
        return categorical

    @classmethod
    def _normalize(
        cls,
        X: ArrayDict,
        normalization: ty.Literal["standard", "quantile"],
        seed: int,
        noise: float = 1e-3,
        **cache_args,
    ) -> ArrayDict:
        X_train = X["train"].copy()
        if normalization == "standard":
            normalizer = sklearn.preprocessing.StandardScaler()
        elif normalization == "quantile":
            normalizer = sklearn.preprocessing.QuantileTransformer(
                output_distribution="normal",
                n_quantiles=max(min(X["train"].shape[0] // 30, 1000), 10),
                subsample=int(1e9),
                random_state=seed,
            )
            if noise:
                stds = np.std(X_train, axis=0, keepdims=True)
                noise_std = noise / np.maximum(stds, noise)  # type: ignore[code]
                X_train += noise_std * np.random.default_rng(seed).standard_normal(  # type: ignore[code]
                    X_train.shape
                )
        else:
            raise NotImplementedError(f"normalization {normalization}")
        normalizer.fit(X_train)
        return {k: normalizer.transform(v) for k, v in X.items()}  # type: ignore[code]

    @classmethod
    def _load_y(cls, root_path: Path):
        return cls._load(root_path, "y")

    @classmethod
    def _replace_nan_categorical(cls, categorical: ArrayDict, cat_nan_policy):
        categorical = deepcopy(categorical)
        cat_nan_masks = {k: v == "nan" for k, v in categorical.items()}
        if any(x.any() for x in cat_nan_masks.values()):  # type: ignore[code]
            if cat_nan_policy == "new":
                cat_new_value = "___null___"
                imputer = None
            elif cat_nan_policy == "most_frequent":
                cat_new_value = None
                imputer = SimpleImputer(strategy=cat_nan_policy)  # type: ignore[code]
                imputer.fit(categorical["train"].astype(object))
            else:
                raise NotImplementedError(f"cat_nan_policy NaN policy {cat_nan_policy}")
            if imputer:
                categorical = {k: imputer.transform(v) for k, v in categorical.items()}
            else:
                for k, v in categorical.items():
                    cat_nan_indices = np.where(cat_nan_masks[k])
                    v[cat_nan_indices] = cat_new_value
        return categorical

    @classmethod
    def _replace_unpopular_categories(
        cls, categorical: ArrayDict, cat_min_frequency: ty.Optional[float], **cache_args
    ):
        if cat_min_frequency is not None and cat_min_frequency > 0:
            categorical = deepcopy(categorical)
            categorical = ty.cast(ArrayDict, categorical)
            min_count = round(len(categorical["train"]) * cat_min_frequency)
            rare_value = "___rare___"
            C_new = {x: [] for x in categorical}
            for column_idx in range(categorical["train"].shape[1]):
                counter = Counter(categorical["train"][:, column_idx].tolist())
                popular_categories = {k for k, v in counter.items() if v >= min_count}
                for part in C_new:
                    C_new[part].append(
                        [
                            (x if x in popular_categories else rare_value)
                            for x in categorical[part][:, column_idx].tolist()
                        ]
                    )
            categorical = {k: np.array(v).T for k, v in C_new.items()}
        return categorical

    @classmethod
    def _categorical_ordinal_encoding(cls, categorical: ArrayDict, **cache_args):
        unknown_value = np.iinfo("int64").max - 3
        encoder = sklearn.preprocessing.OrdinalEncoder(
            handle_unknown="use_encoded_value",  # type: ignore[code]
            unknown_value=unknown_value,  # type: ignore[code]
            dtype="int64",  # type: ignore[code]
        ).fit(categorical["train"])
        categorical = {k: encoder.transform(v) for k, v in categorical.items()}
        max_values = categorical["train"].max(axis=0)
        for part in ["val", "test"]:
            for column_idx in range(categorical[part].shape[1]):
                categorical[part][
                    categorical[part][:, column_idx] == unknown_value, column_idx
                ] = (max_values[column_idx] + 1)
        return categorical

    @classmethod
    def _categorical_ohe(
        cls, categorical: ArrayDict, numerical: ArrayDict, **cache_args
    ):
        ohe = sklearn.preprocessing.OneHotEncoder(
            handle_unknown="ignore", sparse_output=False, dtype="float32"  # type: ignore[code]
        )
        ohe.fit(categorical["train"])
        categorical = {k: ohe.transform(v) for k, v in categorical.items()}
        result = (
            categorical
            if numerical is None
            else {x: np.hstack((numerical[x], categorical[x])) for x in numerical}
        )
        return result

    @classmethod
    def _categorical_counter(
        cls, categorical, numerical, y, normalization, seed, **cache_args
    ):
        assert seed is not None
        loo = LeaveOneOutEncoder(
            sigma=0.1,
            random_state=seed,
            cols=np.arange(len(categorical)),
            return_df=False,
        )
        loo.fit(categorical["train"], y["train"])
        categorical = {k: loo.transform(v).astype("float32") for k, v in categorical.items()}  # type: ignore[code]
        if not isinstance(categorical["train"], np.ndarray):
            categorical = {k: v.values for k, v in categorical.items()}  # type: ignore[code]
        if normalization:
            cache_args.update(dict(is_categorical=True))
            categorical = cls._normalize(categorical, normalization, seed, **cache_args)  # type: ignore[code]
        result = (
            categorical
            if numerical is None
            else {x: np.hstack((numerical[x], categorical[x])) for x in numerical}
        )
        return result

    @classmethod
    def _make(
        cls,
        root_path: Path,
        normalization: ty.Optional[ty.Literal["standard", "quantile"]],
        cat_nan_policy: ty.Literal["new", "most_frequent"],
        cat_policy: ty.Literal["ohe", "indices", "counter"],
        cat_min_frequency: float = 0.0,
        seed: int = 0,
    ) -> ty.Tuple[ty.Optional[ArrayDict], ty.Optional[ArrayDict]]:
        numerical = cls._load_numerical(root_path)
        categorical = cls._load_categorical(root_path)
        y = cls._load_y(root_path)
        numerical = cls._make_numerical(numerical, "mean", normalization, seed)
        if cat_policy == "drop" or categorical is None:
            assert numerical is not None
            return None, numerical
        categorical = cls._replace_nan_categorical(categorical, cat_nan_policy)

        cache_args = dict(cat_nan_policy=cat_nan_policy)
        categorical = cls._replace_unpopular_categories(
            categorical, cat_min_frequency, **cache_args
        )
        cache_args.update(dict(cat_min_frequency=cat_min_frequency))
        categorical = cls._categorical_ordinal_encoding(categorical, **cache_args)

        if cat_policy == "indices":
            return categorical, numerical

        elif cat_policy == "ohe":
            result = cls._categorical_ohe(categorical, numerical, **cache_args)
        elif cat_policy == "counter":
            result = cls._categorical_counter(
                categorical, numerical, y, normalization, seed, **cache_args
            )
        else:
            raise NotImplementedError(f"categorical policy {cat_policy}")
        return None, result  # type: ignore[code]

    @classmethod
    def _make_y(
        cls,
        y: ArrayDict,
        policy: ty.Optional[ty.Literal["mean_std"]],
        is_regression: bool,
    ) -> ty.Tuple[ArrayDict, ty.Tuple[ty.Optional[float], ty.Optional[float]]]:
        if is_regression:
            assert policy == "mean_std"
        y_out = deepcopy(y)
        mean = None
        std = None
        if policy is not None:
            if not is_regression:
                warnings.warn("y_policy is not None, but the task is NOT regression")
            elif policy == "mean_std":
                mean, std = y_out["train"].mean(), y_out["train"].std()
                y_out = {k: (v - mean) / std for k, v in y_out.items()}
            else:
                raise NotImplementedError(f"y policy {policy}")
        return y_out, (mean, std)
