"""Module to handle parallel inference with Slurm
ASCII tree of data structure

data
└── hyperparameters (or inference)
    ├── confounded
    │   ├── medium_dense
    │   ├── medium_sparse
    │   │   ├── data{i}.csv
    │   │   └── groundtruth{i}.csv
    │   ├── small_dense
    │   └── small_sparse
    └── vanilla
        ├── medium_dense
        ├── medium_sparse
        ├── small_dense
        └── small_sparse
"""
import os
import shutil
import random
import numpy as np
import argparse
from utils._utils import ConsoleManager
from slurm._utils import LaunchLargeExperiments
import json
from slurm.awsslurm import SlurmJob

# Debug by pressing F5

################################# Directories #################################
DATA_DIR="/efs/data" # Base directory for data storage
LOGS_DIR="/efs/tmp/causal-benchmark-logs" # Base director for inference output storage
WORKSPACE="/home/ec2-user/causal-benchmark"
PARAMS_DIR="/home/ec2-user/causal-benchmark/hyperparameters"
SCRIPT_PATH={"inference" : "dataset_inference.py", "standardized" : "dataset_inference.py", "tuning" : "dataset_tuning.py"}
SCRATCH_PATH = "/home/ec2-user/causal-benchmark/utilities/extract_to_scratch_and_run.sh"

PARAMS_INPUT_FOLDER = "raw"
PARAMS_GRID_FOLDER = "params_grid"

################################# Utilities ################################

def args_sanity_check(args):
    # Check task 
    if args.task not in ["inference", "tuning", "standardized"]:
        raise ValueError("Wrong task")

    # Check scenarios
    allowed_scenarios = ["vanilla", "confounded", "linear", "measure_err", "timino",  "unfaithful", "pnl"]
    if args.scenario not in allowed_scenarios:
        raise ValueError(f"Scenario {scenario} not allowed!")
        
    # Check methods
    allowed_methods = ["ges", "das", "score", "nogam", "cam", "pc", "diffan", "grandag", "lingam", "resit", "random"]
    if args.method not in allowed_methods:
        raise ValueError(f"Method {args.method} not allowed!")
        
        
    # Check noise distr
    allowed_noise = ["gauss", "nonlin_weak", "nonlin_mid", "nonlin_strong"]
    if args.noise not in allowed_noise:
        raise ValueError(f"Noise distribution {noise} not allowed!")
        
    # Check graph type
    allowed_graphs = ["ER", "SF", "FC", "GRP"]
    if args.graph_type not in allowed_graphs:
        raise ValueError(f"Graph type {args.graph_type} not allowed!")
    

def make_base_script(
    base_data_dir : str,
    base_output_dir : str,
    noise_distr : str,
    scenario : str,
    scenario_param : str,
    data_config : str,
    method : str
):
    """Part of the slurm job script shared by inference and tuning tasks
    """
    script_args = ("" +
        f"--base_data_dir {base_data_dir} " + 
        f"--base_output_dir {base_output_dir} " + 
        f"--data_config {data_config} " + 
        f"--method {method} " + 
        f"--noise_distr {noise_distr} " +
        f"--scenario {scenario} " +
        f"--scenario_param {scenario_param} "
    )
    return script_args

def scratch_and_run_prepend(base_data_dir, scenario, scenario_param, data_config, scratch_path):
    path_to_config = os.path.join(base_data_dir, scenario, scenario_param, data_config)
    scratch_and_run_prepend = ""
    for file in os.listdir(path_to_config):
        if file.startswith("data_"):
            dataset_id = file.split(".")[0].split("_")[1]
            tar_data_path = os.path.join(path_to_config, f"data_{dataset_id}.csv.tar.gz") 
            scratch_and_run_prepend += scratch_path + " " + tar_data_path + " -- "
    return scratch_and_run_prepend


################################# Run program #################################
if __name__ == "__main__":
    # Command line arguments
    parser = argparse.ArgumentParser(description="Running jobs on SLURM cluster")
    parser.add_argument(
        "--seed", 
        default=42, 
        type=int, 
        help="seed for reproducibility"
    )
    parser.add_argument(
        "--task",
        default="inference",
        type=str,
        help="Data can be generated for inference or for hyperparameters tuning. Accepted values ['inference', 'tuning']"
    )
    parser.add_argument(
        "--partition",
        default="cpu",
        type=str,
        help="Name of the slurm partition"
    )
    parser.add_argument(
        "--graph_type",
        type=str,
        help="Algorithm for generation of synthetic graphs. Accepted values are ['ER', 'SF', 'GRP', 'FC']",
        required=True
    )
    parser.add_argument(
        "--noise",
        type=str,
        help="Select datasets by distiribution of the noise. Accepted values are ['gauss', 'random']",
        required = True
    )
    parser.add_argument(
        '--method',
        help ='Mehods for the experimental inference', 
        type=str,
        required=True
    )
    parser.add_argument(
        '--scenario', 
        help ='Scenarios over which to run inference', 
        type=str,
        required=True
    )
    parser.add_argument(
        '--scenario_param', 
        help ="Scenario's parameters over which to make inference", 
        type=str,
        required=True
    )
    parser.add_argument(
        '--data_config', 
        help ="E.g. 1000_medium_sparse", 
        type=str,
        required=True
    )

    args = parser.parse_args()
    args_sanity_check(args)

    # Set random seed for reproducibility (torch.manual_seed cause error!)
    random.seed(args.seed)
    np.random.seed(args.seed)

    data_dir=DATA_DIR
    logs_dir=LOGS_DIR
    script_path=SCRIPT_PATH[args.task]
    params_dir=PARAMS_DIR
    scratch_path = SCRATCH_PATH
    task = args.task
    method = args.method
    scenario = args.scenario
    scenario_param = args.scenario_param
    noise = args.noise
    graph_type = args.graph_type
    data_config = args.data_config
    partition = args.partition
    param_grids_file = os.path.join(params_dir, PARAMS_GRID_FOLDER, method + ".json")
    base_data_dir = os.path.join(data_dir, graph_type, noise)
    base_output_dir = os.path.join(logs_dir, task, graph_type)
    with open(os.path.join(params_dir, PARAMS_INPUT_FOLDER, method + ".json"), "r") as f:
        params = json.load(f)
        reg_params = params["non-tunable"]

    reg_param_name = list(reg_params.keys())[0] # TODO: this is very dirty!

    tuning_results_base_dir = os.path.join(logs_dir, "tuning", graph_type,\
                    noise, scenario, scenario_param, method, data_config)

    script_args = make_base_script(
        base_data_dir, base_output_dir, noise, scenario, scenario_param, data_config, method
    )

    script_args += (" " +
        f"--reg_params {' '.join([str(x) for x in reg_params[reg_param_name]])} " +
        f"--reg_param_name {reg_param_name} " + 
        f"--tuning_results_base_dir {tuning_results_base_dir} " +
        f"--params_file {param_grids_file}" 
    )

    # Sanity check
    print(os.path.join(base_data_dir, scenario, scenario_param, data_config))
    assert os.path.exists(os.path.join(base_data_dir, scenario, scenario_param, data_config))
    assert len(os.listdir(os.path.join(base_data_dir, scenario, scenario_param, data_config))) > 0

    set_mem = False
    num_cpus = 1
    if "large50" in data_config and method in ["resit", "ges", "pc", "diffan", "grandag"]:
        set_mem = True
    if method in ["diffan", "grandag"]:
        num_cpus = 18
    

    slurm_time="48:00:00"
    job = SlurmJob(
        script_path,
        name=f"inference_{method}_{scenario_param}_{data_config}",
        time=slurm_time,
        gpu=False,
        ngpus=None,
        afterok=None,
        ntasks_per_node=None, 
        script_args=script_args,
        partition=partition,
        scratch_and_run_prepend=scratch_and_run_prepend(base_data_dir, scenario, scenario_param, data_config, scratch_path),
        set_mem=set_mem,
        num_cpus=num_cpus
    )

    _ = job()