import json
import os

import pymongo as pymongo
import torch
from tqdm import trange

import numpy as np

from utils.export_explanation import ExplanationExporter, export_decision_trees, export_stepped_decision_trees, \
    ExplanationExporterStepped

client = pymongo.MongoClient("mongodb://root:example@localhost:27017")
db = client.interpretable


def export_for_ui(dataset_name, dataset, params, model_dt, trees, score_model, data_dir, test_mask=None,
                  test_dataset=None, loaders=None):
    print("-- Exporting Explanations --")
    layer_names = ['input']
    for i in range(params["number_of_layers"]):
        layer_names.append(f"stone_age.{i}.linear_softmax")
    layer_names.append('output')

    export_samples = []
    layers = ["Input"] + [f"GNN Layer {i}" for i in range(params["number_of_layers"])] + ["Output"]

    if dataset_name == 'MUTAG':
        data = dataset[0]

        dataset_filename = 'mutag'
        dataset_title = "MUTAG"
        is_graph_classification = True
        mask_dataset = False
        data_train = loaders[0]
        data_test = loaders[2]

        legend = []
        class_labels = []
        label = f'Not Mutagenic'
        class_labels.append(label)
        legend.append({
            "index": 0,
            "label": label
        })
        label = f'Mutagenic'
        class_labels.append(label)
        legend.append({
            "index": 1,
            "label": label
        })

        feature_importances = [trees[layer_name].feature_importances_ for layer_name in layer_names]

        explanation_exporter = ExplanationExporter(model_dt, layers, legend, feature_importances)

        for data in test_dataset[:10]:
            print("Generate Sample", int(data.y))
            explanation_exporter.generate(data.x, data.y, data.edge_index, mutag=True,
                                          extra_tile=f'- {"Not Mutagenic" if data.y == 0 else "Mutagenic"}')
            export_samples.append({
                "x": data.x,
                "y": data.y,
                "edge_index": data.edge_index,
                "mutag": True,
                "mutagenicity": False,
                "extra_tile": f'- {"Not Mutagenic" if data.y == 0 else "Mutagenic"}',
                "node_idx": None,
                "num_hops": 3
            })

    if dataset_name == 'PROTEINS':
        data = dataset[0]

        dataset_filename = 'proteins'
        dataset_title = "PROTEINS"
        is_graph_classification = True
        mask_dataset = False
        data_train = loaders[0]
        data_test = loaders[2]

        legend = []
        class_labels = []
        for i in range(dataset.num_classes):
            label = f'Class {i}'
            class_labels.append(label)
            legend.append({
                "index": i,
                "label": label
            })

        feature_importances = [trees[layer_name].feature_importances_ for layer_name in layer_names]

        explanation_exporter = ExplanationExporter(model_dt, layers, legend, feature_importances)

        exp_counter = 0
        for data in test_dataset:
            if len(data.x) > 20:
                continue
            has_turns = any([s[2] for s in data.x])
            if not has_turns:
                continue
            print(data.edge_index)
            print(data.x)
            #1/0
            #print (data.x)
            explanation_exporter.generate(data.x, data.y, data.edge_index)
            export_samples.append({
                "x": data.x,
                "y": data.y,
                "edge_index": data.edge_index,
                "mutag": False,
                "mutagenicity": False,
                "extra_tile": f'- {"Class 0" if data.y == 0 else "Class 1"}',
                "node_idx": None,
                "num_hops": 3
            })

            exp_counter += 1
            if exp_counter > 100:
                break

    if dataset_name == 'IMDB-BINARY':
        data = dataset[0]

        dataset_filename = 'imdb'
        dataset_title = "IMDB-BINARY"
        is_graph_classification = True
        mask_dataset = False
        data_train = loaders[0]
        data_test = loaders[2]
        legend = []
        class_labels = []
        label = f'Action'
        class_labels.append(label)
        legend.append({
            "index": 0,
            "label": label
        })
        label = f'Romance'
        class_labels.append(label)
        legend.append({
            "index": 1,
            "label": label
        })

        feature_importances = [trees[layer_name].feature_importances_ for layer_name in layer_names]

        explanation_exporter = ExplanationExporter(model_dt, layers, legend, feature_importances)

        generated = [0 for _ in range(dataset.num_classes)]
        for data in test_dataset:
            if len(data.x) > 100 or sum(generated) >= dataset.num_classes * 5:
                continue
            out = model_dt(data.x, data.edge_index, data.batch)
            if out.argmax(-1) != data.y:
                continue

            if generated[int(data.y)] >= 5:
                continue
            generated[int(data.y)] += 1
            print("Generate Sample", int(data.y))
            explanation_exporter.generate(data.x, data.y, data.edge_index, mutagenicity=False,
                                          extra_tile=f'- {class_labels[int(data.y)]}')
            export_samples.append({
                "x": data.x,
                "y": data.y,
                "edge_index": data.edge_index,
                "mutag": False,
                "mutagenicity": False,
                "extra_tile": f'- {class_labels[int(data.y)]}',
                "node_idx": None,
                "num_hops": 3
            })

    if dataset_name == 'REDDIT-BINARY':
        data = dataset[0]

        dataset_filename = 'reddit'
        dataset_title = "REDDIT-BINARY"
        is_graph_classification = True
        mask_dataset = False
        data_train = loaders[0]
        data_test = loaders[2]
        legend = []
        class_labels = []
        label = f'Q/A'
        class_labels.append(label)
        legend.append({
            "index": 0,
            "label": label
        })
        label = f'Discussion'
        class_labels.append(label)
        legend.append({
            "index": 1,
            "label": label
        })

        feature_importances = [trees[layer_name].feature_importances_ for layer_name in layer_names]

        explanation_exporter = ExplanationExporter(model_dt, layers, legend, feature_importances)

        generated = [0 for _ in range(dataset.num_classes)]
        for data in test_dataset:
            if len(data.x) > 100 or sum(generated) >= dataset.num_classes * 5:
                continue
            out = model_dt(data.x, data.edge_index, data.batch)
            if out.argmax(-1) != data.y:
                continue

            if generated[int(data.y)] >= 5:
                continue
            generated[int(data.y)] += 1
            print("Generate Sample", int(data.y))
            explanation_exporter.generate(data.x, data.y, data.edge_index, mutagenicity=False,
                                          extra_tile=f'- {class_labels[int(data.y)]}')
            export_samples.append({
                "x": data.x,
                "y": data.y,
                "edge_index": data.edge_index,
                "mutag": False,
                "mutagenicity": False,
                "extra_tile": f'- {class_labels[int(data.y)]}',
                "node_idx": None,
                "num_hops": 3
            })

    if dataset_name == 'Mutagenicity':
        data = dataset[0]

        dataset_filename = 'mutagenicity'
        dataset_title = "Mutagenicity"
        is_graph_classification = True
        mask_dataset = False
        data_train = loaders[0]
        data_test = loaders[2]
        legend = []
        class_labels = []
        label = f'Mutagenic'
        class_labels.append(label)
        legend.append({
            "index": 0,
            "label": label
        })
        label = f'Not Mutagenic'
        class_labels.append(label)
        legend.append({
            "index": 1,
            "label": label
        })

        feature_importances = [trees[layer_name].feature_importances_ for layer_name in layer_names]

        explanation_exporter = ExplanationExporter(model_dt, layers, legend, feature_importances)

        generated = 0
        generated_0 = 0
        generated_1 = 0
        for data in test_dataset:
            if len(data.x) > 100 or generated >= 10:
                continue
            out = model_dt(data.x, data.edge_index, data.batch)
            if out.argmax(-1) != data.y:
                continue
            if int(data.y) == 0:
                generated_0 += 1
                if generated_0 > 5:
                    continue
            else:
                generated_1 += 1
                if generated_1 > 5:
                    continue
            print("Generate Sample", int(data.y))
            explanation_exporter.generate(data.x, data.y, data.edge_index, mutagenicity=True,
                                          extra_tile=f'- {"Mutagenic" if data.y == 0 else "Not Mutagenic"}')
            export_samples.append({
                "x": data.x,
                "y": data.y,
                "edge_index": data.edge_index,
                "mutag": False,
                "mutagenicity": True,
                "extra_tile": f'- {"Mutagenic" if data.y == 0 else "Not Mutagenic"}',
                "node_idx": None,
                "num_hops": 3
            })
            generated += 1

    if dataset_name == 'BBBP':
        data = dataset[0]

        dataset_filename = 'bbbp'
        dataset_title = "BBBP"
        is_graph_classification = True
        mask_dataset = False
        data_train = loaders[0]
        data_test = loaders[2]
        legend = []
        class_labels = []
        label = f'Class 0'
        class_labels.append(label)
        legend.append({
            "index": 0,
            "label": label
        })
        label = f'Class 1'
        class_labels.append(label)
        legend.append({
            "index": 1,
            "label": label
        })

        feature_importances = [trees[layer_name].feature_importances_ for layer_name in layer_names]

        explanation_exporter = ExplanationExporter(model_dt, layers, legend, feature_importances)

        generated = 0
        generated_0 = 0
        generated_1 = 0
        for data in test_dataset:
            if len(data.x) > 100 or generated >= 10:
                continue
            out = model_dt(data.x, data.edge_index, data.batch)
            if out.argmax(-1) != data.y:
                continue
            if int(data.y) == 0:
                generated_0 += 1
                if generated_0 > 5:
                    continue
            else:
                generated_1 += 1
                if generated_1 > 5:
                    continue
            print("Generate Sample", int(data.y))
            explanation_exporter.generate(data.x, data.y, data.edge_index,
                                          extra_tile=f'- {"Class 0" if data.y == 0 else "Class 1"}')
            export_samples.append({
                "x": data.x,
                "y": data.y,
                "edge_index": data.edge_index,
                "mutag": False,
                "mutagenicity": False,
                "extra_tile": f'- {"Class 0" if data.y == 0 else "Class 1"}',
                "node_idx": None,
                "num_hops": 3
            })
            generated += 1

    if dataset_name == 'COLLAB':
        data = dataset[0]

        dataset_filename = 'collab'
        dataset_title = "COLLAB"
        is_graph_classification = True
        mask_dataset = False
        data_train = loaders[0]
        data_test = loaders[2]
        legend = []
        class_labels = []
        label = f'High Energy'
        class_labels.append(label)
        legend.append({
            "index": 0,
            "label": label
        })
        label = f'Condensed Matter'
        class_labels.append(label)
        legend.append({
            "index": 1,
            "label": label
        })
        label = f'Astro Physics'
        class_labels.append(label)
        legend.append({
            "index": 2,
            "label": label
        })

        feature_importances = [trees[layer_name].feature_importances_ for layer_name in layer_names]

        explanation_exporter = ExplanationExporter(model_dt, layers, legend, feature_importances)

        generated = [0 for _ in range(dataset.num_classes)]
        for data in test_dataset:
            if len(data.x) > 50 or len(data.edge_index[0]) > 250 or sum(generated) >= dataset.num_classes * 5:
                continue
            out = model_dt(data.x, data.edge_index, data.batch)
            if out.argmax(-1) != data.y:
                continue

            if generated[int(data.y)] >= 5:
                continue
            generated[int(data.y)] += 1
            print("Generate Sample", int(data.y))
            explanation_exporter.generate(data.x, data.y, data.edge_index, mutagenicity=False,
                                          extra_tile=f'- {class_labels[int(data.y)]}')
            export_samples.append({
                "x": data.x,
                "y": data.y,
                "edge_index": data.edge_index,
                "mutag": False,
                "mutagenicity": False,
                "extra_tile": f'- {class_labels[int(data.y)]}',
                "node_idx": None,
                "num_hops": 3
            })

    if dataset_name == 'Infection':
        data = dataset[0]

        dataset_filename = 'infection'
        dataset_title = "Infection"
        is_graph_classification = False
        mask_dataset = False
        data_train = loaders[0]
        data_test = loaders[2]

        legend = []
        class_labels = []
        for i in range(data.num_classes):
            label = f'Distance {i}'
            if i == 0:
                label = "Infected"
            if i == data.num_classes - 1:
                label = f'Distance {i}+'
            class_labels.append(label)
            legend.append({
                "index": i,
                "label": label
            })

        feature_importances = [trees[layer_name].feature_importances_ for layer_name in layer_names]

        explanation_exporter = ExplanationExporter(model_dt, layers, legend, feature_importances)

        x_sample = torch.as_tensor([[0, 1], [1, 0], [0, 1], [0, 1], [0, 1], [0, 1], [0, 1], [0, 1]])
        y_sample = torch.as_tensor([6, 0, 1, 2, 3, 4, 5, 6])
        edge_index_sample = torch.as_tensor([[0, 1, 2, 3, 4, 5, 6], [1, 2, 3, 4, 5, 6, 7]])
        explanation_exporter.generate(x_sample, y_sample, edge_index_sample, extra_tile="- Simple Line")
        export_samples.append({
            "x": x_sample,
            "y": y_sample,
            "edge_index": edge_index_sample,
            "mutag": False,
            "mutagenicity": False,
            "extra_tile": f'- Simple Line',
            "node_idx": None,
            "num_hops": 3
        })

        x_sample = torch.as_tensor([[1, 0], [0, 1], [0, 1], [0, 1], [0, 1], [0, 1], [0, 1]])
        y_sample = torch.as_tensor([0, 1, 2, 3, 4, 5, 6])
        edge_index_sample = torch.as_tensor([[0, 1, 2, 3, 4, 5, 6], [1, 2, 3, 4, 5, 6, 0]])
        explanation_exporter.generate(x_sample, y_sample, edge_index_sample, extra_tile="- Ring")
        export_samples.append({
            "x": x_sample,
            "y": y_sample,
            "edge_index": edge_index_sample,
            "mutag": False,
            "mutagenicity": False,
            "extra_tile": f'- Ring',
            "node_idx": None,
            "num_hops": 3
        })

        x_sample = torch.as_tensor([[1, 0], [0, 1], [0, 1], [0, 1], [0, 1], [0, 1], [0, 1], [1, 0], [0, 1], [0, 1]])
        y_sample = torch.as_tensor([0, 1, 2, 3, 3, 4, 6, 0, 1, 2])
        edge_index_sample = torch.as_tensor([[0, 1, 2, 3, 4, 6, 7, 8, 9], [1, 2, 3, 4, 5, 7, 8, 9, 4]])
        explanation_exporter.generate(x_sample, y_sample, edge_index_sample, extra_tile="- Y")
        export_samples.append({
            "x": x_sample,
            "y": y_sample,
            "edge_index": edge_index_sample,
            "mutag": False,
            "mutagenicity": False,
            "extra_tile": f'- Y',
            "node_idx": None,
            "num_hops": 3
        })
        test_dataset = next(iter(loaders[2]))

        node_indices = torch.where(test_dataset.y == 2)[0].tolist()
        for node_idx in node_indices[:5]:
            print(f'Generating Explanation for Subgraph (Node: {node_idx}) ...')
            explanation_exporter.generate(test_dataset.x, test_dataset.y, test_dataset.edge_index, node_idx=node_idx)
            export_samples.append({
                "x": test_dataset.x,
                "y": test_dataset.y,
                "edge_index": test_dataset.edge_index,
                "mutag": False,
                "mutagenicity": False,
                "extra_tile": f'- Node: {node_idx}',
                "node_idx": node_idx,
                "num_hops": 3
            })

        def score_model_infection(m, d, mask=None):
            m.eval()
            total_correct = 0
            num_samples = 0
            for data in d:
                out = m(data.x, data.edge_index, data.batch)
                total_correct += int((out.argmax(-1) == data.y).sum())
                num_samples += len(data.x)
            acc = total_correct / num_samples

            return acc

        score_model = score_model_infection

    if dataset_name == 'Saturation':
        data = dataset[0]

        dataset_filename = 'saturation'
        dataset_title = "Saturation"
        is_graph_classification = False
        mask_dataset = False
        data_train = loaders[0]
        data_test = loaders[2]

        legend = [{
            "index": 0,
            "label": "Class 0"
        }, {
            "index": 1,
            "label": "Class 1"
        }]

        feature_importances = [trees[layer_name].feature_importances_ for layer_name in layer_names]

        explanation_exporter = ExplanationExporter(model_dt, layers, legend, feature_importances)
        test_dataset = next(iter(loaders[2]))
        node_degrees = [0] * len(test_dataset.x)
        for i in range(len(test_dataset.edge_index[0])):
            node_degrees[test_dataset.edge_index[1][i]] += 1

        generated = 0
        generated_0 = 0
        generated_1 = 0
        for node_idx in range(len(test_dataset.x)):
            if node_degrees[node_idx] > 30 or generated >= 10:
                continue
            if int(test_dataset.y[node_idx]) == 0:
                generated_0 += 1
                if generated_0 > 5:
                    continue
            else:
                generated_1 += 1
                if generated_1 > 5:
                    continue
            print(f'Generating Explanation for Subgraph (Node: {node_idx}) ...')
            explanation_exporter.generate(test_dataset.x, test_dataset.y, test_dataset.edge_index, node_idx=node_idx,
                                          num_hops=1,
                                          extra_tile=f'- Node: {node_idx}')
            export_samples.append({
                "x": test_dataset.x,
                "y": test_dataset.y,
                "edge_index": test_dataset.edge_index,
                "mutag": False,
                "mutagenicity": False,
                "extra_tile": f'- Node: {node_idx}',
                "node_idx": node_idx,
                "num_hops": 1
            })
            generated += 1

    if dataset_name == 'Tree_Cycle':
        data = dataset[0]

        dataset_filename = 'tree_cycle'
        dataset_title = "Tree Cycle"
        is_graph_classification = False
        #mask_dataset = True
        #data_train = data
        #data_test = data

        mask_dataset = False
        data_train = loaders[0]
        data_test = loaders[2]

        node_indices = []
        for idx in np.random.permutation(len(dataset[0].y)):
        #for idx in data.test_mask:
            if dataset[0].y[idx] == 1:
                node_indices.append(idx)

        node_indices = torch.tensor(node_indices, dtype=torch.long)

        legend = [{
            "index": 0,
            "label": "Not in Cycle"
        }, {
            "index": 1,
            "label": "In Cycle"
        }]

        feature_importances = [trees[layer_name].feature_importances_ for layer_name in layer_names]

        explanation_exporter = ExplanationExporter(model_dt, layers, legend, feature_importances)

        for node_idx in node_indices[:5]:
            print(f'Generating Explanation for Subgraph (Node: {node_idx}) ...')
            explanation_exporter.generate(data.x, data.y, data.edge_index, node_idx=node_idx, num_hops=6,
                                          extra_tile=f'- Node: {node_idx}')
            export_samples.append({
                "x": data.x,
                "y": data.y,
                "edge_index": data.edge_index,
                "mutag": False,
                "mutagenicity": False,
                "extra_tile": f'- Node: {node_idx}',
                "node_idx": node_idx,
                "num_hops": 3
            })

    if dataset_name == 'Tree_Grid':
        data = dataset[0]

        dataset_filename = 'tree_grid'
        dataset_title = "Tree Grid"
        is_graph_classification = False
        #mask_dataset = True
        #data_train = data
        #data_test = data

        mask_dataset = False
        data_train = loaders[0]
        data_test = loaders[2]

        node_indices = []
        for idx in np.random.permutation(len(dataset[0].y)):
        #for idx in data.test_mask:
            if dataset[0].y[idx] == 1:
                node_indices.append(idx)

        node_indices = torch.tensor(node_indices, dtype=torch.long)

        legend = [{
            "index": 0,
            "label": "Not in Grid"
        }, {
            "index": 1,
            "label": "In Grid"
        }]

        feature_importances = [trees[layer_name].feature_importances_ for layer_name in layer_names]

        explanation_exporter = ExplanationExporter(model_dt, layers, legend, feature_importances)

        for node_idx in node_indices[:5]:
            print(f'Generating Explanation for Subgraph (Node: {node_idx}) ...')
            explanation_exporter.generate(data.x, data.y, data.edge_index, node_idx=node_idx, num_hops=5,
                                          extra_tile=f'- Node: {node_idx}')
            export_samples.append({
                "x": data.x,
                "y": data.y,
                "edge_index": data.edge_index,
                "mutag": False,
                "mutagenicity": False,
                "extra_tile": f'- Node: {node_idx}',
                "node_idx": node_idx,
                "num_hops": 5
            })

    if dataset_name == 'BA_2Motifs':
        data = dataset[0]

        dataset_filename = 'ba_2motifs'
        dataset_title = "BA 2Motifs"
        is_graph_classification = True
        mask_dataset = False
        data_train = loaders[0]
        data_test = loaders[2]

        legend = []
        class_labels = []
        label = f'Cycle'
        class_labels.append(label)
        legend.append({
            "index": 0,
            "label": label
        })
        label = f'House'
        class_labels.append(label)
        legend.append({
            "index": 1,
            "label": label
        })

        feature_importances = [trees[layer_name].feature_importances_ for layer_name in layer_names]

        explanation_exporter = ExplanationExporter(model_dt, layers, legend, feature_importances)

        generated = [0 for _ in range(dataset.num_classes)]
        for data in test_dataset:
            if len(data.x) > 100 or sum(generated) >= dataset.num_classes * 5:
                continue
            out = model_dt(data.x, data.edge_index, data.batch)
            if out.argmax(-1) != data.y:
                continue

            if generated[int(data.y)] >= 5:
                continue
            generated[int(data.y)] += 1
            print("Generate Sample", int(data.y))
            explanation_exporter.generate(data.x, data.y, data.edge_index, mutagenicity=False,
                                          extra_tile=f'- {class_labels[int(data.y)]}')
            export_samples.append({
                "x": data.x,
                "y": data.y,
                "edge_index": data.edge_index,
                "mutag": False,
                "mutagenicity": False,
                "extra_tile": f'- {class_labels[int(data.y)]}',
                "node_idx": None,
                "num_hops": 3
            })

    if dataset_name == 'BA_shapes':
        data = dataset[0]
        node_indices = []
        for idx in np.random.permutation(len(dataset[0].y)):
            # for idx in data.test_mask:
            if dataset[0].y[idx] == 3:
                node_indices.append(idx)

        node_indices = torch.tensor(node_indices, dtype=torch.long)

        legend = [{
            "index": 0,
            "label": "Not in House"
        }, {
            "index": 1,
            "label": "Middle of House"
        }, {
            "index": 2,
            "label": "Bottom of House"
        }, {
            "index": 3,
            "label": "Top of House"
        }]

        feature_importances = [trees[layer_name].feature_importances_ for layer_name in layer_names]

        explanation_exporter = ExplanationExporter(model_dt, layers, legend, feature_importances)
        dataset_filename = 'ba_shapes'
        dataset_title = "BA Shapes"
        is_graph_classification = False
        # mask_dataset = True
        # data_train = data
        # data_test = data

        mask_dataset = False
        data_train = loaders[0]
        data_test = loaders[2]
        for node_idx in node_indices[:5]:
            print(f'Generating Explanation for Subgraph (Node: {node_idx}) ...')
            explanation_exporter.generate(data.x, data.y, data.edge_index, node_idx=node_idx, num_hops=3,
                                          extra_tile=f'- Node: {node_idx}')
            export_samples.append({
                "x": data.x,
                "y": data.y,
                "edge_index": data.edge_index,
                "mutag": False,
                "mutagenicity": False,
                "extra_tile": f'- Node: {node_idx}',
                "node_idx": node_idx,
                "num_hops": 3
            })

    class_labels = [item['label'] for item in legend]

    path = f'{data_dir}/explanations/graphs/'
    if not os.path.exists(path):
        os.makedirs(path)

    path = f'{data_dir}/explanations/trees/'
    if not os.path.exists(path):
        os.makedirs(path)

    explanation_exporter.export(f'{data_dir}/explanations/graphs/data_{dataset_filename}.json')

    print("-- Exporting Decision Trees --")
    export_decision_trees(trees, layer_names, data.num_node_features, params["state_space"], class_labels,
                          params["skip_connection"], params["number_of_layers"], f'{data_dir}/explanations/tree_plots',
                          dataset_name=dataset_filename, data_dir=data_dir, use_pooling=is_graph_classification)

    print("-- Exporting Lossy Decision Trees --")
    _, stepped_trees = export_stepped_decision_trees(trees, layer_names, data.num_node_features, params["state_space"],
                                                     class_labels,
                                                     params["skip_connection"], params["number_of_layers"], model_dt,
                                                     data_train,
                                                     data_test, dataset_name=dataset_filename, prune_steps=10,
                                                     score_model=score_model, train_test_mask=mask_dataset,
                                                     return_trees=True, return_json=True,
                                                     test_mask=test_mask, use_pooling=is_graph_classification,
                                                     data_dir=data_dir)

    print("-- Exporting Lossy Explanations --")
    explanation_exporter = ExplanationExporterStepped(model_dt, layers, legend, feature_importances, stepped_trees)

    explanation_exporter.generate_for_all(export_samples, debug=True)
    explanation_exporter.export(
        f'{data_dir}/explanations/graphs/data_step_{dataset_filename}.json')

    f = open(f'{data_dir}/explanations/trees/trees_step_{dataset_filename}.json', )
    json_data_trees = json.load(f)
    f = open(f'{data_dir}/explanations/graphs/data_step_{dataset_filename}.json', )
    json_data_graphs = json.load(f)
    print("--- All Explanations generated ---")
    print("--- Writing to MongoDB ---")
    insert_datasets(dataset_filename, dataset_title, json_data_graphs, graph_classif=is_graph_classification)
    insert_trees(dataset_filename, json_data_trees)
    insert_samples(dataset_filename, json_data_graphs)
    print("--- Done ---")


def insert_datasets(dataset_name, dataset_tile, data, graph_classif=False):
    data_samples = db.datasets
    sample = data[0]
    wirte_object = {
        "dataset": dataset_name,
        "title": dataset_tile,
        "steps": True,
        "graph_classification": graph_classif,
        "layers": sample["layers"],
        "legend": sample["legend"],
        "samples": [s["sampleName"] for s in sample["samples"]],
    }
    if data_samples.count_documents({
        "dataset": dataset_name
    }) > 0:
        data_samples.delete_many({"dataset": dataset_name})
    data_samples.insert_one(wirte_object)


def insert_trees(dataset_name, data):
    data_samples = db.trees
    wirte_object = {
        "dataset": dataset_name,
        "data": [],
    }
    for prune_index in trange(len(data)):
        sample = data[prune_index]
        data_object = {
            "trees": sample["data"],
            "prune_step": prune_index,
            "train_score": sample["train_score"],
            "test_score": sample["test_score"],
        }
        wirte_object["data"].append(data_object)

    if data_samples.count_documents({
        "dataset": dataset_name
    }) > 0:
        data_samples.delete_many({"dataset": dataset_name})
    data_samples.insert_one(wirte_object)


CUTOFF_VALUE = 0.01


def filter_nodes(node):
    for layer_index in range(len(node["intermediate_top_nodes"])):
        for i in range(len(node["intermediate_top_nodes"][layer_index])):
            if abs(node["intermediate_top_nodes_score"][layer_index][i]) < CUTOFF_VALUE:
                node["intermediate_top_nodes_score"][layer_index] = node["intermediate_top_nodes_score"][layer_index][
                                                                    :i]
                node["intermediate_top_nodes"][layer_index] = node["intermediate_top_nodes"][layer_index][:i]
                break


def insert_samples(dataset_name, data):
    samples = db.samples
    if samples.count_documents({
        "dataset": dataset_name
    }) > 0:
        samples.delete_many({"dataset": dataset_name})
    for sample_index in trange(len(data[0]['samples'])):
        wirte_object = {
            "dataset": dataset_name,
            "index": sample_index+1,
            "name": data[0]['samples'][sample_index]["sampleName"],
            "edges": data[0]['samples'][sample_index]["edges"],
            "data": [],
        }
        for prune_index in range(len(data)):
            sample = data[prune_index]['samples'][sample_index]
            nodes = []
            for node in sample["nodes"]:
                new_node = {
                    "id": str(node['id']),
                    "label": node['label'],
                    "index": node['index'],
                    "labelTrue": node['labelTrue'],
                    "nodeIndex": node['nodeIndex'],
                    "intermediate_top_nodes": node['intermediate_top_nodes'],
                    "intermediate_top_nodes_score": node['intermediate_top_nodes_score'],
                    "intermediate_outputs": node['intermediate_outputs'],
                    "intermediate_features_used": node['intermediate_features_used'],
                    "decision_paths": node['decision_paths'],
                }
                filter_nodes(new_node)
                if 'nodeLabel' in node:
                    new_node['nodeLabel'] = node['nodeLabel']
                nodes.append(new_node)
            data_object = {
                "prune_step": prune_index,
                "nodes": nodes,
            }
            wirte_object["data"].append(data_object)
        try:
            samples.insert_one(wirte_object)
        except:
            print(
                f'Sample too Large Nodes: {len(data[0]["samples"][sample_index]["nodes"])}, Edges: {len(data[0]["samples"][sample_index]["edges"])}')
