from abc import ABC, abstractmethod
import os
import numpy as np
import pandas as pd
import sklearn
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import (
    SplineTransformer,
    StandardScaler,
)
from pathlib import Path
from utils.decorators import register
from functools import partial
from typing import Dict, Type
import datasets
from datasets import load_from_disk
from decision.xp.constants import (
    FINETUNED_DATASET_PATH,
    FINETUNED_MODEL_PATH,
    PROCESSED_DATASET_PATH,
)
from decision.xp.model.base import PretrainedMixin
from sklearn.model_selection import GroupShuffleSplit
from datasets import Dataset

sklearn.set_config(transform_output="pandas")
data_path = os.getenv("DATA", "/data")


class BaseDataset(ABC):
    """Base class for tabular datasets."""

    test_mode = False

    @property
    @abstractmethod
    def y_name(self):
        """Return the name of the target variable."""
        pass

    @property
    def has_test_set(self):
        return False

    def eval(self):
        """Toggle test mode."""
        if not self.has_test_set:
            raise ValueError("Calling 'eval()' on dataset with no test set.")
        self.test_mode = True

    def train(self):
        """Toggle test mode."""
        if not self.has_test_set:
            raise ValueError("Calling 'train()' on dataset with no test set.")
        self.test_mode = False

    @abstractmethod
    def get_df(self) -> pd.DataFrame:
        """Return the dataframe."""
        pass

    def process_df(self, df: pd.DataFrame) -> (pd.DataFrame, pd.Series):
        return df.drop(columns=[self.y_name]), df[self.y_name]

    def get_processed_df(self) -> (pd.DataFrame, pd.Series):
        df = self.get_df()
        return self.process_df(df)

    def subsample_df(
        self, df_X: pd.DataFrame, df_y: pd.Series
    ) -> (pd.DataFrame, pd.Series):
        return df_X, df_y

    def get_real_X_y(self) -> (pd.DataFrame, pd.Series):
        df_X, df_y = self.get_processed_df()
        df_X, df_y = self.subsample_df(df_X, df_y)
        return df_X, df_y

    def get_synth_X_y(self) -> (pd.DataFrame, pd.Series, pd.Series):
        X, y = self.get_real_X_y()

        X_train, y_train = X, y
        X_test, _ = X, y

        pipeline = make_pipeline(
            StandardScaler(),
            SplineTransformer(n_knots=10, degree=3),
            LogisticRegression(),
        )
        pipeline.fit(X_train, y_train)
        self.synth_est_ = pipeline
        Q_test = pipeline.predict_proba(X_test)[:, 1]

        rng = np.random.default_rng(0)
        yp = rng.binomial(n=1, p=Q_test)
        Xp = X_test
        yp = pd.Series(yp, index=X_test.index, name="yp")
        Qp = pd.Series(Q_test, index=X_test.index, name="Qp")

        return Xp, yp, Qp


class BalancedMixin:
    @staticmethod
    def subsample_df(df_X: pd.DataFrame, df_y: pd.Series) -> (pd.DataFrame, pd.Series):
        # check that df_X and df_y have compatible indices
        if not df_X.index.equals(df_y.index):
            raise ValueError("df_X and df_y must have the same index.")

        counts = df_y.value_counts()
        df_X1 = df_X[df_y == 1]
        df_X0 = df_X[df_y == 0]

        n = 2 * counts[1]  # 16000
        n1 = n // 2
        n0 = n - n1
        assert n1 <= counts[1]
        assert n0 <= counts[0]
        subdf_X1 = df_X1.sample(n=n1, replace=False, random_state=0)
        subdf_X0 = df_X0.sample(n=n1, replace=False, random_state=0)
        subdf_X = pd.concat([subdf_X1, subdf_X0])

        # subsample df_y based on the subsampled indices of subdf_X
        subdf_y = df_y[subdf_X.index]

        return subdf_X, subdf_y


class ForwardedMixin(BaseDataset, ABC):
    """Forwarded datasets have the confidence score of the model included."""

    prop_test = 0.5
    prop_val = 0.5

    group_name = "__group__"

    @property
    @abstractmethod
    def S_name(self) -> str:
        """Return the name of the confidence variable."""
        pass

    @property
    @abstractmethod
    def ds_name(self) -> str:
        """Return the name of the confidence variable."""
        pass

    def __getstate__(self) -> object:
        pass  # If the dataset has parameters, write them here

    def __setstate__(self, state: object) -> None:
        pass

    def load_dataset(self):
        """Load the huggingface dataset."""
        return datasets.load_dataset(self.ds_name, split="train")

    @abstractmethod
    def extract(self, examples: dict):
        """Extract the relevant information from the huggingface dataset to give to the model.

        Parameters
        ----------
        examples : dict
            A batch of examples from the huggingface dataset.

        """
        pass

    @abstractmethod
    def get_label(self, examples: dict):
        """Retrieve the labels when processing the hugging face dataset."""
        pass

    def process_df(self, df: pd.DataFrame) -> (pd.DataFrame, pd.Series, pd.Series):
        return (
            df.drop(columns=[self.y_name, self.S_name, self.group_name]),
            df[self.y_name],
            df[self.S_name],
            df[self.group_name],
        )

    def subsample_df(
        self, df_X: pd.DataFrame, df_y: pd.Series, df_S: pd.Series, df_group: pd.Series
    ) -> (pd.DataFrame, pd.Series, pd.Series):
        return df_X, df_y, df_S, df_group

    def get_real_X_y(
        self,
        model: PretrainedMixin,
        finetuned: bool = False,
        split: str | None = None,
    ) -> (pd.DataFrame, pd.Series, pd.Series):
        df_X, df_y, df_S, df_group = self.get_processed_df(
            model=model, finetuned=finetuned, split=split
        )
        df_X, df_y, df_S, df_group = self.subsample_df(df_X, df_y, df_S, df_group)
        return df_X, df_y, df_S, df_group

    def get_processed_df(
        self,
        model: PretrainedMixin,
        finetuned: bool = False,
        split: str | None = None,
    ) -> (pd.DataFrame, pd.Series):
        df = self.get_df(model=model, finetuned=finetuned, split=split)
        return self.process_df(df)

    def get_dataset(self, model: PretrainedMixin, finetuned: bool = False):
        if finetuned:
            path = get_finetuned_dataset_path(self, model)
        else:
            path = get_processed_dataset_path(self, model)
        dataset = load_from_disk(path)
        return dataset

    def get_df(
        self,
        model: PretrainedMixin,
        finetuned: bool = False,
        split: str | None = None,
    ):
        dataset = self.get_dataset(model=model, finetuned=finetuned)

        if finetuned:
            if split is None:
                raise ValueError("Must specify split when finetuned=True.")
            assert split in ["train", "val", "test"]
            dataset = dataset[split]

        groups = self.get_groups(dataset)

        X = dataset["latent_space"]
        y = dataset["label"]
        if finetuned:
            S = dataset["finetuned_probabilities"][:, 1]
        else:
            S = dataset["probabilities"][:, 1]

        X = np.asarray(X)
        y = np.asarray(y)
        S = np.asarray(S)

        df_X = pd.DataFrame(X, columns=[f"Latent {i}" for i in range(X.shape[1])])
        df_y = pd.Series(y, name=self.y_name)
        # df_y = self.binarize_y_true(df_y)
        df_S = pd.Series(S, name=self.S_name)
        df_groups = pd.Series(groups, name=self.group_name)

        df = pd.concat([df_X, df_y, df_S, df_groups], axis=1)

        return df

    def get_arrays(self, model: PretrainedMixin, finetuned: bool, rs: int = 0):
        if finetuned:
            X_val, y_val, S_val, G_val = self.get_real_X_y(
                model=model, finetuned=finetuned, split="val"
            )
            X_test, y_test, S_test, G_test = self.get_real_X_y(
                model=model, finetuned=finetuned, split="test"
            )
            X = pd.concat([X_val, X_test], axis=0, ignore_index=True)
            y = pd.concat([y_val, y_test], axis=0, ignore_index=True)
            S = pd.concat([S_val, S_test], axis=0, ignore_index=True)
            G = pd.concat([G_val, G_test], axis=0, ignore_index=True)
            idx_val = np.arange(X_val.shape[0])
            idx_test = np.arange(X_val.shape[0], X.shape[0])

        else:
            X, y, S, G = self.get_real_X_y(model=model, finetuned=finetuned)

            idx_val, idx_test = self.get_train_test_split(groups=G, rs=rs)
            # idx = np.arange(X.shape[0])
            # idx_val, idx_test = train_test_split(idx, test_size=0.5, random_state=rs)

        X, y, S, G = X.to_numpy(), y.to_numpy(), S.to_numpy(), G.to_numpy()

        idx_val1, idx_val2 = self.get_train_val_split(groups=G[idx_val], rs=rs)
        idx_val1 = idx_val[idx_val1]
        idx_val2 = idx_val[idx_val2]
        # idx_val1, idx_val2 = train_test_split(idx_val, test_size=0.5, random_state=rs)

        # For scikit-tree
        X = X.astype(np.float32)
        y = y.astype(np.int32)
        S = S.astype(np.float32)
        G = G.astype(np.float32)

        return (X, y, S, G), (idx_val1, idx_val2, idx_test)

    def get_groups(self, dataset: Dataset) -> np.ndarray:
        return np.arange(dataset.num_rows)  # no group by default

    # def get_splits(
    #     self, groups: np.ndarray, rs: int = 0
    # ) -> (np.ndarray, np.ndarray, np.ndarray):
    #     """Generate train, val, and test splits.

    #     Parameters
    #     ----------
    #     test_prop : float
    #         Proportion of the total dataset to use for the test set.
    #     val_prop : float
    #         Proportion of the remaining dataset to use for the validation set.

    #     """
    #     # if groups is not None:
    #     groups = np.asarray(groups)
    #     n = groups.shape[0]
    #     # assert groups.shape[0] == n
    #     # props = np.asarray(props)
    #     # assert props.ndim == 1
    #     # assert props.shape[0] == 3
    #     # props /= props.sum()
    #     # n = dataset.num_rows
    #     idx = np.arange(n, dtype=int)

    #     # groups = self.get_groups(dataset)
    #     # if groups is None:
    #     #     splitter = ShuffleSplit(n_splits=1, test_size=0.5, random_state=rs)
    #     # else:
    #     splitter = GroupShuffleSplit(n_splits=1, test_size=0.5, random_state=rs)

    #     idx_train_val, idx_test = next(splitter.split(X=idx, groups=groups))
    #     idx_train, idx_val = next(
    #         splitter.split(X=idx_train_val, groups=groups[idx_train_val])
    #     )
    #     idx_train = idx_train_val[idx_train]
    #     idx_val = idx_train_val[idx_val]

    #     # idx_val, idx_test = train_test_split(idx, test_size=test_prop, random_state=rs)
    #     # idx_val1, idx_val2 = train_test_split(
    #     #     idx_val, test_size=val_prop, random_state=rs
    #     # )

    #     return (idx_train, idx_val, idx_test)

    def get_split(
        self,
        groups: np.ndarray,
        prop_test: float = 0.5,
        rs: int = 0,
    ) -> (np.ndarray, np.ndarray, np.ndarray):
        """Generate train, val, and test splits.

        Parameters
        ----------
        test_prop : float
            Proportion of the total dataset to use for the test set.
        val_prop : float
            Proportion of the remaining dataset to use for the validation set.

        """
        # if groups is not None:
        groups = np.asarray(groups)
        n = groups.shape[0]
        # assert groups.shape[0] == n
        idx = np.arange(n, dtype=int)
        # if groups is None:
        #     splitter = ShuffleSplit(n_splits=1, test_size=prop_test, random_state=rs)
        # else:
        splitter = GroupShuffleSplit(n_splits=1, test_size=prop_test, random_state=rs)
        idx_train, idx_test = next(splitter.split(X=idx, groups=groups))

        return (idx_train, idx_test)

    def get_train_test_split(
        self, groups: np.ndarray, rs: int = 0
    ) -> (np.ndarray, np.ndarray):
        return self.get_split(groups=groups, prop_test=self.prop_test, rs=rs)

    def get_train_val_split(
        self, groups: np.ndarray, rs: int = 0
    ) -> (np.ndarray, np.ndarray):
        return self.get_split(groups=groups, prop_test=self.prop_val, rs=rs)

    def get_train_val_test_split(
        self, groups: np.ndarray, rs: int = 0
    ) -> (np.ndarray, np.ndarray):
        idx_train_val, idx_test = self.get_train_test_split(groups, rs)
        idx_train, idx_val = self.get_train_val_split(groups[idx_train_val], rs)
        idx_train = idx_train_val[idx_train]
        idx_val = idx_train_val[idx_val]
        return idx_train, idx_val, idx_test

    def get_params(self):
        """Get the parameters used to instantiate the dataset."""
        return {}


ds_registry: Dict[str, Type[BaseDataset | ForwardedMixin]] = {}
ds_rename: Dict[str, str] = {}
register_ds = partial(register, registry=ds_registry, rename_registry=ds_rename)


def get_path(ds: ForwardedMixin, model: PretrainedMixin, root: str) -> Path:
    ds_dirname = ds.ds_name.replace("/", ":")
    model_dirname = model.model_name.replace("/", ":") + ":" + model.__class__.__name__
    path = Path(root) / ds_dirname / model_dirname
    return path


def get_processed_dataset_path(ds: ForwardedMixin, model: PretrainedMixin) -> Path:
    return get_path(ds, model, PROCESSED_DATASET_PATH)


def get_finetuned_dataset_path(ds: ForwardedMixin, model: PretrainedMixin) -> Path:
    return get_path(ds, model, FINETUNED_DATASET_PATH)


def get_finetuned_model_path(ds: ForwardedMixin, model: PretrainedMixin) -> Path:
    return get_path(ds, model, FINETUNED_MODEL_PATH)
