import os
import sys
import json
import math
import causally.scm.scm as scm
import causally.graph.random_graph as rg
from causally.scm import noise
from typing import Dict

from generate_datasets import compute_seed_generator, create_dir, get_seed_from_generator, args_sanity_check
from utils.data import generate_and_store_dataset
from utils.custom_causally import InvertibleAdditiveNoiseModel

WORKSPACE = os.path.join(os.getcwd(), "..")




if __name__ == "__main__":
    """Json input arguments.

    seed: int
        Random seed for reproducibility.
    number_of_datasets: int
        Number of datasets  generated.
    samples_per_dataset: int
        Number of samples for each dataset.
    graph_generation_algorithm: str
        Algorithm for random graph generation, Erdos-Renyi, Barabasi Albert (i.e. Scale-Free),
        Gaussian Random Process. Accepted values are {"ER", "SF", "GRP"}.
    noise_distribution: str
        Distribution of the SCM noise terms. Accepted values are {'gauss', 'nn'}.
    scm_type: str
        Structural causal model (linear, nonlinear additive, post-nonlinear).
        Accepted values are {}'linear', 'anm', 'pnl'}.
    mechanisms: str
        Methods for generation of the causal mechanism (gaussian process, neural net).
        Accepted values are {'gp', 'nn'}.
    number_of_nodes: int
        Number of nodes in the graph.
    p_density: float
        `p`density parameter, probability of connecting a pair of nodes. Values in the range [0, 1]
    m_density: int
        `m` density parameter, expected degree of each node.
        One and only one between `m_density` and `p_density` must be provided.
    output_folder: str
        Base folder for storage of the data.
    dataset_name: str
        Name of the dataset. This is used as name of the folder for the dataset storage.
    """

    operation_name = os.path.basename(__file__).split(".")[0]

    if len(sys.argv) > 1:
        # loading pre-generated args
        args_dir = os.path.join(WORKSPACE, "script-arguments", "generated")
        args_filename = f"{operation_name}_{sys.argv[1]}.json"

    else:
        args_dir = os.path.join(WORKSPACE, "script-arguments")
        args_filename = operation_name + ".json"

    # Check script directory exists
    if not os.path.isdir(args_dir):
        raise IsADirectoryError(
            "The '../script-arguments' directory does not exist."
            "Please create the directory with the 'dataset_inference.json' file."
        )

    json_args_path = os.path.join(args_dir, args_filename)

    if os.path.exists(json_args_path):
        with open(json_args_path, "r") as file:
            args = json.load(fp=file)
    else:
        raise FileNotFoundError(f"The file '{json_args_path}' does not exist.")

    # Check and read input arguments
    args_sanity_check(args, invertible=True)



    standardize = args["standardize"]

    # Update output_folder and seed based on the task
    args["output_folder"] = os.path.join(args["output_folder"], args["task"])

    # Name the dataset folder
    if args["dataset_name"] is None:
        if args["m_density"] is not None:
            arg_density = args["m_density"]
        else:
            arg_density = args["p_density"]
        args["dataset_name"] = "-".join(
            [
                str(args["number_of_nodes"]),
                args["graph_generation_algorithm"],
                str(arg_density),
                args["scm_type"],
                args["mechanisms"],
                args["noise_distribution"],
            ]
        )
        
        if standardize:
            args["dataset_name"] += "-std"
        else:
            args["dataset_name"] += "-nostd"

        if "seed_offset" in args:
            # raw seed, before changed by dataset_name
            args["dataset_name"] += f"-SO-{args['seed_offset']}"

    # Create dataset directory
    dataset_directory = create_dir(args["output_folder"], args["dataset_name"])


    generator = compute_seed_generator(args["task"], args["seed"])
    

    # Sample and store datasets
    for id in range(args["number_of_datasets"]):

        curr_seed = get_seed_from_generator(generator)

        curr_noise_seed = get_seed_from_generator(generator)
        curr_noise_generator = compute_seed_generator("noise", curr_noise_seed)

        data_file = os.path.join(dataset_directory, f"data_{id}.npy")
        groundtruth_file = os.path.join(dataset_directory, f"groundtruth_{id}.npy")

        # Graph generator
        if args["graph_generation_algorithm"] == "ER":
            graph_generator = rg.ErdosRenyi(
                num_nodes=args["number_of_nodes"],
                expected_degree=args["m_density"],
                p_edge=args["p_density"],
                min_num_edges=args["min_num_edges"],
            )
        elif args["graph_generation_algorithm"] == "SF":
            graph_generator = rg.BarabasiAlbert(
                num_nodes=args["number_of_nodes"],
                expected_degree=args["m_density"],
                min_num_edges=args["min_num_edges"],
            )
        elif args["graph_generation_algorithm"] == "GRP":
            p_out = args["p_density"] / 5  # TODO make input script argument
            n_clusters = math.floor(
                args["number_of_nodes"] / 3
            )  # TODO make input script argument
            graph_generator = rg.GaussianRandomPartition(
                num_nodes=args["number_of_nodes"],
                p_in=args["p_density"],
                p_out=p_out,
                n_clusters=n_clusters,
            )

        # Noise parameter (must be equal for both nodes, in linear case)
        scale = curr_noise_generator.uniform(0.5, 5)

        # Model generator
        if args["scm_type"] == "anm":
            model = InvertibleAdditiveNoiseModel(
                num_samples=args["samples_per_dataset"],
                graph_generator=graph_generator,
                noise_scale=scale,
                seed=args["seed"] + id,
            )
        elif args["scm_type"] == "linear":
            noise_generator = noise.Gumbel(loc=0, scale=scale)
            model = scm.LinearModel(
                num_samples=args["samples_per_dataset"],
                graph_generator=graph_generator,
                noise_generator=noise_generator,
                seed=args["seed"] + id,
                min_weight=-1,
                max_weight=-1,
            )
        elif args["scm_type"] == "pnl":
            raise ValueError(
                f"Expected scm_type is 'linear' or 'anm'. Got instead 'pnl',"
                " which is not supported for invertible data generation."
            )
        else:
            raise ValueError(
                "Expected scm_type is linear or anm. Got instead " + args["scm_type"]
            )

        generate_and_store_dataset(data_file, groundtruth_file, model, standardize)
        print(f"Dataset {id} generated.")

    # Store the generation parameters
    path = os.path.join(dataset_directory, "config.json")
    with open(path, "w") as f:
        json.dump(args, f, indent=6)
