"""Script called by slurm_data.py single dataset generation

Return
------
X : np.array
    Matrix of the data
A : np.array
    Ground truth adjacency matrix
"""

import os
import argparse
import numpy as np
from utils._utils import ConsoleManager
from benchmark.data_generators import ConfundedGenerator, VanillaGenerator, LinearSCMGenerator,\
    MeasureErrorGenerator, TiminoGenerator, UnfaithfulGenerator, PNLGenerator


###################### Utility functions ######################
def get_data_dir(base_dir, graph_type, noise_distribution, scenario, scenario_param, num_samples, graph_size, graph_density):
    scenario_param_folder = scenario + f"_{scenario_param}" if len(scenario_param) > 0 else scenario
    data_dir = os.path.join(base_dir, graph_type, noise_distribution, scenario, scenario_param_folder, f"{num_samples}_{graph_size}_{graph_density}")

    if not os.path.exists(data_dir):
        try:
            os.makedirs(data_dir)
        except FileExistsError:
            pass

    return data_dir

###################### main execution ######################
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Dataset generation")
    
    parser.add_argument(
        "--seed", 
        default=51, 
        type=int, 
        help="Random seed for reproducibility"
    )

    parser.add_argument(
        "--graph_type",
        default="ER",
        type=str,
        help="Algorithm for generation of synthetic graphs. Accepted values are ['ER', 'SF', 'GRP', 'FC']"
    )

    parser.add_argument(
        "--noise_distr",
        default="gauss",
        type=str,
        help="Distribution of th noise terms. Accepted values are ['Gauss', 'Random']"
    )

    parser.add_argument(
        "--scenario", 
        default="vanilla", 
        type=str, 
        help="Experimental scenario"
    )
    parser.add_argument(
        "--scenario_params", 
        nargs="+",
        type=float, 
        help="Parameters for dataset generation. E.g. for linear scenario, the various probabilities of linear functions",
    )

    parser.add_argument(
        "--graph_size", 
        default="small", 
        type=str, 
        help="Graph size: ['small', 'medium', 'large20', 'large30', 'large50']"
    )
    
    parser.add_argument(
        "--num_nodes", 
        default=5, 
        type=int, 
        help="Number of nodes in the graph"
    )

    parser.add_argument(
        "--graph_density", 
        type=str, 
        help="Graph density: ['sparse', 'dense', 'full', 'cluster']",
        required=True
    )

    parser.add_argument(
        "--samples_size_list", 
        nargs="+",
        type=int, 
        help="Required size of the datasets. E.g. [100, 1000]",
        required=True
    )

    parser.add_argument(
        "--dataset_id", 
        default=0, 
        type=int, 
        help="ID to store dataset and groundtruth in their folder"
    )
    parser.add_argument(
        "--output_folder", 
        default=None, 
        type=str, 
        help="Base folder for storage of the data. Input of the form /efs/data/<'hyperparameters' or 'inference'>"
    )

    args = parser.parse_args()

    noise_std_support = (0.5, 1.0) # Uniform support or noise terms
    max_num_samples = max(args.samples_size_list) # Generate data for largest sample size

    if args.scenario_params is None:
        args.scenario_params = [""]

    # Generate data for each parameter of the current scenario
    for scenario_param in args.scenario_params:
        if args.scenario == "vanilla":
            generator = VanillaGenerator(
                args.graph_type, args.num_nodes, args.graph_size, args.graph_density, max_num_samples, args.noise_distr, noise_std_support, args.seed
            )
        elif args.scenario == "confounded":
            # scenario_param : rho
            generator = ConfundedGenerator(
                scenario_param, args.graph_type, args.num_nodes, args.graph_size, args.graph_density, max_num_samples, args.noise_distr, noise_std_support, args.seed
            )
        elif args.scenario == "linear":
            # scenario_param : p_linear
            generator = LinearSCMGenerator(
                scenario_param, args.graph_type, args.num_nodes, args.graph_size, args.graph_density, max_num_samples, args.noise_distr, noise_std_support, args.seed
            )
        elif args.scenario == "measure_err":
            # scenario_param : gamma
            generator = MeasureErrorGenerator(
                scenario_param, args.graph_type, args.num_nodes, args.graph_size, args.graph_density, max_num_samples, args.noise_distr, noise_std_support, args.seed
            )
        elif args.scenario == "timino":
            generator = TiminoGenerator(
                args.graph_type, args.num_nodes, args.graph_size, args.graph_density, max_num_samples, args.noise_distr, noise_std_support, args.seed
            )
        elif args.scenario == "unfaithful":
            generator = UnfaithfulGenerator(
                scenario_param, args.graph_type, args.num_nodes, args.graph_size, args.graph_density, max_num_samples, args.noise_distr, noise_std_support, args.seed
            )
        elif args.scenario == "pnl":
            generator = PNLGenerator(
                float(scenario_param), args.graph_type, args.num_nodes, args.graph_size, args.graph_density, max_num_samples, args.noise_distr, noise_std_support, args.seed
            )
        else:
            raise ValueError(f"Wrong value {args.scenario} for scenario option.")
        
        ConsoleManager.data_config_msg(
            {"scenario": args.scenario, "num_samples" : max_num_samples, "graph_size" : args.graph_size, "graph_density": args.graph_density}
        )
        X, A = generator.generate_data()

        # Reproducible subsampling
        np.random.seed(args.seed)

        # Subsample store data for all required samples size in args.samples_size_list
        for num_samples in args.samples_size_list:
            # Store data
            if len(X) > num_samples:
                X_sub = X[np.random.choice(len(X), num_samples, replace=False), :]
            else:
                X_sub = X[:, :]
            data_dir = get_data_dir(args.output_folder, args.graph_type,  args.noise_distr, args.scenario, str(scenario_param), num_samples, args.graph_size, args.graph_density)
            ConsoleManager.data_storing(data_dir, args.dataset_id)
            np.savetxt(os.path.join(data_dir, f"data_{args.dataset_id}.csv"), X_sub, delimiter=",")
            np.savetxt(os.path.join(data_dir, f"groundtruth_{args.dataset_id}.csv"), A, delimiter=",")
            if args.scenario == "confounded":
                np.savetxt(os.path.join(data_dir, f"confounded_{args.dataset_id}.csv"), generator.confounded_adjacency, delimiter=",")
            if args.scenario == "unfaithful":
                np.savetxt(os.path.join(data_dir, f"unfaithful_{args.dataset_id}.csv"), generator.unfaithful_adj, delimiter=",")
            ConsoleManager.done_msg()