import numpy as np

import pandas as pd
from pandas.core.frame import DataFrame

from sklearn.linear_model import LogisticRegression
from sklearn import metrics, svm
from sklearn.ensemble import RandomForestClassifier
from sklearn.neural_network import MLPClassifier

from sklearn.experimental import enable_iterative_imputer
from sklearn.impute import IterativeImputer
from sklearn.ensemble import ExtraTreesRegressor
from sklearn.linear_model import BayesianRidge

from scipy.spatial.distance import cdist
from scipy.optimize import linear_sum_assignment

from sklearn.metrics.pairwise import rbf_kernel

import tensorflow as tf
from sklearn.neighbors import KNeighborsClassifier


def compute_weights(labels):
    weights = np.ones(len(labels))
    weights[labels == 0] /= sum(labels == 0)
    weights[labels == 1] /= sum(labels == 1)
    weights = weights / 2.0 * len(labels)
    return weights


def accuracy(probs, target):
    return 1 - sum(((probs.squeeze() > 0.5) - target) ** 2) / len(target)


def normalization(data, norm_params=None, norm_type="standard"):
    """Normalize data in [0, 1] range.

    Args:
      - data: original data

    Returns:
      - norm_data: normalized data
      - norm_parameters: min_val, max_val for each feature for renormalization
    """

    # Parameters
    _, dim = data.shape
    norm_data = data

    if isinstance(norm_data, pd.DataFrame):
        norm_data = norm_data.values

    if norm_type == "minmax":

        # MixMax normalization
        min_val = np.zeros(dim)
        max_val = np.zeros(dim)

        if norm_params == None:
            # For each dimension
            for i in range(dim):
                min_val[i] = np.nanmin(norm_data[:, i])
                max_val[i] = np.nanmax(norm_data[:, i])
                if data.shape[1] == 784:
                    min_val[i] = 0
                    max_val[i] = 255
                norm_data[:, i] = norm_data[:, i] - np.nanmin(norm_data[:, i])
                norm_data[:, i] = norm_data[:, i] / (np.nanmax(norm_data[:, i]) + 1e-6)
        else:
            for i in range(dim):
                norm_data[:, i] = norm_data[:, i] - norm_params["min_val"][i]
                norm_data[:, i] = norm_data[:, i] / (norm_params["max_val"][i] + 1e-6)

        # Return norm_parameters for renormalization
        norm_params = {"min_val": min_val, "max_val": max_val}

    elif norm_type == "standard":

        if norm_params == None:
            mu = np.nanmean(norm_data, axis=0)
            std = np.nanstd(norm_data, axis=0) + 1e-6

            norm_params = {"mu": mu, "std": std}

        norm_data = (norm_data - norm_params["mu"]) / norm_params["std"]

    return norm_data, norm_params


def rmse_loss(imputed_data, ori_data, data_m, norm=True):
    """Compute RMSE loss between ori_data and imputed_data for missing data

    Args:
    - ori_data: original data without missing values
    - imputed_data: imputed data
    - data_m: indicator matrix for missingness, 1 if observation is missing

    Returns:
    - rmse: Root Mean Squared Error
    """

    if norm:
        ori_data, norm_parameters = normalization(ori_data.copy(), None, "minmax")
        imputed_data, _ = normalization(imputed_data.copy(), norm_parameters, "minmax")

    # Only for missing values
    nominator = np.sum(
        (((data_m) * np.nan_to_num(ori_data) - (data_m) * imputed_data) ** 2)
    )
    denominator = np.sum(data_m) + 1e-7

    rmse = np.sqrt(nominator / float(denominator))

    return rmse


def downstream_classification(
    df_trn,
    df_tst,
    method="logreg",
    iw_weight=None,
    beta=None,
    run=1,
):
    """Logistic regression trained on private data and evaluated on true test data."""

    df_trn = df_trn.values
    df_tst = df_tst.values
    num_classes = len(np.unique(df_tst[:, -1]))

    if method == "logreg":
        model = LogisticRegression(
            solver="lbfgs", max_iter=4000
        )  # , class_weight='balanced'

    elif method == "svm":
        model = svm.SVC(probability=True, C=1)  # , class_weight='balanced'

    elif method == "rf":
        model = RandomForestClassifier(
            min_samples_leaf=3,
        )

    elif method == "knn":
        df_trn_idx = np.random.choice(
            len(df_trn),
            size=len(df_trn) * 2,
            replace=True,
            p=iw_weight / len(iw_weight),
        )
        df_trn = df_trn[df_trn_idx]
        model = KNeighborsClassifier(n_neighbors=5, weights="distance")
        model.fit(df_trn[:, :-1], df_trn[:, -1].round())

    elif method == "nn":
        model = tf.keras.Sequential(
            [
                tf.keras.layers.Dense(df_trn.shape[1], activation="relu"),
                tf.keras.layers.Dense(num_classes),
                tf.keras.layers.Softmax(),
            ]
        )
        model.compile(
            optimizer="adam",
            loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
            metrics=["accuracy"],
        )
        model.fit(
            df_trn[:, :-1], df_trn[:, -1].round(), sample_weight=iw_weight, epochs=10
        )

    elif method == "beta_mse":
        model_true = LogisticRegression(
            solver="lbfgs", max_iter=4000, class_weight="balanced"
        )  #
        model_true.fit(df_tst[:, :-1], df_tst[:, -1].round())
        model_synth = LogisticRegression(
            solver="lbfgs", max_iter=4000, class_weight="balanced"
        )  #
        model_synth.fit(df_trn[:, :-1], df_trn[:, -1].round(), sample_weight=iw_weight)
        coef0 = np.concatenate((model_true.coef_, model_true.intercept_[None]), 1)
        coef1 = np.concatenate((model_true.coef_, model_true.intercept_[None]), 1)
        return ((coef0 - coef1) ** 2).mean()

    else:
        raise KeyError("Model not defined")

    if method not in ["nn", "beta_mse", "knn"]:
        model.fit(df_trn[:, :-1], df_trn[:, -1].round(), sample_weight=iw_weight)

    pred_proba = model.predict_proba(df_tst[:, :-1])

    if beta is not None and method == "logreg":
        return sum((model.coef_[0] - beta) ** 2)  #! might be -model.coef_[0]

    # acc = accuracy(model.predict_proba(df_trn[:, :-1])[:, 1], df_trn[:, -1].round())
    # print(f"train Logreg Accuracy is {acc:.2f}")

    # acc = accuracy(model.predict_proba(df_tst[:, :-1])[:, 1], df_tst[:, -1].round())
    # print(f"test Logreg Accuracy is {acc:.2f}")
    if num_classes > 2:
        auc = metrics.roc_auc_score(
            df_tst[:, -1].astype("int"), pred_proba, multi_class="ovr"
        )
    else:
        try:
            auc = metrics.roc_auc_score(df_tst[:, -1].astype("int"), pred_proba[:, 1])
        except IndexError:
            print("unique labels: ", np.unique(df_trn[:, -1]))
            print("all labels: ", df_trn[:, -1])
            print(pred_proba.shape)
            print(pred_proba)
            print(df_tst.shape)
            auc = metrics.roc_auc_score(df_tst[:, -1].astype("int"), pred_proba)
            raise IndexError
    print(f"Auc is {auc:.2f}")

    # brier = metrics.brier_score_loss(df_tst[:, -1], pred_proba[:,1])#, average='weighted')
    # print(f'Brier is {brier:.2f}')

    # if auc < .5 and run < 20:
    #     return(downstream_classification(pd.DataFrame(df_trn), pd.DataFrame(df_tst), method, iw_weight, nbias, run+1))

    return auc  # brier #auc,


def downstream_regression(
    df_trn,
    df_tst,
    method="linear",
    iw_weight=None,
    beta=None,
    run=1,
):
    """Logistic regression trained on private data and evaluated on true test data."""

    df_trn = df_trn.values
    df_tst = df_tst.values

    if method == "linear":
        model = LinearRegression()  # , class_weight='balanced'

    elif method == "nn":
        model = tf.keras.Sequential(
            [
                tf.keras.layers.Dense(df_trn.shape[1], activation="relu"),
                tf.keras.layers.Dense(1),
            ]
        )
        model.compile(
            optimizer="adam",
            loss=tf.keras.losses.MeanSquaredError(),
            metrics=[tf.keras.metrics.MeanSquaredError()],
        )
        model.fit(
            df_trn[:, :-1], df_trn[:, -1], sample_weight=iw_weight, epochs=50
        )

    elif method == "beta_mse":
        model_true = LinearRegression(class_weight="balanced")  #
        model_true.fit(df_tst[:, :-1], df_tst[:, -1].)
        model_synth = LogisticRegression(class_weight="balanced")  #
        model_synth.fit(df_trn[:, :-1], df_trn[:, -1]., sample_weight=iw_weight)
        coef0 = np.concatenate((model_true.coef_, model_true.intercept_[None]), 1)
        coef1 = np.concatenate((model_true.coef_, model_true.intercept_[None]), 1)
        return ((coef0 - coef1) ** 2).mean()

    else:
        raise KeyError("Model not defined")

    if method not in ["nn", "beta_mse", "knn"]:
        model.fit(df_trn[:, :-1], df_trn[:, -1], sample_weight=iw_weight)

    mse = sklearn.metrics.mean_squared_error(df_tst[:, -1], model.predict(df_tst[:, :-1]))
    print(f"mse is {mse:.2f}")

    return mse



def downstream_MCAR_imputation(
    df_trn, df_tst, method="logreg", iw_weight=None, seed=42
):

    np.random.seed(seed)
    M = np.random.uniform(size=df_tst.shape) < 0.2
    df_tst[M] = np.nan
    if method == "mice":
        imputer = IterativeImputer(
            random_state=seed, max_iter=10, estimator=BayesianRidge()
        )
    elif method == "missForest":
        imputer = IterativeImputer(
            random_state=seed,
            max_iter=10,
            estimator=ExtraTreesRegressor(n_estimators=10, n_jobs=2),
        )
    else:
        raise KeyError("Model not defined")
    df_trn.dropna(inplace=True)
    try:
        imputer.fit(df_trn, sample_weight=iw_weight)
        imp_df_tst = imputer.transform(df_tst)
        return rmse_loss(pd.DataFrame(imp_df_tst), df_tst, M)
    except:
        return np.nan


def wasserstein_distance(x1, x2, w_x2=None):
    d_mat = cdist(x1, x2)
    if w_x2 is not None:
        d = np.einsum("j, ij -> ij", w_x2, d_mat)
    else:
        d = d_mat

    assignment = linear_sum_assignment(d)  # this line is really slow
    return d[assignment].sum() / len(x1)


def compute_kernel(x, y):
    return rbf_kernel(x, y)


def compute_mmd(x, y, weights_y=None):
    if weights_y is None:
        weights_y = np.ones(len(y))
    x_kernel = compute_kernel(x, x)
    y_kernel = compute_kernel(y, y)
    weighted_y_kernel = np.einsum("i, ij, j -> ij", weights_y, y_kernel, weights_y)
    xy_kernel = compute_kernel(x, y)
    weighted_xy_kernel = np.einsum("ij, j -> ij", xy_kernel, weights_y)
    mmd = x_kernel.mean() + weighted_y_kernel.mean() - 2 * weighted_xy_kernel.mean()
    return mmd
