"""
File to create and save synthetic data.
"""
import argparse

import dill
import numpy as np
from tqdm import tqdm, trange

from ml2_meta_causal_discovery.datasets.dataset_generators import \
    ClassifyDatasetGenerator
from ml2_meta_causal_discovery.utils.datautils import \
    turn_bivariate_causal_graph_to_label

import h5py
from pathlib import Path
import json


def hpc_classify_main(args):
    num_vars = 20
    function_gen = "gplvm"
    usecase = args.folder_name
    # Rest of the code...
    num_samples = 1000
    graph_type = ["ER"]
    exp_edges_upper = args.exp_edges_upper
    exp_edges_lower = args.exp_edges_lower

    if exp_edges_upper == exp_edges_lower:
        name = f"{function_gen}_{num_vars}var_ER{args.exp_edges_lower}"
    else:
        name = f"{function_gen}_{num_vars}var_ERL{args.exp_edges_lower}U{args.exp_edges_upper}"

    dataset_generator = ClassifyDatasetGenerator(
        num_variables=num_vars,
        function_generator=function_gen,
        batch_size=args.batch_size,
        num_samples=num_samples,
        kernel_sum=True,
        mean_function="latent",
        graph_type=graph_type,
        graph_degrees=list(range(exp_edges_lower, exp_edges_upper + 1))
    )
    # Context data here will have both context and target
    for i in tqdm(range(args.data_start, args.data_end)):
        np.random.seed(i)  # Set the seed
        (
            target_data,
            causal_graphs,
        ) = next(dataset_generator.generate_next_dataset())
        # Save the data as h5py
        save_folder = Path(args.work_dir) / "datasets" / "data" / "synth_training_data" / name / usecase
        save_folder.mkdir(exist_ok=True, parents=True)
        with h5py.File(save_folder / f'{name}_{i}.hdf5', 'w') as f:
            dset = f.create_dataset("data", data=target_data)
            dset = f.create_dataset("label", data=causal_graphs)
        with open(save_folder / "graph_args.json", "w") as f:
            graph_args = {
                "graph_type": graph_type,
                "graph_degrees_upper": exp_edges_upper,
                "graph_degrees_lower": exp_edges_lower,
                "num_variables": num_vars,
                "num_samples": num_samples,
                "function_generator": function_gen,
            }
            json.dump(graph_args, f)


if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--work_dir",
        "-wd",
        type=str,
        default="./",
        help="Folder where the Neural Process Family is stored.",
    )
    parser.add_argument(
        "--data_start",
        "-ds",
        type=int,
        default=0,
    )
    parser.add_argument(
        "--data_end",
        "-de",
        type=int,
        default=1,
    )
    parser.add_argument(
        "--batch_size",
        "-bs",
        type=int,
        default=50000,
    )
    parser.add_argument(
        "--exp_edges_upper",
        "-eeu",
        type=int,
        default=20,
    )
    parser.add_argument(
        "--exp_edges_lower",
        "-eel",
        type=int,
        default=20,
    )
    parser.add_argument(
        "--folder_name",
        "-fn",
        type=str,
        default="train",
    )

    args = parser.parse_args()
    hpc_classify_main(args)
