import sys
import argparse
import subprocess

SPANNING_TREE_PORT = 26545

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--vw", help="Path to VW binary to use", type=str, required=True
    )
    parser.add_argument(
        "--spanning_tree",
        help="Path to spanning tree binary to use",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--data_files",
        help="Data files to use, one per node",
        type=str,
        nargs="+",
        required=True,
    )
    parser.add_argument(
        "--test_file",
        help="Test input file to feed to final model",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--vw_args", help="Extra vw arguments to to use", type=str, default=""
    )
    parser.add_argument("--prediction_file", help="", type=str, default=None)
    args = parser.parse_args()

    spanning_tree_args = [
        args.spanning_tree,
        "--nondaemon",
        "-p",
        str(SPANNING_TREE_PORT),
    ]
    print("Starting spanning_tree with args: " + " ".join(spanning_tree_args[1:]))
    spanning_tree_proc = subprocess.Popen(
        spanning_tree_args, stdout=subprocess.PIPE, stderr=subprocess.PIPE
    )

    split_vw_args = args.vw_args.split()
    vw_procs = []
    for index, data_file in enumerate(args.data_files):
        cmd_args = [
            args.vw,
            "--span_server",
            "localhost",
            "--total",
            str(len(args.data_files)),
            "--node",
            str(index),
            "--unique_id",
            "1234",
            "-d",
            data_file,
            "--span_server_port",
            str(SPANNING_TREE_PORT),
        ]
        cmd_args.extend(split_vw_args)
        if index == len(args.data_files) - 1:
            cmd_args.extend(["-f", "final.model"])
        print("Starting VW with args: " + " ".join(cmd_args[1:]))
        vw_procs.append(
            subprocess.Popen(cmd_args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        )

    for proc in vw_procs:
        return_code = proc.wait()
        print("VW succeeded")
        if return_code != 0:
            print("VW failed:")
            if proc.stdout:
                print("STDOUT: \n" + proc.stdout.read().decode("utf-8"))
            if proc.stderr:
                print("STDERR: \n" + proc.stderr.read().decode("utf-8"))
            spanning_tree_proc.kill()
            sys.exit(1)

    spanning_tree_proc.kill()

    cmd_args = [args.vw, "-d", args.test_file, "-i", "final.model", "-t"]
    cmd_args.extend(split_vw_args)
    if args.prediction_file is not None:
        cmd_args.extend(["-p", args.prediction_file])
    print("Running test on produced model...")
    print("Running VW with args: " + " ".join(cmd_args[1:]))

    cmd_args.extend(split_vw_args)
    subprocess.Popen(cmd_args).wait()
