import argparse
import subprocess
from pathlib import Path

import sys
sys.path.append('/relnet')

import numpy as np
from relnet.state.topology_variations import generate_topology_variations
from relnet.evaluation.experiment_conditions import get_conditions_for_experiment, get_default_file_paths

TM_GEN_BIN = "tm-gen"
TM_RUN_BIN = "tm-run"

def create_topology_variations(args):
    topology_root = args.topology_root
    dummy_ec = get_conditions_for_experiment("topvar", "dummy", args.eval_on_train, True, args.dms_mult)
    generate_topology_variations(topology_root, dummy_ec)


def call_tmgen(args, which):
    tms_dir = Path(args.tm_root) / f"scale_factor_{args.min_scale_factor}" / f"locality_{str(args.locality)}"
    tms_dir.mkdir(parents=True, exist_ok=True)

    print(f"got eval on train as {args.eval_on_train}")
    dummy_ec = get_conditions_for_experiment(which, "dummy", args.eval_on_train, True, args.dms_mult)

    seed_set = set()
    for model_seed in dummy_ec.experiment_params['model_seeds']:
        dummy_ec.set_generator_seeds(model_seed)
        for seed_list in [dummy_ec.train_seeds, dummy_ec.validation_seeds, dummy_ec.test_seeds]:
            seed_set.update(seed_list)

    how_many = len(seed_set)


    command = [TM_GEN_BIN,
               "--topology_root", args.topology_root,
               "--min_scale_factor", str(args.min_scale_factor),
               "--locality", str(args.locality),
               "--tm_count", str(how_many),
               "--threads", str(args.threads)
               ]
    subprocess.run(command, capture_output=False)

def call_tmrun(args):
    command = [TM_RUN_BIN,
               "--topology_root", args.topology_root,
               "--tm_root", args.tm_root]
    subprocess.run(command, capture_output=False)

def main():
    parser = argparse.ArgumentParser(description="Generate average demand matrices and resulting routing configurations.")
    parser.add_argument("--topology_root", type=str, required=True, help="Topologies root directory.")
    parser.add_argument("--tm_root", type=str, required=True, help="Demand matrices root directory.")

    parser.add_argument("--dms_mult", type=float, required=False, help="Multiplier for number of DMs to generate (applied to 1000 main / 20 topvar). Default 1.")

    parser.add_argument("--min_scale_factor", type=float, required=True, help="tm-gen min scale factor.")
    ## TODO: note that this assumes an int locality, which may not be strictly the case.
    parser.add_argument("--locality", type=int, required=True, help="tm-gen locality factor.")
    parser.add_argument('--eval_on_train', dest='eval_on_train', action='store_true',
                        help="Whether to train/validate/test on the same sets of DMs (if true) or disjoint.")
    parser.add_argument("--which", required=True, type=str,
                        help="Which experiment to setup for",
                        choices=["main", "topvar"])

    parser.add_argument("--threads", type=int, required=True, help="number of threads to use when calling tm-gen.")


    parser.add_argument('--skip_tmrun', dest='skip_tmrun', action='store_true',
                        help="Whether to skip tm-run tool, which generates routing configurations.")

    parser.set_defaults(dms_mult=1.0)
    parser.set_defaults(eval_on_train=False)
    parser.set_defaults(skip_tmrun=False)

    args = parser.parse_args()

    if args.which == "topvar":
        create_topology_variations(args)

    call_tmgen(args, args.which)

    if not args.skip_tmrun:
        print(f"executing tm-run step as configured.")
        call_tmrun(args)

if __name__ == "__main__":
    main()


