import copy
import json

import numpy as np
from dig.xgraph.method.subgraphx import MCTS
from dig.xgraph.method.subgraphx import PlotUtils
from sklearn.tree import _tree
from torch_geometric.data import Data
from torch_geometric.utils import to_networkx
from tqdm import tqdm

from utils.prune_tree import get_num_nodes, prune_trees_least_influencial, score_model


class NpEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.integer):
            return int(obj)
        if isinstance(obj, np.floating):
            return float(obj)
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        return super(NpEncoder, self).default(obj)


background_colors = ['#f3f301', '#1679d6',
                     '#d23d3d', '#2d862d',
                     '#f39e04', '#9d239e',
                     '#1ff320', '#04f3eb',
                     '#f33eda', '#C8F3A9']


def generate_tree_json(tree, feature_names, class_names, layer_idx, num_states, feature_layer_index, threshold_feature):
    tree_ = tree.tree_
    feature = tree.tree_.feature
    threshold = tree.tree_.threshold
    impurity = tree.tree_.impurity
    tree_classes = tree.classes_

    def recurse(node, depth, prev_id=None):
        node_id = str(layer_idx) + '-' + str(node)
        new_node = {
            "id": node_id,
            "node_id": node,
            "type": 'test',
            "samples": np.sum(tree_.value[node]),
            "gini": "{:.3f}".format(impurity[node]),
            "children": []
        }
        if tree_.feature[node] != _tree.TREE_UNDEFINED:
            new_node["title"] = f'{feature_names[feature[node]]} ≤ {threshold[node]}'
            if not threshold_feature[feature[node]] and threshold[node] == 0.5:
                new_node["title"] = feature_names[feature[node]]
            new_node["children"].append(recurse(tree_.children_left[node], depth + 1))
            new_node["children"].append(recurse(tree_.children_right[node], depth + 1))
            new_node["layer"] = feature_layer_index[feature[node]]
            new_node["state"] = feature[node] % num_states
        else:
            out_class = tree_classes[np.argmax(tree_.value[node])]
            new_node["title"] = 'Leaf Node'
            new_node["classLabel"] = class_names[out_class]
            new_node["backgroundColor"] = background_colors[out_class]
            new_node["layer"] = -1
        return new_node

    return recurse(0, 1)


def export_decision_trees(trees, layer_names, input_features, states, class_labels, skip_connection, num_layers,
                          directory=None, png_directory=None, dataset_name='infection', use_pooling=False, save_json=True, data_dir=None):
    decision_trees = []
    for layer_idx, layer_name in enumerate(layer_names):
        if layer_name not in trees:
            continue

        class_names = [f'State {i}' for i in range(states)]
        feature_names = [f'Neighbour {i}' for i in range(states)] + [f'Prev. State {i}' for i in range(states)]
        feature_layer_index = [layer_idx - 1 for _ in range(states * 2)]
        threshold_feature = [True for _ in range(states)] + [False for _ in range(states)]
        for i1 in range(states):
            for i2 in range(states):
                feature_names.append(f'Neighb. {i2} ≤ Neighb. {i1}')
                feature_layer_index.append(layer_idx - 1)
                threshold_feature.append(False)
        if layer_name == 'input':
            feature_names = [f'Input State {i}' for i in range(input_features)]
            feature_layer_index = [-1 for _ in range(input_features)]
            threshold_feature = [False for _ in range(input_features)]
        elif layer_name == 'output':
            if skip_connection:
                feature_names = [f'Input: State {i}' for i in range(states)]
                feature_names_short = [f'Inp.: St. {i}' for i in range(states)]
                feature_layer_index = [0 for _ in range(states)]
                threshold_feature = [False for _ in range(states)]
                for layer_index in range(num_layers):
                    feature_names += [f'Layer {layer_index}: State {i}' for i in range(states)]
                    feature_names_short += [f'L. {layer_index}: St. {i}' for i in range(states)]
                    feature_layer_index += [layer_index + 1 for _ in range(states)]
                    threshold_feature += [False for _ in range(states)]

                if use_pooling:
                    for i1 in range((num_layers+1)*states):
                        for i2 in range((num_layers+1)*states):
                            feature_names.append(f'{feature_names_short[i2]} ≤ {feature_names_short[i1]}')
                            feature_layer_index.append(1)
                            threshold_feature.append(False)
            else:
                feature_names = [f'Prev. State {i}' for i in range(states)]
                feature_layer_index = [layer_idx - 1 for _ in range(states)]
                threshold_feature = [False for _ in range(states)]
            class_names = class_labels

        decision_trees.append(
            generate_tree_json(trees[layer_name], feature_names, class_names, layer_idx, states, feature_layer_index, threshold_feature))

    if save_json:
        if data_dir is not None:
            path = f'{data_dir}/explanations/trees/trees_{dataset_name}.json'
        with open(path, "w") as outfile:
            json.dump(decision_trees, outfile, cls=NpEncoder)
    else:
        return decision_trees

def export_stepped_decision_trees(trees, layer_names, input_features, states, class_labels, skip_connection, num_layers,
                                  model_dt, train_loader, test_loader, dataset_name='infection', prune_steps=10,
                                  score_model=score_model, train_test_mask=False, use_pooling=False, return_trees=False,
                                  test_mask=None, return_json=False, debug=False, data_dir=None):
    original_num_nodes = get_num_nodes(trees, layer_names)
    step_pruned_trees = []
    stepped_trees = []
    for i in range(prune_steps+1):
        if i > 0:
            max_nodes_to_keep = original_num_nodes - ((original_num_nodes / prune_steps) * i)
            current_num_nodes = get_num_nodes(trees, layer_names)
            nodes_to_remove = int(current_num_nodes - max_nodes_to_keep)
            trees = prune_trees_least_influencial(trees, model_dt, layer_names, train_loader,
                                                  remove_nodes=nodes_to_remove, score_model=score_model, mask=test_mask)
            model_dt.update_trees(trees)
        if test_mask is not None:
            train_score = score_model(model_dt, train_loader, test_mask)
            test_score = score_model(model_dt, test_loader, test_mask)
        elif train_test_mask:
            train_score = score_model(model_dt, train_loader, mask_index=0)
            test_score = score_model(model_dt, test_loader, mask_index=2)
        else:
            train_score = score_model(model_dt, train_loader)
            test_score = score_model(model_dt, test_loader)
        if debug:
            print('---------- GNN DT Test Accuracy After Pruning ----------')
            print(train_score, test_score)
            print("Tree Depths:", [trees[layer_name].get_depth() for layer_name in layer_names])
        dt_json_data = export_decision_trees(trees, layer_names, input_features, states,
                                             class_labels,
                                             skip_connection, num_layers,
                                             'tree_plots/infection', save_json=False, use_pooling=use_pooling)
        step_pruned_trees.append({
            "data": dt_json_data,
            "train_score": train_score,
            "test_score": test_score
        })
        stepped_trees.append(copy.deepcopy(trees))

    export_file_path = f'./explanations/trees/trees_step_{dataset_name}.json'
    if data_dir is not None:
        export_file_path = f'{data_dir}/explanations/trees/trees_step_{dataset_name}.json'

    with open(export_file_path, "w") as outfile:
        json.dump(step_pruned_trees, outfile, cls=NpEncoder)

    if return_trees and return_json:
        return step_pruned_trees, stepped_trees
    if return_json:
        return step_pruned_trees
    if return_trees:
        return stepped_trees

class ExplanationExporter:
    def __init__(self, model, layers, legend, feature_importances):
        self.model = model
        self.sample_num = 0
        self.export_dict = {"samples": [], "layers": layers, "legend": legend,
                            "featureImportances": feature_importances}

    def export(self, path):
        with open(path, "w") as outfile:
            json.dump(self.export_dict, outfile, cls=NpEncoder)

    def generate(self, x_sample, y_sample, edge_index_sample, extra_tile="", plot=False, node_idx=None, num_hops=2, mutag=False, mutagenicity=False):
        self.model.reset_explanation()

        # Default Subset and Edge Mask containing all nodes and edges
        subset = [i for i in range(len(x_sample))]
        edge_mask = [True for _ in range(len(edge_index_sample[0]))]

        if node_idx is not None:
            subgraph_x, subgraph_edge_index, subset, edge_mask, kwargs = \
                MCTS.__subgraph__(node_idx, x_sample, edge_index_sample, num_hops=num_hops)
            # print(f'Nodes: {len(subgraph_x)}, Edges: {len(subgraph_edge_index[0])}')
        logits, explained = self.model.explain(x_sample, edge_index_sample, subset=subset)
        logits = logits.argmax(dim=-1)

        intermediate_top_nodes = [[] for _ in range(len(x_sample))]
        intermediate_top_nodes_score = [[] for _ in range(len(x_sample))]

        for exp in self.model.explanation:
            for i in range(len(x_sample)):
                intermediate_importance_index = list(np.argsort(list(map(abs, exp[i])))[::-1])
                intermediate_importance_scores = list(
                    filter(lambda x: abs(x) > 0.0, list(exp[i][intermediate_importance_index])))
                intermediate_importance_index = intermediate_importance_index[:len(intermediate_importance_scores)]
                intermediate_top_nodes[i].append(intermediate_importance_index)
                intermediate_top_nodes_score[i].append(intermediate_importance_scores)

        if self.model.use_pooling:
            self.model.outputs[-1] = np.array([self.model.outputs[-1][0] for _ in range(len(x_sample))])
            self.model.inputs[-1] = np.array([self.model.inputs[-1][0] for _ in range(len(x_sample))])
            self.model.features_used[-1] = np.array([self.model.features_used[-1][0] for _ in range(len(x_sample))])
            logits = np.array([logits[0] for _ in range(len(x_sample))])
            y_sample = np.array([y_sample[0] for _ in range(len(x_sample))])
        intermediate_outputs = np.array([[layer_output[i] for layer_output in self.model.outputs]
                                         for i in range(len(x_sample))], dtype=object)
        intermediate_inputs = np.array([[layer_input[i] for layer_input in self.model.inputs]
                                        for i in range(len(x_sample))], dtype=object)
        intermediate_features_used = np.array(
            [[layer_features_used[i] for layer_features_used in self.model.features_used]
             for i in range(len(x_sample))], dtype=object)

        intermediate_top_nodes = np.array(intermediate_top_nodes, dtype=object)
        intermediate_top_nodes_score = np.array(intermediate_top_nodes_score, dtype=object)

        explained[np.isnan(explained)] = 0

        decision_paths = self.model.get_decision_paths()
        intermediate_decision_paths = [[layer_paths[i] for layer_paths in decision_paths]
                                       for i in range(len(x_sample))]

        subset_intermediate_decision_paths = []
        for sample_id in subset:
            subset_intermediate_decision_paths.append(intermediate_decision_paths[sample_id])

        if node_idx is not None:
            x_sample = subgraph_x
            edge_index_sample = subgraph_edge_index
            logits = logits[subset]
            y_sample = y_sample[subset]

        if plot:
            vis_graph = to_networkx(Data(x=x_sample, edge_index=edge_index_sample))
            plotutils = PlotUtils(dataset_name='ba_shapes')
            plotutils.plot(vis_graph, nodelist=[], figname=None, y=logits, node_idx=0)

        sample_dict = {"nodes": [], "edges": [],
                       "sampleName": f'Subgraph #{self.sample_num} ({len(x_sample)} Nodes) {extra_tile}'}

        index_conv = []
        for i in range(len(x_sample)):
            node_index = int(subset[i])
            index_conv.append(node_index)
            importance_index = list(np.argsort(list(map(abs, explained[node_index])))[::-1])
            importance_scores = list(filter(lambda x: abs(x) > 0.0, list(explained[node_index][importance_index])))
            importance_index = importance_index[:len(importance_scores)]
            node_dict = {
                "id": str(node_index),
                "label": int(logits[i]),
                "index": i,
                "labelTrue": int(y_sample[i]),
                "nodeIndex": node_index,
                "intermediate_top_nodes": intermediate_top_nodes[subset][i],
                "intermediate_top_nodes_score": intermediate_top_nodes_score[subset][i],
                "intermediate_outputs": intermediate_outputs[subset][i],
                "intermediate_inputs": intermediate_inputs[subset][i],
                "intermediate_features_used": intermediate_features_used[subset][i],
                "decision_paths": subset_intermediate_decision_paths[i]
            }
            if mutag:
                mutag_labels = ["C", "N", "O", "F", "I", "Cl", "Br"]
                node_dict["nodeLabel"] = mutag_labels[np.argmax(x_sample[i])]
            if mutagenicity:
                mutagenicity_labels = ["C", "O", "Cl", "H", "N", "F", "Br", "S", "P", "I", "Na", "K", "Li", "Ca"]
                node_dict["nodeLabel"] = mutagenicity_labels[np.argmax(x_sample[i])]
            sample_dict["nodes"].append(node_dict)
        edge_index_conv = []
        for i in range(len(edge_mask)):
            if edge_mask[i]:
                edge_index_conv.append(i)
        for i in range(len(edge_index_sample[0])):
            message_from = index_conv[edge_index_sample[0][i]]
            message_to = index_conv[edge_index_sample[1][i]]
            sample_dict["edges"].append({
                "source": str(int(message_from)),
                "target": str(int(message_to)),
            })

        self.sample_num += 1
        self.export_dict['samples'].append(sample_dict)

class ExplanationExporterStepped:
    def __init__(self, model, layers, legend, feature_importances, stepped_trees):
        self.model = model
        self.sample_num = 0
        self.stepped_trees = stepped_trees
        self.feature_importances = feature_importances
        self.layers = layers
        self.legend = legend
        self.export_dict = {"samples": [], "layers": layers, "legend": legend,
                            "featureImportances": feature_importances}
        self.export_list = []

    def export(self, path):
        with open(path, "w") as outfile:
            json.dump(self.export_list, outfile, cls=NpEncoder)

    def generate_for_all(self, export_samples, debug=False):
        tq = tqdm(self.stepped_trees, disable=(not debug))
        for trees in tq:
            self.sample_num = 0
            self.export_dict = {"samples": [], "layers": self.layers, "legend": self.legend,
                                "featureImportances": self.feature_importances}
            self.model.update_trees(trees)
            for sample in export_samples:
                tq.set_postfix(sample=f'{sample["node_idx"]} - {sample["extra_tile"]}')
                self.generate(sample['x'], sample['y'], sample['edge_index'], extra_tile=sample['extra_tile'], mutag=sample['mutag'], mutagenicity=sample['mutagenicity'], num_hops=sample['num_hops'], node_idx=sample['node_idx'])
            self.export_list.append(self.export_dict)

    def generate(self, x_sample, y_sample, edge_index_sample, extra_tile="", plot=False, node_idx=None, num_hops=2, mutag=False, mutagenicity=False):
        self.model.reset_explanation()

        # Default Subset and Edge Mask containing all nodes and edges
        subset = [i for i in range(len(x_sample))]
        edge_mask = [True for _ in range(len(edge_index_sample[0]))]

        if node_idx is not None:
            subgraph_x, subgraph_edge_index, subset, edge_mask, kwargs = \
                MCTS.__subgraph__(node_idx, x_sample, edge_index_sample, num_hops=num_hops)
            # print(f'Nodes: {len(subgraph_x)}, Edges: {len(subgraph_edge_index[0])}')
        logits, explained = self.model.explain(x_sample, edge_index_sample, subset=subset)
        logits = logits.argmax(dim=-1)

        intermediate_top_nodes = [[] for _ in range(len(x_sample))]
        intermediate_top_nodes_score = [[] for _ in range(len(x_sample))]

        for exp in self.model.explanation:
            for i in range(len(x_sample)):
                intermediate_importance_index = list(np.argsort(list(map(abs, exp[i])))[::-1])
                intermediate_importance_scores = list(
                    filter(lambda x: abs(x) > 0.0, list(exp[i][intermediate_importance_index])))
                intermediate_importance_index = intermediate_importance_index[:len(intermediate_importance_scores)]
                intermediate_top_nodes[i].append(intermediate_importance_index)
                intermediate_top_nodes_score[i].append(intermediate_importance_scores)

        if self.model.use_pooling:
            self.model.outputs[-1] = np.array([self.model.outputs[-1][0] for _ in range(len(x_sample))])
            self.model.inputs[-1] = np.array([self.model.inputs[-1][0] for _ in range(len(x_sample))])
            self.model.features_used[-1] = np.array([self.model.features_used[-1][0] for _ in range(len(x_sample))])
            logits = np.array([logits[0] for _ in range(len(x_sample))])
            y_sample = np.array([y_sample[0] for _ in range(len(x_sample))])
        intermediate_outputs = np.array([[layer_output[i] for layer_output in self.model.outputs]
                                         for i in range(len(x_sample))], dtype=object)
        intermediate_inputs = np.array([[layer_input[i] for layer_input in self.model.inputs]
                                        for i in range(len(x_sample))], dtype=object)
        intermediate_features_used = np.array(
            [[layer_features_used[i] for layer_features_used in self.model.features_used]
             for i in range(len(x_sample))], dtype=object)

        intermediate_top_nodes = np.array(intermediate_top_nodes, dtype=object)
        intermediate_top_nodes_score = np.array(intermediate_top_nodes_score, dtype=object)

        explained[np.isnan(explained)] = 0

        decision_paths = self.model.get_decision_paths()
        intermediate_decision_paths = [[layer_paths[i] for layer_paths in decision_paths]
                                       for i in range(len(x_sample))]

        subset_intermediate_decision_paths = []
        for sample_id in subset:
            subset_intermediate_decision_paths.append(intermediate_decision_paths[sample_id])

        if node_idx is not None:
            x_sample = subgraph_x
            edge_index_sample = subgraph_edge_index
            logits = logits[subset]
            y_sample = y_sample[subset]

        if plot:
            vis_graph = to_networkx(Data(x=x_sample, edge_index=edge_index_sample))
            plotutils = PlotUtils(dataset_name='ba_shapes')
            plotutils.plot(vis_graph, nodelist=[], figname=None, y=logits, node_idx=0)

        sample_dict = {"nodes": [], "edges": [],
                       "sampleName": f'Subgraph #{self.sample_num} ({len(x_sample)} Nodes) {extra_tile}'}

        index_conv = []
        for i in range(len(x_sample)):
            node_index = int(subset[i])
            index_conv.append(node_index)
            importance_index = list(np.argsort(list(map(abs, explained[node_index])))[::-1])
            importance_scores = list(filter(lambda x: abs(x) > 0.0, list(explained[node_index][importance_index])))
            importance_index = importance_index[:len(importance_scores)]
            node_dict = {
                "id": str(node_index),
                "label": int(logits[i]),
                "index": i,
                "labelTrue": int(y_sample[i]),
                "nodeIndex": node_index,
                "intermediate_top_nodes": intermediate_top_nodes[subset][i],
                "intermediate_top_nodes_score": intermediate_top_nodes_score[subset][i],
                "intermediate_outputs": intermediate_outputs[subset][i],
                "intermediate_inputs": intermediate_inputs[subset][i],
                "intermediate_features_used": intermediate_features_used[subset][i],
                "decision_paths": subset_intermediate_decision_paths[i]
            }
            if mutag:
                mutag_labels = ["C", "N", "O", "F", "I", "Cl", "Br"]
                node_dict["nodeLabel"] = mutag_labels[np.argmax(x_sample[i])]
            if mutagenicity:
                mutagenicity_labels = ["C", "O", "Cl", "H", "N", "F", "Br", "S", "P", "I", "Na", "K", "Li", "Ca"]
                node_dict["nodeLabel"] = mutagenicity_labels[np.argmax(x_sample[i])]
            sample_dict["nodes"].append(node_dict)
        edge_index_conv = []
        for i in range(len(edge_mask)):
            if edge_mask[i]:
                edge_index_conv.append(i)
        for i in range(len(edge_index_sample[0])):
            message_from = index_conv[edge_index_sample[0][i]]
            message_to = index_conv[edge_index_sample[1][i]]
            sample_dict["edges"].append({
                "source": str(int(message_from)),
                "target": str(int(message_to)),
            })

        self.sample_num += 1
        self.export_dict['samples'].append(sample_dict)
