import os
import numpy as np
import argparse
from slurm.awsslurm import SlurmJob

SCRIPT_PATH = "aggregate_results.py"

if __name__ == "__main__":
    # Command line arguments
    parser = argparse.ArgumentParser(description="Running jobs on SLURM cluster for results aggregation")

    parser.add_argument(
        '--graph_type', 
        help ='Graph type for which to aggregate: ER, SF, FC, GRP', 
        type=str,
        required=True
    )

    parser.add_argument(
        '--task', 
        help ='Type of task: inference, standardized', 
        type=str,
        required=True
    )

    parser.add_argument(
        '--partition', 
        help ='Slurm cluster partition', 
        type=str,
        default="cpu"
    )

    parser.add_argument(
        '--scenarios', 
        nargs='+', 
        help ='Scenarios for which to aggregate', 
        type=str,
    )

    parser.add_argument(
        '--methods', 
        nargs='+', 
        help ='Methods for which to aggregate', 
        type=str,
    )

    args = parser.parse_args()

    if args.scenarios is None:
        args.scenarios = ["vanilla", "pnl", "unfaithful", "confounded", "measure_err", "timino"]

    if args.methods is None:
        args.methods = ["score","das","nogam","cam","diffan","pc","ges","grandag","resit","lingam","random"]

    assert not('linear' in args.scenarios and args.task == "inference")


    for method in args.methods:
        for scenario in args.scenarios:

            script_args = ("" +
            f"--methods {method} " +
            f"--scenarios {scenario} " +
            f"--graph_type {args.graph_type} " +
            f"--task {args.task} "
            )

            slurm_time="02:00:00"
            job = SlurmJob(
                SCRIPT_PATH,
                name=f"aggregate_{method}_{scenario}_{args.graph_type}",
                time=slurm_time,
                gpu=False,
                ngpus=None,
                afterok=None,
                ntasks_per_node=None, 
                script_args=script_args,
                partition=args.partition
            )

            slurm_job_id = job()