"""Goal:
1. For inference, recover the best parameters for each scenario-config
2. For inference, experiments only with one parameter for each pair of scenario-config
   (need to remove one for loop)
"""
import os
import json
import time
from typing import List
from slurm.awsslurm import SlurmJob
from sklearn.model_selection import ParameterGrid


PARAMS_INPUT_FOLDER = "raw"
PARAMS_GRID_FOLDER = "params_grid"

class LaunchLargeExperiments:
    """Interface with slurm cluster manager, for launching multiple experiments. 
    Allow for different execution baed on the type of task,
    whether it is 'inference' or 'tuning'.
    For the 'inference' task, the class automatically selects the best hyperparameters
    of a method based on the combination of scenario and data confguration. 
    """
    def __init__(
        self,
        data_dir : str,
        logs_dir : str,
        script_path : str,
        params_dir : str,
        scratch_path : str,
        task : str,
        methods : List[str],
        scenarios : List[str],
        graph_type : str,
        noise_distr : str,
        partition : str
    ) -> None:
        """
        Parameters
        ----------
        data_dir : str
            Base directory of storage of the data. E.g. /efs/data
        logs_dir : str
            Base director for experiments logs storage. E.g. /efs/tmp/causal-benchmark-logs
        script_path : str
            Path of the script for inference on a dataset. E.g. dataset_inference.py
        params_dir : str
            Base with the tunable hyperparmeters.
            For each method, the json with the hyperparameters is params_dir/<method>.json
        scratch_path : str
            Path to extract_to_scratch_and_run.sh
        task : str
            'inference' or 'tuning'
        methods : List[str]
            The methods for the inference or hyperparams-inference task
        scenarios : List[str]
            The scenarios of inference or hyperparams-inference
        graph_type : str
            Graph simluation algorithms. One between ['ER', 'SF', 'GRP', 'FC']
        num_datasets : int
            Number of seeds of inference for each scenario-config combination
        noise_distr : str
            Distribution of the noise term. Can be 'gauss' or the type of non-gaussianity
        partition : str
            The slurm partition. One between ['cpu', 'gpu']
        """
        self.data_dir = data_dir
        self.logs_dir = logs_dir
        self.script_path = script_path
        self.params_dir = params_dir
        self.scratch_path = scratch_path
        self.task = task
        self.methods = methods
        self.graph_type = graph_type
        self.scenarios = scenarios
        self.noise_distr = noise_distr
        self.partition = partition

        if task == "tuning":
            self.make_params_grid()


    def param_grids_file(self, method):
        return os.path.join(self.params_dir, PARAMS_GRID_FOLDER, method + ".json")

    def methods_with_tuning(self):
        return ["diffan", "grandag"]
    
    def make_params_grid(self):
        for method in self.methods:
            if not os.path.exists(os.path.join(self.params_dir, PARAMS_GRID_FOLDER, method + ".json")):
                try:
                    with open(os.path.join(self.params_dir, PARAMS_INPUT_FOLDER, method + ".json"), "r") as f:
                        params = json.load(f)
                        param_grid = ParameterGrid(params["tunable"] | params["non-tunable-default"]) 
                        # Add identifier to each configuration
                        id_param_grid = {key : val for key, val in zip(range(len(param_grid)), param_grid)}
                    with open(os.path.join(self.params_dir, PARAMS_GRID_FOLDER, method + ".json"), "w") as f:
                        json.dump(id_param_grid, f)
                    print(f"Written {method} param grid!")
                except FileExistsError:
                    print(f"No worries: {os.path.join(self.params_dir, PARAMS_INPUT_FOLDER, method)}.json was already created!")

    def make_base_script(
        self,
        base_data_dir : str,
        base_output_dir : str,
        noise_distr : str,
        scenario : str,
        scenario_param : str,
        data_config : str,
        method : str
    ):
        """Part of the slurm job script shared by inference and tuning tasks
        """
        script_args = ("" +
            f"--base_data_dir {base_data_dir} " + 
            f"--base_output_dir {base_output_dir} " + 
            f"--data_config {data_config} " + 
            f"--method {method} " + 
            f"--noise_distr {noise_distr} " +
            f"--scenario {scenario} " +
            f"--scenario_param {scenario_param} "
        )
        return script_args

    def scratch_and_run_prepend(self, base_data_dir, scenario, scenario_param, data_config):
        path_to_config = os.path.join(base_data_dir, scenario, scenario_param, data_config)
        scratch_and_run_prepend = ""
        for file in os.listdir(path_to_config):
            if file.startswith("data_"):
                dataset_id = file.split(".")[0].split("_")[1]
                tar_data_path = os.path.join(path_to_config, f"data_{dataset_id}.csv.tar.gz") 
                scratch_and_run_prepend += self.scratch_path + " " + tar_data_path + " -- "
        return scratch_and_run_prepend

    def launch_tuning(
            self,
            base_data_dir : str,
            base_output_dir : str,
            noise_distr : str,
            scenario : str,
            scenario_param : str,
            data_config : str, 
            method : str,
    ):
        """Launch slurm jobs for the tuning of hyperparameters.
        For each configuration in the parameters grid, for each dataset, launch a slurm job.

        Parameters
        ----------
        base_data_dir : str
            Directory from which to read the data specified in data_config
        noise_distr : str
            Distribution of the noise terms
        scenario : str
            Data geneartion scenario
        scenario_param : str
            <scenario_name>_<param_value> to distinguish parameters of the current scenario.
            E.g. confounded_0.2
        data_config : str
            Configurations of data geneariton, i.e. small_sparse_100
        method : str
            Method of inference
        """
        param_grids_file = self.param_grids_file(method)
        script_args = self.make_base_script(
            base_data_dir, base_output_dir, noise_distr, scenario, scenario_param, data_config, method
        )
        script_args += (" " +
                # f"--params_id {param_id} " +
                f"--params_file {param_grids_file}" 
        )

        set_mem = False
        if "large50" in data_config:
            set_mem = True

        # TODO: check if time is enough!
        # If it is interrupted, how do we handle that?
        slurm_time="48:00:00"
        job = SlurmJob(
            self.script_path,
            name=f"tuning_{method}_{scenario_param}_{data_config}",
            time=slurm_time,
            gpu=False,
            ngpus=None,
            afterok=None,
            ntasks_per_node=None, 
            script_args=script_args,
            partition=self.partition,
            scratch_and_run_prepend=self.scratch_and_run_prepend(base_data_dir, scenario, scenario_param, data_config),
            set_mem=set_mem
        )

        _ = job()

    def launch_inference(
        self,
        base_data_dir : str,
        base_output_dir : str,
        noise_distr : str,
        scenario : str,
        scenario_param : str,
        data_config : str, 
        method : str
    ):
        """Get the best combination of parameters, and run inference!
        """
    
        with open(os.path.join(self.params_dir, PARAMS_INPUT_FOLDER, method + ".json"), "r") as f:
            params = json.load(f)
            reg_params = params["non-tunable"]

        param_grids_file = self.param_grids_file(method) # for non-reg params, we pass the file address
        reg_param_name = list(reg_params.keys())[0] # TODO: this is very dirty!

        tuning_results_base_dir = os.path.join(self.logs_dir, "tuning", self.graph_type,\
                        noise_distr, scenario, scenario_param, method, data_config)

        script_args = self.make_base_script(
            base_data_dir, base_output_dir, noise_distr, scenario, scenario_param, data_config, method
        )

        script_args += (" " +
            f"--reg_params {' '.join([str(x) for x in reg_params[reg_param_name]])} " +
            f"--reg_param_name {reg_param_name} " + 
            f"--tuning_results_base_dir {tuning_results_base_dir} " +
            f"--params_file {param_grids_file} " +
            f"--task {self.task}"
        )

        # Sanity check
        assert os.path.exists(os.path.join(base_data_dir, scenario, scenario_param, data_config))
        assert len(os.listdir(os.path.join(base_data_dir, scenario, scenario_param, data_config))) > 0

        set_mem = False
        num_cpus = 1
        if "large50" in data_config and method in ["resit", "diffan", "grandag"]:
            set_mem = True
        if method in ["diffan", "grandag"]:
            num_cpus = 18
        

        slurm_time="48:00:00"
        job = SlurmJob(
            self.script_path,
            name=f"inference_{method}_{scenario_param}_{data_config}",
            time=slurm_time,
            gpu=False,
            ngpus=None,
            afterok=None,
            ntasks_per_node=None, 
            script_args=script_args,
            partition=self.partition,
            scratch_and_run_prepend=self.scratch_and_run_prepend(base_data_dir, scenario, scenario_param, data_config),
            set_mem=set_mem,
            num_cpus=num_cpus
        )

        _ = job()


    def launch(self):
        # Cartesian product of all possibilities
        if self.task == "tuning":
            batch_dimension = 500
            sleep_time = 1200
        elif self.task == "inference" or self.task == "standardized":
            batch_dimension = 500
            sleep_time = 360 # 6 minutes wait for data loading and allocation
        num_submissions = 0
        for noise in self.noise_distr:
            base_data_dir = os.path.join(self.data_dir, self.graph_type, noise)
            base_output_dir = os.path.join(self.logs_dir, self.task, self.graph_type)
            for scenario in self.scenarios:
                scenario_params = os.listdir(os.path.join(base_data_dir, scenario)) # e.g. confounded_0.2
                for scenario_param in scenario_params:
                    data_configs = os.listdir(os.path.join(base_data_dir, scenario, scenario_param)) # e.g. small_sparse_100
                    for config in data_configs: 
                        for method_name in self.methods: 
                            if (not "large30" in config) and (not "large50" in config): # TODO: remove check on large50!
                                if not (method_name in ["pc", "ges", "fci"] and "large50" in config): # NOTE: do not compare with the others, as for 50 nodes process get killed
                                    if self.task == "tuning":
                                        self.launch_tuning(
                                            base_data_dir, base_output_dir, noise, scenario, scenario_param, config, method_name
                                        )
                                        num_submissions +=1
                                    elif self.task == "inference" or self.task == "standardized":
                                        if "large50" not in config:
                                            self.launch_inference(
                                                base_data_dir, base_output_dir, noise, scenario, scenario_param, config, method_name
                                            )
                                            num_submissions +=1
                                    if (num_submissions > 0) and (num_submissions % batch_dimension == 0):
                                        print(f"Num submissions: {num_submissions}: taking a {int(sleep_time/60)} minutes nap...", end=" ", flush=True)
                                        time.sleep(sleep_time) 
                                        print("Restart!")

        print(f"Submitted all {num_submissions} jobs!")