import itertools
import math
import random
from copy import deepcopy
from pathlib import Path

import networkx as nx
import re

import numpy as np
import xxhash

from collections import Counter

from relnet.state.state_generators import TmGenStateGenerator

num_tries = 10000


def generate_topology_variations(topology_root, ec, cleanup=False):
    var_types = ec.topology_variations["var_types"]

    if len(var_types) == 0:
        print(f"No topology variations configured and none will be generated.")
        return

    var_count = ec.topology_variations["var_count"]
    if var_count > (1e5 - 1):
        raise ValueError(f"Configuration is only meant to support 10^5 graphs; adjust this limit manually.")

    var_percentage = ec.topology_variations["var_percentage"]

    rand_instance = random.Random(42)
    orig_graph_files = [str(p) for p in Path(topology_root).glob("*.graph")]
    gen_files = []
    all_hashes = []

    for var_type in var_types:
        for p in orig_graph_files:
            f = Path(p)
            graph_name = f.name.replace(".graph", "")

            num_nodes, np_data, edges, ep_data = TmGenStateGenerator.read_topology_file(f, set(), set())

            for var_num in range(var_count):
                success = False
                for _ in range(num_tries):
                    nx_graph = nx.DiGraph()
                    nx_graph.add_edges_from(edges)

                    num_nodes_var = deepcopy(num_nodes)
                    np_data_var = deepcopy(np_data)
                    edges_var = deepcopy(edges)
                    ep_data_var = deepcopy(ep_data)


                    if var_type == "NR":
                        num_nodes_var, np_data_var, edges_var, ep_data_var, rand_instance = remove_random_nodes(var_percentage, num_nodes_var, np_data_var, edges_var, ep_data_var, nx_graph,
                                    rand_instance)
                    elif var_type == "NA":
                        num_nodes_var, np_data_var, edges_var, ep_data_var, rand_instance = add_random_nodes(
                            var_percentage, num_nodes_var, np_data_var, edges_var, ep_data_var, nx_graph,
                            rand_instance)

                    elif var_type == "ER":
                        num_nodes_var, np_data_var, edges_var, ep_data_var, rand_instance = remove_random_edges(
                            var_percentage, num_nodes_var, np_data_var, edges_var, ep_data_var, nx_graph,
                            rand_instance)

                    elif var_type == "EA":
                        num_nodes_var, np_data_var, edges_var, ep_data_var, rand_instance = add_random_edges(
                            var_percentage, num_nodes_var, np_data_var, edges_var, ep_data_var, nx_graph,
                            rand_instance)
                    else:
                        raise ValueError(f"Topology variation {var_type} not known.")

                    el_hash = hash_edgelist(edges_var)
                    if el_hash not in all_hashes:
                        success = True
                        all_hashes.append(el_hash)
                        gen_file = write_variation_file(topology_root, graph_name, var_type, var_num,
                                             num_nodes_var, np_data_var, edges_var, ep_data_var)
                        gen_files.append(gen_file)
                        break
                    else:
                        continue
                if not success:
                    raise ValueError(
            f"number of tries exceeded in main loop; could not generate enough distinct graphs. Try setting a lower percentage.")

    county = Counter(all_hashes)
    print(f"generated {len(set(county.keys()))} distinct topology variations.")

    if cleanup:
        for gen_file in gen_files:
            gen_file.unlink(missing_ok=True)

def hash_edgelist(edge_list):
    hash_instance = xxhash.xxh32()
    hash_instance.update(np.array(edge_list))
    el_hash = hash_instance.intdigest()
    return el_hash

def write_variation_file(topology_root, graph_name, var_type, var_num, num_nodes_var, np_data_var, edges_var, ep_data_var):
    out_file = Path(topology_root) / (TmGenStateGenerator.get_var_id(graph_name, var_type, var_num) + ".graph")
    with open(out_file, "w") as out_fh:
        out_fh.write(f"NODES {num_nodes_var}\n")
        out_fh.write(f"{' '.join(list(np_data_var.keys()))}\n")
        for n in range(num_nodes_var):
            out_fh.write(f"{' '.join([str(val[n]) for prop, val in np_data_var.items()])}\n")

        out_fh.write("\n")
        out_fh.write(f"EDGES {len(edges_var)}\n")

        out_fh.write(f"{' '.join(['label', 'src', 'dest'])} ")
        out_fh.write(f"{' '.join(prop_name for prop_name in ep_data_var.keys() if prop_name not in['label', 'src', 'dest'])}\n")

        for i, e in enumerate(edges_var):
            out_fh.write(f"{ep_data_var['label'][i]} ")
            out_fh.write(f"{e[0]} {e[1]} ")

            # replacing delay manually since some topology variations cause the total delay to be less than tmgen threshold of 10ms.
            out_fh.write(
                f"{' '.join(str(prop_val[i] if prop_name != 'delay' else 10000) for prop_name, prop_val in ep_data_var.items() if prop_name not in ['label', 'src', 'dest'])}\n")

    return out_file


def remove_random_nodes(var_percentage, num_nodes_var, np_data_var, edges_var, ep_data_var, nx_graph, rand_instance):
    max_nodes = math.ceil((num_nodes_var * var_percentage) / 100)
    assert max_nodes >= 1
    n_to_remove = rand_instance.randint(1, max_nodes)
    success = False
    orig_node_list = list(range(num_nodes_var))

    for _ in range(num_tries):
        graph_clone = deepcopy(nx_graph)
        removed_nodes = []

        for i in range(n_to_remove):
            node_to_remove = rand_instance.choice([n for n in orig_node_list if n not in removed_nodes])

            graph_clone.remove_node(node_to_remove)
            removed_nodes.append(node_to_remove)

        if nx.is_strongly_connected(graph_clone):
            success = True
            break
    if not success:
        raise ValueError(
            f"number of tries exceeded; giving up for operation NR (node removal). Try setting a lower percentage.")
    node_list = [n for n in list(range(num_nodes_var)) if n not in removed_nodes]
    relabel_map = {n: idx for idx, n in enumerate(node_list)}
    num_nodes_var -= len(removed_nodes)
    np_data_var = {prop_name: ([val for i, val in enumerate(prop_values) if i not in removed_nodes]) for
                   prop_name, prop_values in np_data_var.items()}
    ep_data_var = {prop_name: ([val for i, val in enumerate(prop_values) if
                                (edges_var[i][0] not in removed_nodes and edges_var[i][1] not in removed_nodes)]) for
                   prop_name, prop_values in ep_data_var.items()}
    edges_var = list(filter(lambda edge: (edge[0] not in removed_nodes and edge[1] not in removed_nodes), edges_var))
    edges_var = [(relabel_map[e[0]], relabel_map[e[1]]) for e in edges_var]
    for orig_node, new_node in relabel_map.items():
        np_data_var['label'][new_node] = re.sub(r"_\d+", "_" + str(new_node), np_data_var['label'][new_node])
        np_data_var['label'][new_node] = re.sub(r"\d+_([A-Za-z])", str(new_node) + r"_\1", np_data_var['label'][new_node])

    for i, edge in enumerate(edges_var):
        ep_data_var['label'][i] = re.sub(r"_\d+", f"_{i}", ep_data_var['label'][i])

    return num_nodes_var, np_data_var, edges_var, ep_data_var, rand_instance


def add_random_nodes(var_percentage, num_nodes_var, np_data_var, edges_var, ep_data_var, nx_graph, rand_instance):
    max_nodes = math.ceil((num_nodes_var * var_percentage) / 100)
    assert max_nodes >= 1
    n_to_add = rand_instance.randint(1, max_nodes)

    orig_nodes = list(range(num_nodes_var))
    orig_num_nodes = num_nodes_var
    node_list = orig_nodes

    for i in range(n_to_add):
        node_id = rand_instance.choice(list(range(0, len(node_list)))) # insert position
        np_data_var['label'].insert(node_id, f"{node_id}_ART{i}_{node_id}_{node_id}_{node_id}")
        np_data_var['x'].insert(node_id, 0)
        np_data_var['y'].insert(node_id, 0)

        # relabel existing edges with new map.
        relabel_map = {}
        relabel_map.update({n: n for n in node_list if n < node_id})
        relabel_map.update({n: n + 1 for n in node_list if n >= node_id})

        edges_var = [(relabel_map[e[0]], relabel_map[e[1]]) for e in edges_var]

        node_to_connect = rand_instance.choice([n for n in node_list if not 'ART' in np_data_var['label'][n]])

        edges_var.append((node_id, node_to_connect))
        ep_data_var['label'].append(f'edge_{len(edges_var) - 1}')

        edges_var.append((node_to_connect, node_id))
        ep_data_var['label'].append(f'edge_{len(edges_var) - 1}')

        other_prop_names = [k for k in ep_data_var.keys() if k not in ['label']]
        sampled_vals = {}

        for other_prop in other_prop_names:
            sampled_vals[other_prop] = rand_instance.choice(ep_data_var[other_prop])

        for prop_name, prop_val in sampled_vals.items():
            ep_data_var[prop_name].append(prop_val)
            ep_data_var[prop_name].append(prop_val)

        num_nodes_var += 1
        node_list = list(range(num_nodes_var))


    new_labels = []
    for i, curr_label in enumerate(np_data_var['label']):
        upd_label = re.sub(r"_\d+", "_" + str(i), curr_label)
        upd_label = re.sub(r"\d+_([A-Za-z])", str(i) + r"_\1", upd_label)

        new_labels.append(upd_label)

    np_data_var['label'] = new_labels
    return num_nodes_var, np_data_var, edges_var, ep_data_var, rand_instance


def remove_random_edges(var_percentage, num_nodes_var, np_data_var, edges_var, ep_data_var, nx_graph, rand_instance):
    max_edges = math.ceil((len(edges_var) * var_percentage) / 100)
    assert max_edges >= 1
    n_to_remove = rand_instance.randint(1, max_edges)
    success = False
    for _ in range(num_tries):
        graph_clone = deepcopy(nx_graph)
        removed_edges = []

        for i in range(n_to_remove):
            edge_to_remove = rand_instance.choice([e for e in edges_var if (e not in removed_edges and (e[1], e[0]) not in removed_edges)])

            graph_clone.remove_edge(edge_to_remove[0], edge_to_remove[1])
            graph_clone.remove_edge(edge_to_remove[1], edge_to_remove[0])

            removed_edges.append(edge_to_remove)

        if nx.is_strongly_connected(graph_clone):
            success = True
            break

    if not success:
        raise ValueError(
            f"number of tries exceeded; giving up for operation ER (edge removal). Try setting a lower percentage.")


    ep_data_var = {prop_name: ([val for i, val in enumerate(prop_values) if
                                (edges_var[i] not in removed_edges)]) for
                   prop_name, prop_values in ep_data_var.items()}
    edges_var = list(filter(lambda edge: (edge not in removed_edges), edges_var))

    for i, edge in enumerate(edges_var):
        ep_data_var['label'][i] = re.sub(r"_\d+", f"_{i}", ep_data_var['label'][i])

    return num_nodes_var, np_data_var, edges_var, ep_data_var, rand_instance


def add_random_edges(var_percentage, num_nodes_var, np_data_var, edges_var, ep_data_var, nx_graph, rand_instance):
    max_edges = math.ceil((len(edges_var) * var_percentage) / 100)
    assert max_edges >= 1
    n_to_add = rand_instance.randint(1, max_edges)

    for i in range(n_to_add):
        elig_edges = list(nx.non_edges(nx_graph))

        edge_from, edge_to = rand_instance.choice(elig_edges)

        edges_var.append((edge_from, edge_to))
        nx_graph.add_edge(edge_from, edge_to)
        ep_data_var['label'].append(f'edge_{len(edges_var) - 1}')

        edges_var.append((edge_to, edge_from))
        nx_graph.add_edge(edge_to, edge_from)
        ep_data_var['label'].append(f'edge_{len(edges_var) - 1}')

        other_prop_names = [k for k in ep_data_var.keys() if k not in ['label']]
        sampled_vals = {}

        for other_prop in other_prop_names:
            sampled_vals[other_prop] = rand_instance.choice(ep_data_var[other_prop])

        for prop_name, prop_val in sampled_vals.items():
            ep_data_var[prop_name].append(prop_val)
            ep_data_var[prop_name].append(prop_val)

    return num_nodes_var, np_data_var, edges_var, ep_data_var, rand_instance


