"""Module to handle parallel inference with Slurm
ASCII tree of data structure

data
└── hyperparameters (or inference)
    ├── confounded
    │   ├── medium_dense
    │   ├── medium_sparse
    │   │   ├── data{i}.csv
    │   │   └── groundtruth{i}.csv
    │   ├── small_dense
    │   └── small_sparse
    └── vanilla
        ├── medium_dense
        ├── medium_sparse
        ├── small_dense
        └── small_sparse
"""
import os
import shutil
import random
import numpy as np
import argparse
from utils._utils import ConsoleManager
from slurm._utils import LaunchLargeExperiments

# Debug by pressing F5

################################# Directories #################################
DATA_DIR="/efs/data" # Base directory for data storage
LOGS_DIR="/efs/tmp/causal-benchmark-logs" # Base director for inference output storage
WORKSPACE="/home/ec2-user/causal-benchmark"
PARAMS_DIR="/home/ec2-user/causal-benchmark/hyperparameters"
SCRIPT_PATH={"inference" : "dataset_inference.py", "standardized" : "dataset_inference.py", "tuning" : "dataset_tuning.py"}
SCRATCH_PATH = "/home/ec2-user/causal-benchmark/utilities/extract_to_scratch_and_run.sh"

################################# Utilities ################################
def delete_past_logs(base_dir, scenarios, methods):
    """Delete predictions adjacency matrices and metadata from previous experiments,
    for the specified scenarios and methods
    """
    for scenario in scenarios:
        if os.path.exists(os.path.join(base_dir, scenario)):
            for scenario_param in os.listdir(os.path.join(base_dir, scenario)):
                for method in methods:
                    path = os.path.join(base_dir, scenario, scenario_param, method)
                    if os.path.exists(path):
                        for filename in os.listdir(path):
                            file_path = os.path.join(path, filename)
                            try:
                                if os.path.isfile(file_path) or os.path.islink(file_path):
                                    os.unlink(file_path)
                                elif os.path.isdir(file_path):
                                    shutil.rmtree(file_path)
                                    # Create predictions and tmp folder in a unique thread. File creation in parallel running is dangerous!
                                    os.mkdir(file_path)
                            except OSError as e:
                                print('Failed to delete %s. Reason: %s' % (file_path, e))

def args_sanity_check(args):
    # Check task 
    if args.task not in ["inference", "tuning", "standardized"]:
        raise ValueError("Wrong task")

    # Check scenarios
    if args.scenarios is not None:
        allowed_scenarios = ["vanilla", "confounded", "linear", "measure_err", "timino",  "unfaithful", "pnl"]
        for scenario in args.scenarios:
            if scenario not in allowed_scenarios:
                raise ValueError(f"Scenario {scenario} not allowed!")
        
    # Check methods
    if args.methods is not None:
        allowed_methods = [
            "ges", "das", "score", "nogam", "cam", "pc", "diffan", "grandag",\
                "lingam", "resit", "random", "varsort", "scoresort", "fci"
        ]
        for method_name in args.methods:
            if method_name not in allowed_methods:
                raise ValueError(f"Method {method_name} not allowed!")
        
        
    # Check noise distr
    if args.noise_distr is not None:
        allowed_noise = ["gauss", "nonlin_weak", "nonlin_mid", "nonlin_strong"]
        for noise in args.noise_distr:
            if noise not in allowed_noise:
                raise ValueError(f"Noise distribution {noise} not allowed!")
        
    # Check graph type
    allowed_graphs = ["ER", "SF", "FC", "GRP"]
    if args.graph_type not in allowed_graphs:
        raise ValueError(f"Graph type {args.graph_type} not allowed!")


################################# Run program #################################
if __name__ == "__main__":
    # Command line arguments
    parser = argparse.ArgumentParser(description="Running jobs on SLURM cluster")
    parser.add_argument(
        "--seed", 
        default=42, 
        type=int, 
        help="seed for reproducibility"
    )
    parser.add_argument(
        "--task",
        default="inference",
        type=str,
        help="Data can be generated for inference or for hyperparameters tuning. Accepted values ['inference', 'tuning', 'standardized']"
    )
    parser.add_argument(
        "--no_clean_logs",
        action="store_true",
        help="Set to True to forbid deletion of existing predictions in tmp (adjacency matrices) for the methods and scenarios of current experiments",
    )
    parser.add_argument(
        "--partition",
        default="cpu",
        type=str,
        help="Name of the slurm partition"
    )
    parser.add_argument(
        "--graph_type",
        type=str,
        help="Algorithm for generation of synthetic graphs. Accepted values are ['ER', 'SF', 'GRP', 'FC']",
        required=True
    )
    parser.add_argument(
        "--noise_distr",
        nargs='+',
        type=str,
        help="Select datasets by distiribution of the noise. Accepted values are ['gauss', 'nonlin_\{weak,mid,strong\}']",
    )
    parser.add_argument(
        '--methods', 
        nargs='+', 
        help ='Mehods for the experimental inference', 
        type=str
    )
    parser.add_argument(
        '--scenarios', 
        nargs='+', 
        help ='Scenarios over which to run inference', 
        type=str
    )

    args = parser.parse_args()
    args_sanity_check(args)

    if args.noise_distr is None:
        args.noise_distr = ["gauss", "nonlin_weak", "nonlin_mid", "nonlin_strong"]

    if args.scenarios is None:
        args.scenarios = ["vanilla", "confounded", "linear", "measure_err", "timino",  "unfaithful", "pnl"]

    if args.methods is None:
        args.methods = [
            "ges", "das", "score", "nogam", "cam", "pc", "diffan", "grandag",\
                "lingam", "resit", "random", "varsort", "scoresort", "fci"
        ]

    # Set random seed for reproducibility (torch.manual_seed cause error!)
    random.seed(args.seed)
    np.random.seed(args.seed)

    # Clean logs from previous runs if required
    args.no_clean_logs = True
    if not args.no_clean_logs:
        ConsoleManager.suspended_msg("Cleaning the logs")
        delete_past_logs(os.path.join(LOGS_DIR, args.task, args.graph_type), args.scenarios, args.methods)
        ConsoleManager.done_msg()

    # Launch the experiments
    experiments_launcher = LaunchLargeExperiments(
        data_dir=DATA_DIR,
        logs_dir=LOGS_DIR,
        script_path=SCRIPT_PATH[args.task],
        params_dir=PARAMS_DIR,
        scratch_path = SCRATCH_PATH,
        task=args.task,
        methods=args.methods,
        scenarios=args.scenarios,
        graph_type=args.graph_type,
        noise_distr=args.noise_distr,
        partition=args.partition
    )

    experiments_launcher.launch()