"""Script called by slurm_inference.py for inference over a single dataset

Given that data generation process is the only distinguishing factor,
a sige class (i.e. ExperimentManager) is enough to handle
    - Run of single instance experiment
    - Aggregation of the results
All the other classes should just implement different data generation process. 
"""

import argparse
import os
from benchmark.experiment_manager import ExperimentManager
from utils._utils import  get_inference_params


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Dataset generation")

    parser.add_argument(
        "--debug",
        action="store_true",
        help="Set to True to enable debugging behaviour"
    )

    parser.add_argument(
        "--dry_run",
        action="store_true",
        help="If True, read the data but do not execute inference. Sanity check on data (slurm logs)"
    )

    parser.add_argument(
        "--base_data_dir", 
        default=os.path.join(os.sep, "efs", "data", "ER", "gauss"), 
        type=str, 
        help="Directory with dataset and groundtruth csv files"
    )

    parser.add_argument(
        "--base_output_dir", 
        default=os.path.join(os.sep, "efs", "tmp", "causal-benchmark-logs", "tuning", "ER"), 
        type=str, 
        help="Destination of the output. One for each graph_algorithm, scenario, method"
    )

    parser.add_argument(
        "--data_config", 
        default="100_small_sparse", 
        type=str, 
        help="Configuration of the data in terms of number of samples, size of the graph, density of the graph"
    )

    parser.add_argument(
        "--method", 
        default=f"ges", 
        type=str, 
        help="Method of inference"
    )

    parser.add_argument(
        "--noise_distr", 
        default="gauss", 
        type=str, 
        help="Distribution of the additive noise terms. Admitted values ['gauss', 'random']"
    )

    parser.add_argument(
        "--scenario", 
        default="vanilla", 
        type=str, 
        help="Experimental scenario"
    )

    parser.add_argument(
        "--scenario_param", 
        default="vanilla", # e.g. linear_0.33, or <scenario>
        type=str, 
        help="Specify the parameter for the data generating process. If no parameter is required, this is identical to the scenario name"
    )

    parser.add_argument(
        "--reg_params", 
        nargs='+',
        type=float, 
        help="List of regularization parameters of the solution: usually alpha or lambda"
    )

    parser.add_argument(
        "--reg_param_name", 
        default="alpha", # e.g. linear_0.33, or <scenario>
        type=str, 
        help="Name of the regularization parameter. Usually alpha or lambda"
    )

    parser.add_argument(
        "--tuning_results_base_dir", 
        default="/efs/tmp/causal-benchmark-logs/tuning/ER/gauss/vanilla/vanilla/diffan/100_small_sparse", # e.g. linear_0.33, or <scenario>
        type=str, 
        help="Path to the directory with the tuning results for the current dataset"
    )

    parser.add_argument(
        '--params_file', 
        help='Path to the parameters grid of the current method', 
        type=str,
        default="/home/ec2-user/causal-benchmark/hyperparameters/params_grid/diffan.json"
    )

    parser.add_argument(
        "--task",
        type=str,
        default="inference",
        help="inference, tuning, standardized"
    )
    parser.add_argument(
        "--seed",
        type=int,
        # default="inference",
        help="seed of the dataset"
    )

    args = parser.parse_args()
    data_dir = os.path.join(
        args.base_data_dir, args.scenario, args.scenario_param, args.data_config
    )
    output_dir = os.path.join(args.base_output_dir, args.scenario, args.scenario_param, args.method)

    for reg_param in args.reg_params:
        files = os.listdir(data_dir)
        files = sorted(files, key=lambda x: (x.split(".")[0].split("_")[0], int(x.split(".")[0].split("_")[1])))
        dataset_id = args.seed
        parameters = get_inference_params(args.tuning_results_base_dir, dataset_id, args.method, args.params_file, args.reg_param_name, reg_param)

        exp_manager = ExperimentManager(
            data_dir=data_dir, 
            output_dir=output_dir,
            dataset_id=dataset_id,
            noise_distr=args.noise_distr,
            method_name=args.method, 
            method_parameters=parameters,  
            task=args.task
        )
        data_df = exp_manager.get_data()

        exp_manager.inference(data_df)