import os
import sys
import shutil
import json
import math
import numpy as np
import causally.scm.scm as scm
import causally.graph.random_graph as rg
import causally.scm.noise as noise
import causally.scm.causal_mechanism as cm
import causally.scm.context as context
from typing import Dict
from torch import nn
import hashlib
from utils.custom_causally import InvertibleAdditiveNoiseModel


from utils.data import generate_and_store_dataset

WORKSPACE = os.path.join(os.getcwd(), "..")


def args_sanity_check(args: Dict, invertible=False):
    if args["p_density"] is None or args["m_density"] is None:
        ValueError("One argument between `-p` and `-m` must be unassigned.")

    if args["m_density"] is None and args["graph_generation_algorithm"] == "SF":
        raise ValueError(
            "SF graphs can not accept `-p` as density parameter. Provide a valid value for `-m`"
        )

    # TODO: add more checks!

    if invertible:
        assert args["scm_type"] in ["linear", "anm"]
        assert args["mechanisms"] == "inv"
        assert args["noise_distribution"] == "gumbel"


def compute_seed_generator(task: str, initial_seed: int) -> np.random.Generator:
    OFFSETS = {"train": 0, "test": 1_000, "finetune": 2_000, "noise": 3_000}
    if task in OFFSETS:
        rng_seed = int(initial_seed + OFFSETS[task])
    else:
        offset = int(
            hashlib.sha256(task.encode("utf-8"), usedforsecurity=False).hexdigest(),
            base=16,
        )
        # note modulus to ensure the seed is in a reasonable range (automatic for fixed offsets)
        rng_seed = int(initial_seed + offset) % int(2**32)

    return np.random.default_rng(seed=rng_seed)


def get_seed_from_generator(gen: np.random.Generator) -> int:
    upper_seed_bound = int(2**32) - 1234
    return int(gen.integers(1, upper_seed_bound))


def create_dir(folder, dataset_name):
    dataset_directory = os.path.join(folder, dataset_name)
    if os.path.exists(dataset_directory):
        shutil.rmtree(dataset_directory)
    if not os.path.exists(dataset_directory):
        os.makedirs(dataset_directory)  # create also intermediate directories
    return dataset_directory


if __name__ == "__main__":
    """Json input arguments.

    seed: int
        Random seed for reproducibility.
    number_of_datasets: int
        Number of datasets  generated.
    samples_per_dataset: int
        Number of samples for each dataset.
    graph_generation_algorithm: str
        Algorithm for random graph generation, Erdos-Renyi, Barabasi Albert (i.e. Scale-Free),
        Gaussian Random Process. Accepted values are {"ER", "SF", "GRP"}.
    noise_distribution: str
        Distribution of the SCM noise terms. Accepted values are {'gauss', 'nn'}.
    scm_type: str
        Structural causal model (linear, nonlinear additive, post-nonlinear).
        Accepted values are {}'linear', 'anm', 'pnl'}.
    mechanisms: str
        Methods for generation of the causal mechanism (gaussian process, neural net).
        Accepted values are {'gp', 'nn'}.
    number_of_nodes: int
        Number of nodes in the graph.
    p_density: float
        `p`density parameter, probability of connecting a pair of nodes. Values in the range [0, 1]
    m_density: int
        `m` density parameter, expected degree of each node.
        One and only one between `m_density` and `p_density` must be provided.
    output_folder: str
        Base folder for storage of the data.
    dataset_name: str
        Name of the dataset. This is used as name of the folder for the dataset storage.
    """

    operation_name = os.path.basename(__file__).split(".")[0]

    if len(sys.argv) > 1:
        # loading pre-generated args
        args_dir = os.path.join(WORKSPACE, "script-arguments", "generated")
        args_filename = f"{operation_name}_{sys.argv[1]}.json"

    else:
        args_dir = os.path.join(WORKSPACE, "script-arguments")
        args_filename = operation_name + ".json"

    # Check script directory exists
    if not os.path.isdir(args_dir):
        raise IsADirectoryError(
            "The '../script-arguments' directory does not exist."
            "Please create the directory with the 'dataset_inference.json' file."
        )

    json_args_path = os.path.join(args_dir, args_filename)

    if os.path.exists(json_args_path):
        with open(json_args_path, "r") as file:
            args = json.load(fp=file)
    else:
        raise FileNotFoundError(f"The file '{json_args_path}' does not exist.")

    # Check and read input arguments
    args_sanity_check(args, args.get("mechanism", "") == "inv")
    standardize = args["standardize"]

    # Update output_folder and seed based on the task
    args["output_folder"] = os.path.join(args["output_folder"], args["task"])

    # Read arguments
    number_of_nodes = args["number_of_nodes"]

    # Name the dataset folder
    if args["dataset_name"] is None:
        if args["m_density"] is not None:
            arg_density = args["m_density"]
        else:
            arg_density = args["p_density"]
        args["dataset_name"] = "-".join(
            [
                str(number_of_nodes),
                args["graph_generation_algorithm"],
                str(arg_density),
                args["scm_type"],
                args["mechanisms"],
                args["noise_distribution"],
            ]
        )
        if standardize:
            args["dataset_name"] += "-std"
        else:
            args["dataset_name"] += "-nostd"

        # raw seed, before changed by dataset_name
        args["dataset_name"] += f"-SO-{args['seed_offset']}"

    dataset_directory = create_dir(args["output_folder"], args["dataset_name"])

    generator = compute_seed_generator(args["task"], args["seed"])

    # Sample and store datasets
    for id in range(args["number_of_datasets"]):
        curr_seed = get_seed_from_generator(generator)

        curr_noise_seed = get_seed_from_generator(generator)
        curr_noise_generator = compute_seed_generator("noise", curr_noise_seed)

        data_file = os.path.join(dataset_directory, f"data_{id}.npy")
        groundtruth_file = os.path.join(dataset_directory, f"groundtruth_{id}.npy")

        # Noise generator
        if args["noise_distribution"] == "gauss":
            # Sample noise on CPU to avoid NaN
            # (PyTorch bug: https://discuss.pytorch.org/t/why-am-i-getting-a-nan-in-normal-mu-std-rsample/117401/8)
            noise_generator = noise.Normal(
                0,
                curr_noise_generator.uniform(
                    0.4, np.sqrt(2) * 0.4, size=number_of_nodes
                ),
            )  # deviation as in CSiVA experiments
        elif args["noise_distribution"] == "uniform":
            noise_generator = noise.Uniform(
                low=curr_noise_generator.uniform(0.1, 0.4, size=number_of_nodes),
                high=curr_noise_generator.uniform(0.6, 0.9, size=number_of_nodes),
            )
        elif args["noise_distribution"] == "exponential":
            noise_generator = noise.Exponential(
                scale=curr_noise_generator.uniform(0.5, 2, size=number_of_nodes)
            )
        elif args["noise_distribution"] == "gumbel":
            noise_generator = noise.Gumbel(
                loc=curr_noise_generator.uniform(0.5, 2, size=number_of_nodes),
                scale=curr_noise_generator.uniform(0.5, 2, size=number_of_nodes),
            )
        elif args["noise_distribution"] == "beta":
            noise_generator = noise.Beta(
                a=curr_noise_generator.uniform(2, 3, size=number_of_nodes),
                b=curr_noise_generator.uniform(2, 3, size=number_of_nodes),
            )
        elif args["noise_distribution"] == "gamma":
            noise_generator = noise.Gamma(
                shape=curr_noise_generator.uniform(1, 3, size=number_of_nodes),
                scale=curr_noise_generator.uniform(1, 3, size=number_of_nodes),
            )
        elif args["noise_distribution"] == "mlp":
            noise_generator = noise.MLPNoise(
                standardize=True
            )  # standardize=False removes varsortability
        elif args["noise_distribution"] == "mlp2":
            noise_generator = noise.MLPNoise(
                standardize=True, a_weight=-1.5, b_weight=1.5
            )  # standardize=False removes varsortability
        elif args["noise_distribution"].startswith("mlp_"):
            param = float(args["noise_distribution"][4:])
            assert param > 0
            noise_generator = noise.MLPNoise(
                standardize=True, a_weight=-param, b_weight=+param
            )  # standardize=False removes varsortability
       
        else:
            raise ValueError(f"Unsupported noise type {args['noise_distribution']}.")

        # Graph generator
        if args["graph_generation_algorithm"] == "ER":
            graph_generator = rg.ErdosRenyi(
                num_nodes=number_of_nodes,
                expected_degree=args["m_density"],
                p_edge=args["p_density"],
                min_num_edges=args["min_num_edges"],
            )
        elif args["graph_generation_algorithm"] == "SF":
            graph_generator = rg.BarabasiAlbert(
                num_nodes=number_of_nodes,
                expected_degree=args["m_density"],
                min_num_edges=args["min_num_edges"],
            )
        elif args["graph_generation_algorithm"] == "GRP":
            p_out = args["p_density"] / 5  # TODO make input script argument
            n_clusters = math.floor(
                number_of_nodes / 3
            )  # TODO make input script argument
            graph_generator = rg.GaussianRandomPartition(
                num_nodes=number_of_nodes,
                p_in=args["p_density"],
                p_out=p_out,
                n_clusters=n_clusters,
            )

        # Causal mechanism generator
        if args["mechanisms"] == "nn":
            causal_mechanism = cm.NeuralNetMechanism()
        elif args["mechanisms"] == "gp":
            causal_mechanism = cm.GaussianProcessMechanism()
        elif args["mechanisms"] == "weak-nonlinear":
            causal_mechanism = cm.NeuralNetMechanism(
                weights_std=0.2, activation=nn.Identity(), scaling=5
            )
        elif args["mechanisms"] == "linear":
            pass
        elif args["mechanisms"] == "inv":
            inv_scale = curr_noise_generator.uniform(0.5, 5)
            # overwrite existing noise
            noise_generator = noise.Gumbel(loc=0, scale=inv_scale)
            
        elif args["mechanisms"] is not None:
            raise ValueError(f"Unsupported causal mechanism {args['mechanisms']}.")

        # Model generator
        if args["scm_type"] == "anm":
            if args["mechanisms"] == "inv":
                model = InvertibleAdditiveNoiseModel(
                num_samples=args["samples_per_dataset"],
                graph_generator=graph_generator,
                noise_scale=inv_scale,
                seed=curr_seed,
            )
            else:
                model = scm.AdditiveNoiseModel(
                    num_samples=args["samples_per_dataset"],
                    graph_generator=graph_generator,
                    noise_generator=noise_generator,
                    causal_mechanism=causal_mechanism,
                    seed=curr_seed,
                )
        elif args["scm_type"] == "linear":
            model = scm.LinearModel(
                num_samples=args["samples_per_dataset"],
                graph_generator=graph_generator,
                noise_generator=noise_generator,
                seed=curr_seed,
            )
        elif args["scm_type"] == "linear2":
            model = scm.LinearModel(
                num_samples=args["samples_per_dataset"],
                graph_generator=graph_generator,
                noise_generator=noise_generator,
                seed=curr_seed,
                min_abs_weight=0.5,
                min_weight=-3,
                max_weight=3,
            )

        elif args["scm_type"].startswith("linear_"):
            parts = args["scm_type"][7:].split("_")

            model = scm.LinearModel(
                num_samples=args["samples_per_dataset"],
                graph_generator=graph_generator,
                noise_generator=noise_generator,
                seed=curr_seed,
                min_abs_weight=float(parts[0]),  # 1?
                min_weight=-float(parts[1]),
                max_weight=float(parts[1]),
            )
        elif args["scm_type"] == "pnl":
            model = scm.PostNonlinearModel(
                num_samples=args["samples_per_dataset"],
                graph_generator=graph_generator,
                noise_generator=noise_generator,
                causal_mechanism=causal_mechanism,
                invertible_function=lambda x: x**3,  # type: ignore
                seed=curr_seed,
            )
        elif args["scm_type"] == "mixed":
            model = scm.MixedLinearNonlinearModel(
                num_samples=args["samples_per_dataset"],
                graph_generator=graph_generator,
                noise_generator=noise_generator,
                linear_mechanism=cm.LinearMechanism(),
                nonlinear_mechanism=causal_mechanism,
                seed=curr_seed,
            )

        generate_and_store_dataset(data_file, groundtruth_file, model, standardize)

        if id % 500 == 0 or args["mechanisms"] == "inv":
            print(f"Done {id}")

    # Store the generation parameters
    path = os.path.join(dataset_directory, "config.json")
    with open(path, "w") as f:
        json.dump(args, f, indent=6)

    print(f"Stored dataset {args['dataset_name']} in {dataset_directory}")
