import pandas as pd
import numpy as np
from sklearn.preprocessing import OneHotEncoder, MinMaxScaler
from sklearn.metrics import pairwise_distances_argmin_min
from tqdm import tqdm
from sklearn.utils import check_random_state


class DCR:
    name = "dcr"
    data_requirement = "train_and_test"
    needs_discrete_features = True
    needs_random_state = True

    def __init__(
        self,
        discrete_features: list = [],
        subsample_test_size: bool = True,
        max_rows: int = 50000,
        dcr_type: str = "proportion",  # "avg" or "proportion",
        random_state: int = 0,
    ):
        super().__init__()
        self.discrete_features = discrete_features
        self.subsample_test_size = subsample_test_size
        self.max_rows = max_rows
        self.dcr_type = dcr_type
        self.rng = check_random_state(random_state)

    def evaluate(self, train: pd.DataFrame, test: pd.DataFrame, sd: pd.DataFrame):
        # compare training set to same size synthetic set
        numerical_features = [
            col for col in train.columns if col not in self.discrete_features
        ]

        assert len(test) <= len(
            train
        ), "Test set must be smaller than or equal to train size to compute DCR"

        # scale the data to fixed range
        ohe = OneHotEncoder(sparse_output=False)
        if len(self.discrete_features) > 0:
            ohe.fit(
                pd.concat(
                    [
                        train[self.discrete_features],
                        test[self.discrete_features],
                        sd[self.discrete_features],
                    ],
                    axis=0,
                )
            )
        else:
            ohe = None
        scaler = MinMaxScaler()
        if len(numerical_features) > 0:
            scaler.fit(train[numerical_features])
        else:
            scaler = None
        data = {}
        for df, name in zip([train, test, sd[: len(train)]], ["train", "test", "syn"]):
            if ohe is not None:
                cat = ohe.transform(df[self.discrete_features])
                cat = cat / 2  # scaling for Gower distance
            else:
                cat = None
            if scaler is not None:
                num = scaler.transform(df[numerical_features])
            else:
                num = None
            if cat is not None and num is not None:
                data[name] = np.concatenate((cat, num), axis=1)
            elif cat is not None:
                data[name] = cat
            elif num is not None:
                data[name] = num

        # subsample the dataset to either max rows or test size
        num_rows_subsample = (
            min(len(data["test"]), self.max_rows)
            if self.subsample_test_size
            else self.max_rows
        )
        num_rows_subsample = min(num_rows_subsample, len(data["train"]))
        num_iterations = int(np.ceil(len(data["train"]) / num_rows_subsample))

        results = (
            self._dcr_prop(data, num_iterations, num_rows_subsample)
            if self.dcr_type == "proportion"
            else self._dcr_avg(data, num_iterations, num_rows_subsample)
        )
        return results

    def _dcr_prop(
        self, data: dict, num_iterations: int, num_rows_subsample: int
    ) -> tuple[np.ndarray, np.ndarray]:
        closer_to_train = []
        closer_to_test = []
        for _ in tqdm(range(num_iterations)):

            syn_curr = self._choose(data["syn"], num_rows_subsample)

            _, d_s_tr = pairwise_distances_argmin_min(
                syn_curr,
                self._choose(data["train"], num_rows_subsample),
                metric="cityblock",
            )
            # align test set size to never exceed subsampled set size
            _, d_s_te = pairwise_distances_argmin_min(
                syn_curr,
                self._choose(data["test"], min(len(data["test"]), num_rows_subsample)),
                metric="cityblock",
            )

            closer_to_train_ = np.mean(d_s_tr < d_s_te)
            closer_to_train.append(closer_to_train_)
            closer_to_test_ = 1 - closer_to_train_
            closer_to_test.append(closer_to_test_)
        scores = [min(1, x * 2) for x in closer_to_test]

        return {
            "dcr.score": np.mean(scores),
            "dcr.closer_to_train": np.mean(closer_to_train),
            "dcr.closer_to_test": np.mean(closer_to_test),
        }

    def _dcr_avg(
        self, data: dict, num_iterations: int, num_rows_subsample: int
    ) -> tuple[np.ndarray, np.ndarray]:
        avg_cl_train_to_syn = []
        avg_cl_train_to_test = []
        for _ in tqdm(range(num_iterations)):

            train_curr = self._choose(data["train"], num_rows_subsample)

            # train-synthetic
            _, d_tr_syn = pairwise_distances_argmin_min(
                train_curr,
                self._choose(data["syn"], num_rows_subsample),
                metric="cityblock",
            )
            avg_cl_train_to_syn.append(np.mean(d_tr_syn))
            # train-test
            _, d_tr_test = pairwise_distances_argmin_min(
                train_curr,
                self._choose(data["test"], num_rows_subsample),
                metric="cityblock",
            )
            avg_cl_train_to_test.append(np.mean(d_tr_test))

        return {
            "dcr.train-to-syn.avg": np.mean(avg_cl_train_to_syn),
            "dcr.train-to-test.avg": np.mean(avg_cl_train_to_test),
        }

    def _choose(self, x: np.ndarray, n: int) -> np.ndarray:
        return x[self.rng.choice(len(x), n, replace=False)]
