import itertools
import math

from relnet.evaluation.eval_utils import get_run_number
from relnet.state.graph_state import GraphState, get_graph_hash
from relnet.utils.config_utils import get_logger_instance


class TmGenStateGenerator(object):
    name = "tmgen"

    NODE_PROPS_DISCARD = {"x", "y"}
    EDGE_PROPS_DISCARD = {"label", "weight"}

    def __init__(self, file_paths):
        super().__init__()

        self.file_paths = file_paths
        logs_file = str(file_paths.construct_log_filepath())
        self.logger_instance = get_logger_instance(logs_file)

    def generate(self, graph_name, gen_params, random_seed, cycle_start_seed, cycle_end_seed, var_type=None, var_number=None):
        if var_type is None and var_number is None:

            graph_id = graph_name
        else:
            graph_id = self.get_var_id(graph_name, var_type, var_number)

        topology_file = self.file_paths.topologies_dir / f"{graph_id}.graph"
        num_nodes, np_data, edges, ep_data = self.read_topology_file(topology_file,
                                                                     self.NODE_PROPS_DISCARD, self.EDGE_PROPS_DISCARD)

        demands = {}

        dm_dir = self.file_paths.demand_matrices_dir / f"scale_factor_{gen_params['min_scale_factor']}" / f"locality_{gen_params['locality']}"
        dm_file = dm_dir / f"{graph_id}_{random_seed}.demands"
        with open(dm_file.resolve(), "r", encoding='utf-8') as fh:
            demands_line = next(fh)
            num_demands = int(demands_line.strip().split(" ")[1])
            # attributes line, not needed.
            next(fh)

            for _ in range(num_demands):
                d_values = next(fh).strip().split(" ")
                edge_src, edge_dest = int(d_values[1]), int(d_values[2])
                bw = convert(d_values[3])

                demands[(edge_src, edge_dest)] = bw


        instance = GraphState(num_nodes, edges,
                                np_data, ep_data, demands)
        g_hash = get_graph_hash(instance)

        instance.set_generator_info(topology_file, dm_file, graph_name,
                                    random_seed, cycle_start_seed, cycle_end_seed, var_type, var_number, g_hash)

        return instance

    @staticmethod
    def get_var_id(graph_name, var_type, var_num):
        if var_type is None and var_num is None:
            filename = graph_name
        else:
            filename = f"{graph_name}{var_type}{var_num:05d}"
        return filename

    @staticmethod
    def read_topology_file(topology_file, node_props_discard, edge_props_discard, encoding='utf-8'):
        with open(topology_file.resolve(), "r", encoding=encoding) as fh:
            nodes_line = next(fh)
            num_nodes = int(nodes_line.strip().split(" ")[1])
            node_props = next(fh).strip().split(" ")
            np_data = {prop_name: [] for prop_name in node_props if prop_name not in node_props_discard}

            for node_number in range(num_nodes):
                np_values = next(fh).strip().split(" ")
                for idx, prop_name in enumerate(node_props):
                    if prop_name in node_props_discard:
                        continue

                    prop_value = convert(np_values[idx])
                    np_data[prop_name].append(prop_value)

            # free line.
            next(fh)
            edges_line = next(fh)
            num_edges = int(edges_line.strip().split(" ")[1])
            edges = []

            edge_props = next(fh).strip().split(" ")
            ep_data = {prop_name: [] for prop_name in edge_props if
                       prop_name not in edge_props_discard.union({"src", "dest"})}


            seen_edges = set()
            for edge_number in range(num_edges):
                ep_values = next(fh).strip().split(" ")
                edge_src, edge_dest = int(ep_values[1]), int(ep_values[2])

                if ((edge_src, edge_dest)) not in seen_edges:
                    edges.append((edge_src, edge_dest))

                    for idx, prop_name in enumerate(edge_props):
                        if (idx in [1, 2]) or prop_name in edge_props_discard:
                            continue

                        prop_value = convert(ep_values[idx])
                        ep_data[prop_name].append(prop_value)

                    seen_edges.add((edge_src, edge_dest))

        return num_nodes, np_data, edges, ep_data

    def generate_many(self, graph_name, gen_params, random_seeds, objective_function, use_ecmp, var_types=None, var_count=None):
        if type(random_seeds) is tuple:
            random_seeds = list(itertools.chain(*random_seeds))

        graph_states = [self.generate(graph_name, gen_params, random_seed,
                                      random_seeds[0], random_seeds[-1])
                        for random_seed in random_seeds]
        if var_types is not None and var_count is not None:
            for var_type in var_types:
                for count in range(var_count):
                    graph_states.extend([self.generate(graph_name, gen_params, random_seed,
                                            random_seeds[0], random_seeds[-1], var_type, count)
                                            for random_seed in random_seeds])

        max_demand = max([gs.demands.max() for gs in graph_states])
        max_capacity = max([ max([gs.get_edge_property(edge, GraphState.CAPACITY_EPROP_NAME) for edge in gs.edge_list])  for gs in graph_states ])

        for gs in graph_states:
            gs.demands = gs.demands / max_demand
            for edge in gs.edge_list:
                scaled_cap = gs.get_edge_property(edge, GraphState.CAPACITY_EPROP_NAME) / max_capacity
                gs.set_edge_property(edge, GraphState.CAPACITY_EPROP_NAME, scaled_cap)

        ### compute the objective values, set on the datapoint.
        for gs in graph_states:
            kw = self.prepare_for_obj_eval(gs, use_ecmp)
            of_val = objective_function.compute(gs, **kw)
            gs.set_obj_fun_value(of_val)

        return graph_states

    def prepare_for_obj_eval(self, graph_state, use_ecmp):
        ret_kwargs = {}
        ret_kwargs['use_ecmp'] = use_ecmp
        ret_kwargs['top_file'] = self.construct_top_file(graph_state)
        ret_kwargs['dm_file'] = graph_state.dm_file_path
        ret_kwargs['wopt_name'] = 'uniform'
        return ret_kwargs

    def construct_top_file(self, gs):
        updated_fname = TmGenStateGenerator.get_var_id(gs.graph_name, gs.var_type, gs.var_number) \
                        + f"_{gs.generator_seed}" + ".graph"

        # updated_fname = f"{gs.graph_name}_{gs.generator_seed}"
        # updated_fname += ".graph"

        out_file = self.file_paths.rc_data_dir / updated_fname
        if not out_file.exists():
            self.write_modified_weights(gs, out_file)
        return out_file

    @staticmethod
    def write_modified_weights(graph_state, out_file):
        orig_top_file = graph_state.top_file_path
        with open(orig_top_file, "r") as orig_fh:
            with open(out_file, "w") as out_fh:
                all_orig_lines = orig_fh.readlines()
                edges_hit = False

                for line in all_orig_lines:
                    if line.startswith("EDGES"):
                        edges_hit = True
                        out_fh.write(line)
                        continue
                    if line.startswith("label src dest"):
                        out_fh.write(line)
                        continue

                    if not edges_hit:
                        out_fh.write(line)
                    else:
                        ep_values = line.strip().split(" ")
                        try:
                            edge_src, edge_dest = int(ep_values[1]), int(ep_values[2])
                        except IndexError:
                            continue
                        assigned_weight = graph_state.get_link_weight((edge_src, edge_dest))
                        ep_values[-1] = str(int(assigned_weight * 1000))
                        ep_values.append("\n")
                        out_fh.write(" ".join(ep_values))

    @staticmethod
    def compute_number_edges(n, edge_percentage):
        total_possible_edges = (n * (n - 1)) / 2
        return int(math.ceil((total_possible_edges * edge_percentage / 100)))

    @staticmethod
    def construct_network_seeds(eval_on_train, model_seed,
                                num_train_graphs, num_validation_graphs, num_test_graphs,
                                separate_graphs_per_model_seed=False):
        if not eval_on_train:
            dm_cycle_size = num_train_graphs + num_validation_graphs + num_test_graphs
            base_offset = get_run_number(model_seed) * dm_cycle_size if separate_graphs_per_model_seed else 0
            validation_seeds = list(range(base_offset, base_offset + num_validation_graphs))
            test_seeds = list(range(base_offset + num_validation_graphs, base_offset + num_validation_graphs + num_test_graphs))
            offset = base_offset + num_validation_graphs + num_test_graphs
            train_seeds = list(range(offset, offset + num_train_graphs))
        else:
            assert num_train_graphs == num_validation_graphs == num_test_graphs, "If using --eval_on_train, number of graphs should be the same."
            base_offset = get_run_number(model_seed) * num_train_graphs if separate_graphs_per_model_seed else 0
            validation_seeds = list(range(base_offset, base_offset + num_validation_graphs))
            test_seeds = list(range(base_offset, base_offset + num_test_graphs))
            train_seeds = list(range(base_offset, base_offset + num_train_graphs))
        return train_seeds, validation_seeds, test_seeds

    @staticmethod
    def split_from_seeds(graph_list, graph_seeds, var_types=None, var_count=None, disjoint_topologies=False):
        if not disjoint_topologies:
            train_start, train_end = 0, len(graph_seeds[0])
            val_start, val_end = len(graph_seeds[0]), len(graph_seeds[0]) + len(graph_seeds[1])
            test_start, test_end = len(graph_seeds[0]) + len(graph_seeds[1]), len(graph_seeds[0]) + len(graph_seeds[1]) + len(graph_seeds[2])

            total_graphs = len(graph_seeds[0]) + len(graph_seeds[1]) + len(graph_seeds[2])


            train_graphs = graph_list[train_start: train_end]
            validation_graphs = graph_list[val_start: val_end]
            test_graphs = graph_list[test_start: test_end]

            if var_types is not None and var_count is not None:
                for i, var_type in enumerate(var_types):
                    for count in range(var_count):
                        offset = total_graphs + (i * var_count * total_graphs) + (count * total_graphs)

                        train_graphs.extend(graph_list[offset + train_start: offset + train_end])
                        validation_graphs.extend(graph_list[offset + val_start: offset + val_end])
                        test_graphs.extend(graph_list[offset + test_start: offset + test_end])
        else:
            if var_count is None and disjoint_topologies is True:
                raise ValueError("shouldn't use disjoint_topologies with no variations; check args.")

            total_num_seeds = sum([len(seeds) for seeds in graph_seeds])

            # original graph goes in training set.
            train_graphs = graph_list[:total_num_seeds]
            validation_graphs = []
            test_graphs = []

            nt, nv, be = len(graph_seeds[0]), len(graph_seeds[1]), len(graph_seeds[2])

            train_count_start = 0
            train_count_end = math.ceil((nt / total_num_seeds) * var_count)

            val_count_start = train_count_end
            val_count_end = math.ceil(((nt + nv) / total_num_seeds) * var_count)

            eval_count_start = val_count_end
            eval_count_end = var_count

            offset = total_num_seeds
            for var_type in var_types:
                for count in range(var_count):
                    # a chunk of total_num_seeds graphs, have to decide where to put it.
                    graphs = graph_list[offset: offset + total_num_seeds]
                    if train_count_start <= count < train_count_end:
                        train_graphs.extend(graphs)
                    elif val_count_start <= count < val_count_end:
                        validation_graphs.extend(graphs)
                    else:
                        test_graphs.extend(graphs)
                    offset += total_num_seeds

        return train_graphs, validation_graphs, test_graphs


def get_subclasses(cls):
    for subclass in cls.__subclasses__():
        yield from get_subclasses(subclass)
        yield subclass


def retrieve_generator_class(generator_class_name):
    subclass = [c for c in get_subclasses(TmGenStateGenerator) if hasattr(c, "name") and generator_class_name == c.name][0]
    return subclass


def create_generator_instance(generator_class, file_paths):
    if type(generator_class) == str:
        generator_class = retrieve_generator_class(generator_class)
    gen_instance = generator_class(file_paths)
    return gen_instance


def convert(val):
    constructors = [int, float, str]
    for c in constructors:
        try:
            return c(val)
        except ValueError:
            pass


def extract_node_name(node_label):
    # follows 0_Name_0_0_0 pattern (e.g. Restena)
    if node_label[0].isdigit() and node_label[-1].isdigit() and node_label.count("_") >= 4:
        parts = node_label.split("_")
        filtered = parts[1:-3]
        return "_".join(filtered)
    # it doesn't (e.g. SNetDifferentCaps)
    else:
        return node_label
