import argparse
import math
from collections import Counter
from pathlib import Path

import json
import networkx as nx
import numpy as np
import traceback

from relnet.evaluation.experiment_conditions import get_default_file_paths
from relnet.state.graph_state import GraphState
from relnet.state.state_generators import TmGenStateGenerator, extract_node_name


def run_initial_part():
    fp = get_default_file_paths()
    all_tops_dir = fp.topologies_dir

    min_size = 20
    max_size = 40

    fits_size_and_diff_caps = []
    fits_size_only = []

    for top_file in all_tops_dir.iterdir():
        if top_file.is_file() and top_file.suffix == ".graph":
            gname = top_file.stem
            num_nodes, np_data, edges, ep_data = TmGenStateGenerator.read_topology_file(top_file, set(), set(), encoding='utf-8')
            if not (min_size <= num_nodes <= max_size):
                continue

            node_list = list(range(num_nodes))
            G = nx.DiGraph()
            G.add_nodes_from(node_list)
            G.add_edges_from(edges)

            diam = nx.diameter(G)

            number_cap_values = len(set(ep_data[GraphState.CAPACITY_EPROP_NAME]))
            if number_cap_values > 1:
                fits_size_and_diff_caps.append((gname, num_nodes, diam))
            else:
                fits_size_only.append((gname, num_nodes, diam))

    fits_size_and_diff_caps = sorted(fits_size_and_diff_caps, key=lambda x: x[0])
    fits_size_only = sorted(fits_size_only, key=lambda x: x[0])
    print(f"fits size constraints and has different capacities: <<{len(fits_size_and_diff_caps)}>> graphs.")
    print(f"fits size but has uniform capacities: <<{len(fits_size_only)}>> graphs.")

    fsdc = [g[0] for g in fits_size_and_diff_caps]
    fso = [g[0] for g in fits_size_only]

    print(fsdc)
    print(f"({' '.join(fsdc)})")

    print(fso)
    print(f"({' '.join(fso)})")

    # get rid of graphs that don't fit the size / capacity restrictions.
    for top_file in all_tops_dir.iterdir():
        if top_file.stem not in fsdc:
            print(f"deleting {top_file.stem}")
            top_file.unlink(missing_ok=True)

    # get rid of graphs that don't have numeric ordering of labels, but lexicographic (or something else?)
    # since it's ambiguous what should be the actual node id.
    for top_file in all_tops_dir.iterdir():
        all_ids = []
        with open(top_file.resolve(), "r", encoding='utf-8') as fh:
            nodes_line = next(fh)
            num_nodes = int(nodes_line.strip().split(" ")[1])
            next(fh)

            for node_number in range(num_nodes):
                int_id = int(next(fh).strip().split(" ")[0].split("_")[0])
                all_ids.append(int_id)

        is_numerically_ordered = (all_ids == list(range(num_nodes)))
        if not is_numerically_ordered:
            top_file.unlink()

    for top_file in all_tops_dir.iterdir():
        all_names = []
        with open(top_file.resolve(), "r", encoding='utf-8') as fh:
            nodes_line = next(fh)
            num_nodes = int(nodes_line.strip().split(" ")[1])
            next(fh)

            for node_number in range(num_nodes):
                np_values = next(fh).strip().split(" ")
                label = np_values[0]
                node_name = extract_node_name(label)

                all_names.append(node_name)

        ctr = Counter(all_names)

        fixed_top_file = all_tops_dir / (top_file.stem + "_fixed" + top_file.suffix)

        mod_count = dict(ctr)
        with open(top_file.resolve(), "r", encoding='utf-8') as fh:
            with open(fixed_top_file.resolve(), "w") as out_fh:
                all_lines = fh.readlines()
                for i, line in enumerate(all_lines):
                    out_line = line

                    if i == 0:
                        num_nodes = int(line.split(" ")[1])
                    else:
                        if 1 < i < num_nodes + 2:
                            # print(f"checking line {line.strip()}")
                            line_parts = line.strip().split(" ")
                            label = line_parts[0]
                            node_name = extract_node_name(label)

                            if ctr[node_name] > 1:
                                true_name = node_name + str(ctr[node_name] - mod_count[node_name])
                                mod_count[node_name] -= 1
                            else:
                                true_name = node_name

                            true_label = label.replace(node_name, true_name)
                            out_line = " ".join([true_label] + line_parts[1:]) + "\n"

                        # artificially increase delays on some links to get around tm-gen limit.
                        elif line.startswith("edge"):
                            delay_str = line.split(" ")[-1]
                            delay_int = int(delay_str)

                            # make sure delay is at least 1000ms.
                            for _ in range(4 - (len(delay_str) - 1)):
                                delay_int *= 10

                            out_line = " ".join(line.split(" ")[:-1] + [str(delay_int)]) + "\n"

                    out_fh.write(out_line)

    # rename fixed files and get rid of original ones
    for top_file in all_tops_dir.iterdir():
        if not ("fixed" in top_file.stem):
            top_file.unlink()

    for top_file in all_tops_dir.iterdir():
        new_filename = "_".join(top_file.stem.split("_")[:-1]) + top_file.suffix
        new_file = all_tops_dir / new_filename
        top_file.rename(new_file)

    final_graphs = ["\"" + top_file.stem + "\"" for top_file in all_tops_dir.iterdir()]
    print(f"({' '.join(final_graphs)})")


def run_filter_part():
    fp = get_default_file_paths()
    all_tops_dir = fp.topologies_dir

    initial_graphs = [top_file.stem for top_file in all_tops_dir.iterdir()]
    routing_model_to_filter = "ssp"
    selection_suffix = "selection"
    sd_threshold = 5e-2

    root_expdir = Path(fp.parent_dir)

    num_keep = 0
    num_drop = 0
    to_keep = []
    # print(list(root_expdir.iterdir()))
    for g in initial_graphs:
        try:
            exp_dir = root_expdir / f"{g}_{routing_model_to_filter}_{selection_suffix}"
            metadata_file = exp_dir / "graph_ds" / "dataset_metadata.json"

            graph_mlu_values = []
            with open(metadata_file, "r") as fh:
                as_dict = json.loads(fh.read())

                for k, v in as_dict["per_graph_metadata"].items():
                    graph_mlu_values.append(v['obj_fun_value'])

            print(graph_mlu_values)
            # mlu_sd = np.std(np.array(graph_mlu_values))
            # if mlu_sd < sd_threshold:


            zeropc = np.percentile(graph_mlu_values, 0)
            ninetypc = np.percentile(graph_mlu_values, 90)

            if abs(zeropc - ninetypc) < 0.001:
                # print(f"getting rid of {g} with SD {mlu_sd}")
                print(f"getting rid of {g}")
                top_file = fp.topologies_dir / f"{g}.graph"
                top_file.unlink(missing_ok=True)
                num_drop += 1
            else:
                # print(f"keeping {g} with SD {mlu_sd}")
                print(f"keeping {g}")
                num_keep += 1
                to_keep.append(g)

        except Exception:
            print(f"<<NOTE>> failed for graph {g}!")
            traceback.print_exc()

    print(f"total graphs: {num_keep + num_drop}")
    print(f"keep: {num_keep} ({(num_keep / (num_keep + num_drop)) * 100:.3f}%)")
    print(f"drop: {num_drop} ({(num_drop / (num_keep + num_drop)) * 100:.3f}%)")

    final_graphs = ["\"" + g + "\"" for g in to_keep]
    print(f"({' '.join(final_graphs)})")


def main():
    parser = argparse.ArgumentParser(description="Run graph selection and cleanup code.")
    parser.add_argument("--part", required=True, type=str,
                        help="Whether to run initial part or filtering by MLU for different DMs.",
                        choices=["initial", "filter"])
    args = parser.parse_args()
    if args.part == "initial":
        run_initial_part()
    elif args.part == "filter":
        run_filter_part()



if __name__ == '__main__':
    main()