"""Script called by slurm_inference.py for inference over a single dataset

Given that data generation process is the only distinguishing factor,
a sige class (i.e. ExperimentManager) is enough to handle
    - Run of single instance experiment
    - Aggregation of the results
All the other classes should just implement different data generation process. 
"""

import argparse
import json
import os
from benchmark.experiment_manager import ExperimentManager


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Dataset generation")

    parser.add_argument(
        "--base_data_dir", 
        default=os.path.join(os.sep, "efs", "data", "ER", "gauss"), 
        type=str, 
        help="Directory with dataset and groundtruth csv files"
    )

    parser.add_argument(
        "--base_output_dir", 
        default=os.path.join(
        os.sep, "efs", "tmp", "causal-benchmark-logs", "tuning", "ER"
        ), 
        type=str, 
        help="Destination of the output"
    )

    parser.add_argument(
        "--data_config", 
        default="100_small_sparse", 
        type=str, 
        help="Configuration of the data in terms of number of samples, size of the graph, density of the graph"
    )

    parser.add_argument(
        "--method", 
        default=f"diffan", 
        type=str, 
        help="Method of inference"
    )

    parser.add_argument(
        "--noise_distr", 
        default="gauss", 
        type=str, 
        help="Distribution of the additive noise terms. Admitted values ['gauss', 'random']"
    )

    parser.add_argument(
        "--scenario", 
        default="vanilla", 
        type=str, 
        help="Experimental scenario"
    )

    parser.add_argument(
        "--scenario_param", 
        default="vanilla", # e.g. linear_0.33, or <scenario>
        type=str, 
        help="Specify the parameter for the data generating process. If no parameter is required, this is identical to the scenario name"
    )

    # parser.add_argument(
    #     '--params_id', 
    #     help='Id of the parameters combination in the parameters grid', 
    #     type=int,
    #     default=0
    # )

    parser.add_argument(
        '--params_file', 
        help='Path to the parameters grid of the current method', 
        type=str,
        default="/home/ec2-user/causal-benchmark/hyperparameters/params_grid/diffan.json"
    )

    ############# Run program #############

    args = parser.parse_args()
    data_dir = os.path.join(
        args.base_data_dir, args.scenario, args.scenario_param, args.data_config
    )
    print(f"Get data from: {data_dir}")

    with open(args.params_file, "r") as f:
        param_grid = json.load(f)

    for params_id, parameters in param_grid.items():
        n_seeds = 0
        files = os.listdir(data_dir)
        files = sorted(files, key=lambda x: (x.split(".")[0].split("_")[0], int(x.split(".")[0].split("_")[1])))
        for file in files:
            if n_seeds < 10 and (file.startswith("data_") and file.endswith("csv")): # Limit to 10 seeds
                params_id = int(params_id)
                dataset_id = int(file.split(".")[0].split("_")[1])
                output_dir = os.path.join(
                    args.base_output_dir, args.noise_distr, args.scenario, args.scenario_param, args.method, args.data_config, f"dataset_{dataset_id}"
                )

                # Get parameters and write them in a dictionary
                # parameters = read_params_from_file(args.params_id, args.params_file)

                exp_manager = ExperimentManager(
                    data_dir=data_dir, 
                    output_dir=output_dir,
                    dataset_id=dataset_id,
                    noise_distr=args.noise_distr,
                    method_name=args.method, 
                    method_parameters=parameters,  
                    task="tuning"
                )
                data_df = exp_manager.get_data()

                exp_manager.tune(data_df, params_id)
                n_seeds += 1