import numpy as np
import torch
from torch.nn import ModuleList, Module
from torch_geometric.nn import MessagePassing, global_add_pool
# Manual Seed for Reproducibility
from tqdm import trange

from models.explanation_generator import ExplanationGenerator
from utils.utils import linear_combo_features, normalize

torch.manual_seed(0)

def normalized(a, axis=-1, order=2):
    l2 = np.atleast_1d(np.linalg.norm(a, order, axis))
    l2[l2 == 0] = 1
    return a / np.expand_dims(l2, axis)


def tree_predict(x, tree, num_classes, linear_feature_combinations=False, pooling=False):
    data = x.cpu().detach().numpy()
    if linear_feature_combinations:
        num_features = len(data[0]) // 2
        if pooling:
            num_features = len(data[0])
        data = linear_combo_features(data, num_features)
    x = tree.predict(data)
    one_hot = np.zeros((x.size, num_classes))
    one_hot[np.arange(x.size), x] = 1
    return torch.from_numpy(one_hot)


def tree_predict_explain(x, tree, num_classes, linear_feature_combinations=False, pooling=False):
    data = x.cpu().detach().numpy()
    if linear_feature_combinations:
        num_features = len(data[0]) // 2
        if pooling:
            num_features = len(data[0])
        data = linear_combo_features(data, num_features)
    explained = data * tree.feature_importances_
    x = tree.predict(data)
    one_hot = np.zeros((x.size, num_classes))
    one_hot[np.arange(x.size), x] = 1
    return torch.from_numpy(one_hot), explained.sum(-1), x


def check_switched_prediction(estimator, sample, feature_index, threshold, target, message):
    sample_switched = np.array(sample)
    if sample[feature_index] > threshold:
        sample_switched[feature_index] = 0
    else:
        if message:
            sample_switched[len(sample_switched) // 2:] = 0
        sample_switched[feature_index] = 1

    prediction_switched = estimator.predict([sample_switched])[0]

    return target == prediction_switched


import shap


def get_decision_path_features(estimator, data, num_classes=0, message=False, num_features=0, pooling_gc=False):
    feature = estimator.tree_.feature
    node_indicator = estimator.decision_path(data)
    leave_id = estimator.apply(data)
    threshold = estimator.tree_.threshold
    features_used = np.zeros((np.shape(data)[0], np.shape(data)[1]))
    # features_used = np.ones((np.shape(data)[0], np.shape(data)[1]))
    for sample_id in range(len(data)):
        node_index = node_indicator.indices[node_indicator.indptr[sample_id]:
                                            node_indicator.indptr[sample_id + 1]]

        for node_id in node_index:
            # If Current Node is a Leaf -> Skip
            if leave_id[sample_id] == node_id:
                continue

            features_used[sample_id][feature[node_id]] = 1.0

    return features_used

def get_features_used(estimator, data, num_classes=0, message=False, num_features=0, pooling_gc=False):
    features_used = np.zeros((np.shape(data)[0], num_classes, np.shape(data)[1]))
    predictions = estimator.predict(data)
    explainer = shap.TreeExplainer(estimator, feature_perturbation="tree_path_dependent")

    shap_values = np.array(explainer.shap_values(data))
    if shap_values.ndim != 3:
        shap_values = np.array([shap_values])

    decision_path_features = get_decision_path_features(estimator, data, num_classes, message, num_features, pooling_gc)
    for sample_id in range(len(data)):
        for class_index in range(len(shap_values)):
            current_class = estimator.classes_[class_index]
            multiplies = 1.0
            if current_class != predictions[sample_id]:
                multiplies = -1.0

            for feature_number in range(np.shape(data)[1]):
                features_used[sample_id][current_class][feature_number] = \
                    multiplies * shap_values[class_index][sample_id][feature_number] \

    return features_used


class DTModule(Module):
    def __init__(self, tree, name, state_size, linear_feature_combinations=False, pooling=False):
        super(DTModule, self).__init__()
        self.__name__ = name
        self.tree = tree
        self.state_size = state_size
        self.linear_feature_combinations = linear_feature_combinations
        self.pooling = pooling

    def forward(self, x):
        x = tree_predict(x, self.tree, self.state_size, linear_feature_combinations=self.linear_feature_combinations, pooling=self.pooling)
        return x

class StoneAgeGNNDT(torch.nn.Module):
    def __init__(self, in_channels, out_channels, bounding_parameter, trees, state_size, use_pooling=True,
                 num_layers=1, skip_connection=False, linear_feature_combinations=False):
        super().__init__()

        self.input = DTModule(trees['input'], 'input', state_size)
        self.pooling = DTModule(trees['output'], 'output', out_channels, use_pooling, use_pooling)
        self.tree_depths = [self.input.tree.get_depth()]
        self.stone_age = ModuleList()
        self.out_channels = out_channels
        self.state_size = state_size
        self.use_pooling = use_pooling
        self.skip_connection = skip_connection
        self.trees = trees
        self.explanation = []
        self.explanation_forward = []
        self.outputs = []
        self.inputs = []
        self.contributions = []
        self.features_used = []
        self.explanation_forward_v2 = []
        self.num_layers = num_layers
        self.num_nodes = 0
        self.edge_index = None
        self.subset = None
        self.backprop_cache = {}
        self.state_neighbours = []
        self.linear_feature_combinations = linear_feature_combinations
        for i in range(num_layers):
            self.tree_depths.append(trees[f"stone_age.{i}.linear_softmax"].get_depth())
            self.stone_age.append(
                StoneAgeGNNLayerDT(bounding_parameter=bounding_parameter,
                                   tree=trees[f"stone_age.{i}.linear_softmax"],
                                   index=i, state_size=self.state_size,
                                   linear_feature_combinations=linear_feature_combinations))
        self.tree_depths.append(self.pooling.tree.get_depth())
    def reset_explanation(self):
        self.outputs = []
        self.inputs = []
        self.explanation = []
        self.explanation_forward = []
        self.contributions = []
        self.features_used = []
        self.explanation_forward_v2 = []
        self.subset = None
        self.backprop_cache = {}
        self.state_neighbours = []

    def update_trees(self, new_trees):
        self.tree_depths = []
        self.trees = new_trees
        self.input.tree = new_trees['input']
        self.tree_depths = [self.input.tree.get_depth()]
        self.pooling.tree = new_trees['output']
        for i in range(self.num_layers):
            self.tree_depths.append(new_trees[f"stone_age.{i}.linear_softmax"].get_depth())
            self.stone_age[i].tree.tree = new_trees[f"stone_age.{i}.linear_softmax"]
        self.tree_depths.append(self.pooling.tree.get_depth())

    def get_decision_paths(self):
        decision_paths = []
        for layer_index in range(self.num_layers + 2):
            decision_paths_layer = []
            if layer_index == 0:
                estimator = self.input.tree
            elif layer_index == self.num_layers + 1:
                estimator = self.pooling.tree
            else:
                estimator = self.stone_age[layer_index - 1].tree.tree

            data = self.inputs[layer_index]
            node_indicator = estimator.decision_path(data)
            for sample_id in range(len(data)):
                node_index = node_indicator.indices[node_indicator.indptr[sample_id]:
                                                    node_indicator.indptr[sample_id + 1]]
                decision_paths_layer.append(node_index)
            decision_paths.append(decision_paths_layer)
        return decision_paths

    def get_neighbours_for_state(self, num_nodes, edge_index, layer_index):
        state_neighbours = np.zeros((num_nodes, self.state_size))
        for i in range(len(edge_index[0])):
            message_from = edge_index[0][i]
            message_to = edge_index[1][i]
            state_neighbours[message_to][self.outputs[layer_index - 1][message_from]] += 1
        return state_neighbours

    def explain(self, x, edge_index, batch=None, subset=None, log_progress=False, backward=False):

        explanation_generator = ExplanationGenerator(x, edge_index, self)
        self.inputs.append(x.cpu().detach().numpy())
        self.state_neighbours.append(np.zeros((len(x), self.state_size)))
        self.num_nodes = len(x)
        self.edge_index = edge_index
        self.subset = subset
        features_used = get_features_used(self.input.tree, x.cpu().detach().numpy(), num_classes=self.state_size)
        self.features_used.append(features_used)
        explanation_forward_v2 = explanation_generator.forward_explain(0)
        self.explanation_forward_v2.append(explanation_forward_v2)
        x, explained, outputs = tree_predict_explain(x, self.input.tree, self.state_size)
        self.outputs.append(outputs)
        xs = [x]
        layer_index = 0

        if backward:
            importance_backwards = np.array(
                [normalize(explanation_generator.backpropagate_explanation(i, 1.0, layer_index, outputs[i])) for i in
                 trange(len(x), desc='Input', disable=(not log_progress))])
            self.explanation.append(importance_backwards)
        else:
            forward_importance = np.array([normalize(explanation_forward_v2[i][outputs[i]]) for i in range(len(x))])
            self.explanation.append(forward_importance)

        for layer in self.stone_age:
            x = layer(x, edge_index, explain=True)
            outputs = x.argmax(dim=-1).cpu().detach().numpy()
            self.outputs.append(outputs)
            self.inputs.append(layer.tree_input)
            self.features_used.append(layer.features_used)

            layer_index += 1

            self.state_neighbours.append(self.get_neighbours_for_state(len(x), edge_index, layer_index))
            explanation_forward_v2 = explanation_generator.forward_explain(layer_index)
            self.explanation_forward_v2.append(explanation_forward_v2)

            if backward:
                importance_backwards = np.array(
                    [normalize(explanation_generator.backpropagate_explanation(i, 1.0, layer_index, outputs[i])) for i in
                     trange(len(x), desc=f'Layer {layer_index - 1}', disable=(not log_progress))])
                self.explanation.append(importance_backwards)
            else:
                forward_importance = np.array([normalize(explanation_forward_v2[i][outputs[i]]) for i in range(len(x))])
                self.explanation.append(forward_importance)
            xs.append(x)

        self.state_neighbours.append(self.get_neighbours_for_state(len(x), edge_index, layer_index))

        if self.use_pooling:
            if batch is None:
                batch = torch.from_numpy(np.array([0 for _ in range(len(x))]))
            x = global_add_pool(x, batch)
            xs = [global_add_pool(xi, batch) for xi in xs]
            # x = global_max_pool(x, batch)
            # xs = [global_max_pool(xi, batch) for xi in xs]

        if self.skip_connection:
            x = torch.cat(xs, dim=1)

        pooling_input = x.cpu().detach().numpy()
        if self.linear_feature_combinations and self.use_pooling:
            num_features = len(pooling_input[0])
            pooling_input = linear_combo_features(pooling_input, num_features)
        self.inputs.append(pooling_input)
        features_used = get_features_used(self.pooling.tree, pooling_input, num_classes=self.out_channels,
                                          pooling_gc=self.use_pooling)
        self.features_used.append(features_used)
        x, explained, _ = tree_predict_explain(x, self.pooling.tree, self.out_channels, linear_feature_combinations=(self.use_pooling and self.linear_feature_combinations), pooling=self.use_pooling)

        outputs = x.argmax(dim=-1).cpu().detach().numpy()
        self.outputs.append(outputs)

        if self.use_pooling:
            explanation_forward_v2 = explanation_generator.get_forward_explanation_gc(skip_connection=self.skip_connection)
        else:
            explanation_forward_v2 = explanation_generator.get_forward_explanation()

        if backward:
            if self.use_pooling:
                gc_importance = explanation_generator.get_explanation_gc(1.0, skip_connection=self.skip_connection)
                importance = np.array([gc_importance for _ in range(self.num_nodes)])
            else:
                importance = np.array(
                    [explanation_generator.get_explanation(i, 1.0) for i in
                     trange(len(x), desc='Output', disable=(not log_progress))])
            self.explanation.append(importance)
        else:
            importance = np.array([normalize(explanation_forward_v2[i]) for i in range(self.num_nodes)])
            self.explanation.append(importance)
        return x, importance

    def forward(self, x, edge_index, batch=None, **kwargs):
        x = self.input(x)
        xs = [x]
        for layer in self.stone_age:
            x = layer(x, edge_index)
            xs.append(x)

        if self.use_pooling:
            if batch is None:
                batch = torch.from_numpy(np.array([0 for _ in range(len(x))]))
            x = global_add_pool(x, batch)
            xs = [global_add_pool(xi, batch) for xi in xs]
            # x = global_max_pool(x, batch)
            # xs = [global_max_pool(xi, batch) for xi in xs]
        if self.skip_connection:
            x = torch.cat(xs, dim=1)

        x = self.pooling(x)
        return x


class StoneAgeGNNLayerDT(MessagePassing):

    def __init__(self, bounding_parameter, tree, state_size, index=0, linear_feature_combinations=False):
        super().__init__(aggr='add')
        self.__name__ = 'stone-age-' + str(index)
        self.tree = DTModule(tree, 'linear_softmax', state_size, linear_feature_combinations)
        self.bounding_parameter = bounding_parameter
        self.state_size = state_size
        self.messages = None
        self.edge_index = None
        self.importance = None
        self.inputs = None
        self.tree_input = None
        self.features_used = None
        self.linear_feature_combinations = linear_feature_combinations

    def forward(self, x, edge_index, explain=False):
        if explain:
            self.edge_index = edge_index.cpu().detach().numpy()
        return self.propagate(edge_index, x=x, explain=explain)

    def message(self, x_j, explain=False):
        if explain:
            self.messages = x_j.cpu().detach().numpy()
        return x_j

    def aggregate(self, inputs, index, ptr, dim_size, explain=False):
        message_sums = super().aggregate(inputs, index, ptr=ptr, dim_size=dim_size)
        if explain:
            self.inputs = message_sums.cpu().detach().numpy()
        return torch.clamp(message_sums, min=0, max=self.bounding_parameter)

    def update(self, inputs, x, explain=False):

        combined = torch.cat((inputs, x), 1)
        self.tree_input = combined.cpu().detach().numpy()

        if self.linear_feature_combinations:
            num_features = len(self.tree_input[0]) // 2
            self.tree_input = linear_combo_features(self.tree_input, num_features)

        if explain:
            feature_importance = self.tree.tree.feature_importances_
            features_used = get_features_used(self.tree.tree, self.tree_input, num_classes=self.state_size,
                                              message=True,
                                              num_features=len(feature_importance))
            self.features_used = features_used

        x = self.tree(combined)
        return x
