"""Module to handle data generation for parallel inference with Slurm
"""
import os
import random
import numpy as np
import argparse
import time
from sklearn.model_selection import ParameterGrid
from utils._utils import init_scancel_script
from slurm.awsslurm import SlurmJob

DATA_DIR="/efs/data" # Base director or data storage
WORKSPACE="/home/ec2-user/causal-benchmark"
SCRIPT_PATH="generate_dataset.py"


def args_sanity_check(args):
    # Check scenarios
    allowed_scenarios = ["vanilla", "confounded", "linear", "measure_err", "timino", "unfaithful", "pnl"]
    for scenario in args.scenarios:
        if scenario not in allowed_scenarios:
            raise ValueError(f"Scenario {scenario} not allowed!")

    # Check noise distr
    allowed_noise = ["gauss", "nonlin_weak", "nonlin_mid", "nonlin_strong", "all"]
    for noise in args.noise_distr:
        if noise not in allowed_noise:
            raise ValueError(f"Noise distribution {noise} not allowed!")
        
    # Check graph_type 
    allowed_types = ["ER", "SF", "FC", "GRP"]
    if args.graph_type not in allowed_types:
        raise ValueError(f"Graph type {args.graph_type} not allowed!")

    # Check graph_size_options
    allowed_sizes = ["small", "medium", "large20", "large30", "large50", "all"]
    for size in args.graph_size_options:
        if size not in allowed_sizes:
            raise ValueError(f"Size {size} not allowed!")

    # Check graph_density_options (FC and full match)
    allowed_densities = ["sparse", "dense", "full", "cluster"]
    for density in args.graph_density_options:
        if density not in allowed_densities:
            raise ValueError(f"Graph type {density} not allowed!")
        if density == "full" and args.graph_type != "FC":
            raise ValueError(f"Graph type {density} not compatible with graph type {args.graph_type}!")
        if density == "cluster" and args.graph_type != "GRP":
            raise ValueError(f"Graph type {density} not compatible with graph type {args.graph_type}!")

    # Check consistency of args combinations
    if "small" in args.graph_size_options:
        if args.graph_type == "SF" or args.graph_type == "GRP":
            raise ValueError(f"Incompatible graph type '{args.graph_type}' and graph size 'small'")
        
    if args.graph_type == "FC":
        assert len(args.graph_density_options) == 1
        if args.graph_density_options[0] != "full":
            raise ValueError(f"Incompatible graph type '{args.graph_type}' and graph density option '{args.graph_density_options[0]}'")

    if args.graph_type == "GRP":
        assert len(args.graph_density_options) == 1
        if args.graph_density_options[0] != "cluster":
            raise ValueError(f"Incompatible graph type '{args.graph_type}' and graph density option '{args.graph_density_options[0]}'")


if __name__ == "__main__":
    # Command line arguments
    parser = argparse.ArgumentParser(description="Running jobs on SLURM cluster")

    parser.add_argument(
        "--seed",
        default=42,
        type=int,
        help="Seed for reproducibility"
    )
    parser.add_argument(
        "--graph_type",
        default="ER",
        type=str,
        help="Algorithm for generation of synthetic graphs. Accepted values are ['ER', 'SF', 'GRP', 'FC']"
    )
    parser.add_argument(
        "--noise_distr",
        nargs='+',
        type=str,
        help="Select datasets by distiribution of the noise. Accepted values are ['gauss', 'random']",
        required = True
    )
    parser.add_argument(
        "--partition",
        default="cpu",
        type=str,
        help="Cluster partition: accepted values ['cpu', 'gpu']. If free, select 'cpu'"
    )
    parser.add_argument(
        "--num_datasets",
        type=int, 
        help="Number of seeds for configuration. Suggested 20",
        required=True
    )
    parser.add_argument(
        '--scenarios', 
        nargs='+', 
        help ='Scenarios for which to generate data', 
        type=str,
        required=True
    )
    parser.add_argument(
        '--graph_size_options', 
        nargs='+', 
        help ='Size of the graph: small medium large20 large30 large50', 
        type=str,
        required=True
    )
    parser.add_argument(
        '--graph_density_options', 
        nargs='+', 
        help ='Density of the graph: sparse, dense, full (only for graph_type FC)', 
        type=str,
        required=True
    )

    args = parser.parse_args()
    args_sanity_check(args)

    np.random.seed(args.seed)
    run_seeds = np.random.choice(range(100), size=args.num_datasets, replace=False)
        
    # Scenarios (assumptions)
    scenarios_params = {
        "confounded" : [0.1, 0.2], # rho: probability of confounded pair
        "linear" : [0.33, 0.66, 0.99], # p_linear: probability of linear mechanism
        "measure_err" : [0.2, 0.4, 0.6, 0.8], # gamma: parametrize % of variance explained by measure error
        "unfaithful" : [0.25, 0.5, 0.75, 1.0], # 
        "pnl" : [3], # exponent of the nonlinear polynomial
    }

    # DataGenerator Parameters
    graph_size2nodes = {
        "small" : [5],
        "medium" : [10],
        "large20" : [20],
        "large30" : [30],
        "large50" : [50]
    }

    if "all" in args.graph_size_options:
        args.graph_size_options = list(graph_size2nodes.keys())
    
    # data_samples_options = [100, 1000, 2000] # kill methods that are too slow at 2000
    data_samples_options = [
        100, 1000
    ]

    # Initialize script to scancel slurm jobs
    scancel_script = os.path.join(WORKSPACE, "utilities", "scancel_jobs.sh")
    init_scancel_script(scancel_script)

    noise_std_support = (0.5, 1.0)

    params_grid = {
    "graph_size" : args.graph_size_options,
    "graph_density" : args.graph_density_options,
    }
    data_configs = ParameterGrid(params_grid)
    
    batch_dimension = 500
    num_submissions = 0
    sleep_time = 120
    for noise in args.noise_distr:
        for scenario in args.scenarios:
            for config in data_configs:
                for id, seed in enumerate(run_seeds):
                    graph_size = config["graph_size"]
                    num_nodes = random.choice(graph_size2nodes[graph_size])

                    graph_density = config["graph_density"]

                    script_args = ("" +
                    f"--seed {seed} " +
                    f"--graph_type {args.graph_type} " +
                    f"--noise_distr {noise} " +
                    f"--scenario {scenario} " +
                    f"--graph_size {graph_size} " +
                    f"--num_nodes {num_nodes} " +
                    f"--graph_density {graph_density} " +
                    f"--samples_size_list {' '.join([str(s) for s in data_samples_options])} " +
                    f"--dataset_id {id} " +
                    f"--output_folder {DATA_DIR}"
                    )

                    # Optional arguments
                    if scenario in scenarios_params.keys():
                        script_args += f" --scenario_params {' '.join([str(v) for v in scenarios_params[scenario]])} "

                    slurm_time="24:00:00"
                    job = SlurmJob(
                        SCRIPT_PATH,
                        name=f"datagen_{scenario}_{graph_size}_{graph_density}_{id}",
                        time=slurm_time,
                        gpu=False,
                        ngpus=None,
                        afterok=None,
                        ntasks_per_node=None, 
                        script_args=script_args,
                        partition=args.partition
                    )

                    slurm_job_id = job()

                    num_submissions += 1
                    if num_submissions % batch_dimension == 0:
                        print(f"Num submissions: {num_submissions}: taking a 2:30 minutes nap...", end=" ", flush=True)
                        time.sleep(sleep_time)
                        print("Restart!")

    print("All datasets generated! Check whether there are no error")
