from dataclasses import dataclass, field

import numpy as np
import shap
import torch
from gpytorch.kernels import RBFKernel
from gpytorch.lazy import lazify
from sklearn.datasets import load_diabetes
from sklearn.ensemble import RandomForestRegressor
from sklearn.neural_network import MLPRegressor
from torch import FloatTensor

from src.explanation_algorithms.BayesGPSHAP import BayesGPSHAP
from src.gp_model.VariationalGPRegression import VariationalGPRegression
from src.predictive_explanation.ShapleyKernel import ShapleyKernel


def generate_data():
    dataset = load_diabetes()
    X, y = dataset.data, dataset.target
    random_index = np.random.choice(range(X.shape[0]), size=X.shape[0], replace=False)
    X, y = X[random_index], y[random_index]
    feature_names = dataset.feature_names

    X, y = torch.tensor(X).float(), torch.tensor(y).float()
    scale = y.std().numpy()
    y = (y - y.mean()) / y.std()

    return X, y, scale, feature_names


def _compute_rf_explanations(X, y):
    rf = RandomForestRegressor(n_estimators=1000)
    rf.fit(X.numpy(), y.numpy())

    rf_shap = shap.TreeExplainer(rf)
    return rf_shap.shap_values(X.numpy())


def loss_function(pred, true):
    return torch.mean((true - pred) ** 2)


def _compute_deep_explanations(X, y):
    # build a deep model and obtain explanations
    nn = MLPRegressor(solver='lbfgs', alpha=1e-1, hidden_layer_sizes=(10, 10))
    nn.fit(X, y)

    X_train_summary = shap.kmeans(X, 10)
    deep_shap = shap.KernelExplainer(nn.predict, X_train_summary)
    return deep_shap.shap_values(X.numpy())


def _predict_explanations_using_rf(X, num_train, target_train):
    X_train = X[:num_train]
    X_test = X[num_train:]

    rf = RandomForestRegressor(n_estimators=1000)
    rf.fit(X_train, target_train)
    return rf.predict(X_test)


def _predict_explanations_using_deep(X, num_train, target_train):
    X_train = X[:num_train]
    X_test = X[num_train:]

    nn = MLPRegressor(solver='lbfgs', alpha=1e-1, hidden_layer_sizes=(10, 10))
    nn.fit(X_train, target_train)

    return nn.predict(X_test)


@dataclass
class PredictiveExplanationExperiment(object):
    # num_cv: int = field(default=5)
    train_ratio: float = field(default=0.7)
    target_explanation_type: str = field(default="gp")

    gp_explanations: FloatTensor = field(init=False)
    rf_explanations: np.ndarray = field(init=False)
    deep_explanations: np.ndarray = field(init=False)

    gp_explanations_gp_predictions: FloatTensor = field(init=False)

    def run(self):
        X, y, scale, feature_names = generate_data()
        self.X, self.y = X, y
        n, d = X.shape
        num_train = int(self.train_ratio * n)
        # X_train, y_train = X[:num_train], y[:num_train]
        # X_test, y_test = X[num_train:], y[num_train:]

        # build a gp regression and obtain explanations
        self.gp_explanations = self._compute_gp_explanations(X, y, scale)

        if self.target_explanation_type == "gp":
            target = self.gp_explanations.reshape(-1, 1).detach()
        elif self.target_explanation_type == "rf":
            self.rf_explanations = _compute_rf_explanations(X, y)
            target = self.rf_explanations.reshape(-1, 1)
        elif self.target_explanation_type == "deep":
            self.deep_explanations = _compute_deep_explanations(X, y)
            target = self.deep_explanations.reshape(-1, 1)

        target_test = target[num_train * d:]
        self.target_test = target_test

        # for matrix input
        target_matrix = self.gp_explanations.detach().numpy()

        self.explanation_prediction_gp = self._predict_explanations_using_gp(
            X, target, num_train=num_train, target_type=self.target_explanation_type)
        self.explanation_prediction_rf = _predict_explanations_using_rf(X, num_train, target_matrix[:num_train, :])
        self.explanation_prediction_deep = _predict_explanations_using_deep(X, num_train, target_matrix[:num_train, :])

    def _predict_explanations_using_gp(self, X, target, num_train, target_type="rf"):
        d = X.shape[1]
        target = torch.tensor(target).float()
        target_train = target[:num_train * d]

        # run Shapley prior
        shapley_kernel = ShapleyKernel(
            train_X=X, kernel=RBFKernel(), lengthscales=self.lengthscales,
            inducing_points=self.inducing_points, num_coalitions=2 ** (d - 1), sampling_method="subsampling",
            verbose=False
        )

        learning_rate = 1e-3
        iter_rounds = 200

        optim = torch.optim.Adam(shapley_kernel.parameters(), lr=learning_rate)
        for rd in range(iter_rounds):
            optim.zero_grad()
            Psi = shapley_kernel(X)
            K = torch.einsum("ijk,lmn->imkn", Psi, Psi.transpose(0, 1))
            K = K.permute(2, 0, 3, 1).resize(len(target), len(target))
            K_train = K[:num_train * d, :num_train * d]
            prediction = K_train @ lazify(K_train).add_diag(shapley_kernel.krr_regularisation).inv_matmul(target_train)
            loss = loss_function(prediction, target_train)
            if rd % 20 == 0:
                print(loss)
            loss.backward()
            optim.step()

        # build prediction
        K_test_train = K[num_train * d:, :num_train * d]
        return (K_test_train @ lazify(K_train).add_diag(
            shapley_kernel.krr_regularisation).inv_matmul(
            target_train)).reshape(-1, d)

    def _compute_gp_explanations(self, X, y, scale):
        kernel = RBFKernel
        d = X.shape[1]
        gp_regression = VariationalGPRegression(X, y, kernel=kernel, num_inducing_points=200,
                                                batch_size=128)
        gp_regression.fit(learning_rate=1e-2, training_iteration=300)

        gp_shap = BayesGPSHAP(train_X=X, kernel=RBFKernel(), gp_model=gp_regression,
                              include_likelihood_noise_for_explanation=False, scale=scale)
        gp_shap.run_bayesSHAP(X=X, num_coalitions=2 ** d)

        self.lengthscales = gp_regression.lengthscale
        self.inducing_points = gp_regression.inducing_points

        return gp_shap.mean_shapley_values.t()  # of shape [num_data, num_feature]
