import json
import pickle

import numpy as np

from relnet.evaluation.eval_utils import find_max_nodes, find_max_edges, find_max_diameter, find_max_colors


class GraphDataset(object):
    def __init__(self, fp, graph_name, obj_fun_name, network_generator_name, graphs=None):
        self.fp = fp
        self.graph_name = graph_name
        self.obj_fun_name = obj_fun_name
        self.network_generator_name = network_generator_name

        if graphs is None:
            # already created -- load metadata.
            self.metadata_dict = self.load_metadata()
        else:
            # construct and save.
            self.construct_and_save(graphs)

    def construct_and_save(self, graphs):
        global_metadata = {}

        original_graph = [g for g in graphs if g.var_type == None and g.var_number == None and g.generator_seed == 0][0]
        original_graph_hash = original_graph.g_hash

        global_metadata['graph_name'] = self.graph_name
        global_metadata['original_graph_hash'] = original_graph_hash
        global_metadata['num_graphs'] = len(graphs)
        global_metadata['obj_fun_name'] = self.obj_fun_name
        global_metadata['network_generator_name'] = self.network_generator_name

        global_metadata['max_num_nodes'] = find_max_nodes(graphs)
        global_metadata['max_num_edges'] = find_max_edges(graphs)
        global_metadata['max_diameter'] = find_max_diameter(graphs)
        global_metadata['max_colors'] = find_max_colors(graphs)

        per_graph_metadata = {}
        for graph in graphs:
            gh = graph.g_hash
            gm = {}

            gm['num_nodes'] = graph.num_nodes
            gm['num_edges'] = graph.num_edges
            gm['obj_fun_value'] = graph.obj_fun_value
            gm['generator_seed'] = graph.generator_seed
            gm['var_type'] = graph.var_type
            gm['var_number'] = graph.var_number

            per_graph_metadata[gh] = gm

            self.write_graph_file(graph)

        metadata_dict = {}
        metadata_dict['global_metadata'] = global_metadata
        metadata_dict['per_graph_metadata'] = per_graph_metadata
        self.metadata_dict = metadata_dict
        self.write_metadata()

    def get_all_graph_hashes(self):
        return list(self.metadata_dict['per_graph_metadata'].keys())

    def get_metadata_for_hash(self, g_hash):
        return self.metadata_dict['per_graph_metadata'][g_hash]

    def get_gts_for_hashes(self, g_hashes):
        return np.array([self.metadata_dict['per_graph_metadata'][g_hash]['obj_fun_value'] for g_hash in g_hashes])

    def write_metadata(self):
        out_file = self.fp.graph_ds_dir / f'dataset_metadata.json'
        with open(out_file, "w") as fh:
            json.dump(self.metadata_dict, fh, indent=4, sort_keys=True)

    def load_metadata(self):
        in_file = self.fp.graph_ds_dir / f'dataset_metadata.json'
        with open(in_file, "rb") as fh:
            metadata_dict = json.load(fh)

        return metadata_dict

    def write_graph_file(self, graph):
        out_file = self.fp.graph_ds_dir / f'{graph.g_hash}.pkl'
        with open(out_file, 'wb') as fh:
            pickle.dump(graph, fh)

    def load_graph_file(self, graph_hash):
        in_file = self.fp.graph_ds_dir / f'{graph_hash}.pkl'
        with open(in_file, 'rb') as fh:
            graph = pickle.load(fh)
        return graph

