"""Script called by slurm_inference.py for inference over a single dataset

Given that data generation process is the only distinguishing factor,
a sige class (i.e. ExperimentManager) is enough to handle
    - Run of single instance experiment
    - Aggregation of the results
All the other classes should just implement different data generation process. 
"""

import argparse
import os
from benchmark.experiment_manager import ExperimentManager
from utils._utils import  get_inference_params


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Dataset generation")

    parser.add_argument(
        "--debug",
        action="store_true",
        help="Set to True to enable debugging behaviour"
    )

    parser.add_argument(
        "--dry_run",
        action="store_true",
        help="If True, read the data but do not execute inference. Sanity check on data (slurm logs)"
    )

    parser.add_argument(
        "--base_data_dir", 
        default=os.path.join(os.sep, "efs", "data", "ER", "gauss"), 
        type=str, 
        help="Directory with dataset and groundtruth csv files"
    )

    parser.add_argument(
        "--base_output_dir", 
        default=os.path.join(os.sep, "efs", "tmp", "causal-benchmark-logs", "tuning", "ER"), 
        type=str, 
        help="Destination of the output. One for each graph_algorithm, scenario, method"
    )

    parser.add_argument(
        "--data_config", 
        default="100_small_sparse", 
        type=str, 
        help="Configuration of the data in terms of number of samples, size of the graph, density of the graph"
    )

    parser.add_argument(
        "--method", 
        default=f"ges", 
        type=str, 
        help="Method of inference"
    )

    parser.add_argument(
        "--noise_distr", 
        default="gauss", 
        type=str, 
        help="Distribution of the additive noise terms. Admitted values ['gauss', 'random']"
    )

    parser.add_argument(
        "--scenario", 
        default="vanilla", 
        type=str, 
        help="Experimental scenario"
    )

    parser.add_argument(
        "--scenario_param", 
        default="vanilla", # e.g. linear_0.33, or <scenario>
        type=str, 
        help="Specify the parameter for the data generating process. If no parameter is required, this is identical to the scenario name"
    )

    parser.add_argument(
        "--reg_params", 
        nargs='+',
        type=float, 
        help="List of regularization parameters of the solution: usually alpha or lambda"
    )

    parser.add_argument(
        "--reg_param_name", 
        default="alpha", # e.g. linear_0.33, or <scenario>
        type=str, 
        help="Name of the regularization parameter. Usually alpha or lambda"
    )

    parser.add_argument(
        "--tuning_results_base_dir", 
        default="/efs/tmp/causal-benchmark-logs/tuning/ER/gauss/vanilla/vanilla/diffan/100_small_sparse", # e.g. linear_0.33, or <scenario>
        type=str, 
        help="Path to the directory with the tuning results for the current dataset"
    )

    parser.add_argument(
        '--params_file', 
        help='Path to the parameters grid of the current method', 
        type=str,
        default="/home/ec2-user/causal-benchmark/hyperparameters/params_grid/diffan.json"
    )

    parser.add_argument(
        "--task",
        type=str,
        default="inference",
        help="inference, tuning, standardized"
    )

    args = parser.parse_args()
    data_dir = os.path.join(
        args.base_data_dir, args.scenario, args.scenario_param, args.data_config
    )
    output_dir = os.path.join(args.base_output_dir, args.scenario, args.scenario_param, args.method)

    for reg_param in args.reg_params:
        n_seeds = 0
        files = os.listdir(data_dir)
        files = sorted(files, key=lambda x: (x.split(".")[0].split("_")[0], int(x.split(".")[0].split("_")[1])))
        for file in files:
            # if n_seeds < 20:
                if file.startswith("data_") and file.endswith("csv"): 
                    dataset_id = int(file.split(".")[0].split("_")[1])
                    parameters = get_inference_params(args.tuning_results_base_dir, dataset_id, args.method, args.params_file, args.reg_param_name, reg_param)

                    exp_manager = ExperimentManager(
                        data_dir=data_dir, 
                        output_dir=output_dir,
                        dataset_id=dataset_id,
                        noise_distr=args.noise_distr,
                        method_name=args.method, 
                        method_parameters=parameters,  
                        task=args.task
                    )
                    data_df = exp_manager.get_data()

                    exp_manager.inference(data_df)
                    n_seeds += 1