"""
Utilities to generate, store, and load datasets.
"""
import os

import numpy as np

from .sampling import (
    sample_linear_abstracted_models,
    sample_linear_realizations,
)


def linear_dataset(
    n_nodes: int = 3,
    n_edges: int = 3,
    graph_type: str = "ER",
    min_readout: int = 1,
    max_readout: int = 5,
    alpha: float = 1.0,
    marginalize_ratio: float = 0.2,
    internal: bool = True,
    n_samples: int = 1000,
    noise_term: str = "gaussian",
):
    """Samples a causal abstraction between random ANMs."""

    # Get the concrete and abstract models, with the ground truth abstraction
    (
        concrete,
        abstract,
        tau,
        gamma,
        pi,
    ) = sample_linear_abstracted_models(
        n_nodes,
        n_edges,
        graph_type,
        min_readout,
        max_readout,
        alpha,
        marginalize_ratio,
        internal,
    )

    # Sample the dataset
    dataset = sample_linear_realizations(
        concrete,
        abstract,
        tau,
        gamma,
        n_samples,
        noise_term,
    )

    return concrete, abstract, tau, gamma, pi, dataset


PARAM_TO_ID = {
    "n_nodes": "d",
    "n_edges": "e",
    "graph_type": "t",
    "min_readout": "m",
    "max_readout": "M",
    "alpha": "h",
    "marginalize_ratio": "p",
    "internal": "i",
    "n_samples": "n",
    "noise_term": "N",
}


def config_to_signature(dset_params: dict) -> str:
    """
    Given the configuration used to generate
    a random dataset, it returns the signature
    that is then used to store the dataset.
    """
    if set(dset_params.keys()) != set(PARAM_TO_ID.keys()):
        raise ValueError(
            "The configuration must contain all the parameters "
            f"{PARAM_TO_ID.keys()}."
        )
    return "_".join(
        [
            f"{PARAM_TO_ID[param]}{value}"
            for param, value in dset_params.items()
        ]
    )


def signature_to_config(signature: str) -> dict:
    """
    Given the signature of a dataset, it returns
    the configuration used to generate it.
    """
    dset_params = {}
    tokens = signature.split("_")
    for param in PARAM_TO_ID:  # type: ignore
        for token in tokens:
            if token.startswith(PARAM_TO_ID[param]):
                dset_params[param] = token[1:]
                if param in [
                    "value_interventions",
                    "marginalize_ratio",
                    "alpha",
                ]:
                    try:
                        dset_params[param] = float(dset_params[param])
                    except ValueError:
                        dset_params[param] = None
                elif param in [
                    "abs_interventions",
                    "joint_interventions",
                    "internal",
                ]:
                    dset_params[param] = dset_params[param] == "True"
                elif param not in ["graph_type", "noise_term"]:
                    try:
                        dset_params[param] = int(dset_params[param])
                    except ValueError:
                        dset_params[param] = None
    return dset_params


def load_dataset(data_path: str, signature: str, num: int):
    """Loads from disk"""
    # filename
    fname = f"{signature}_run{num}.npz"
    path = os.path.join(data_path, fname)

    # load dataset
    data = np.load(path)

    concrete = data["concrete"]
    abstract = data["abstract"]
    tau = data["tau"]
    gamma = data["gamma"]
    pi = data["pi"]
    dataset = (data["samples_x"], data["samples_y"])

    return concrete, abstract, tau, gamma, pi, dataset


def check_dataset(data_path: str, signature: str, num: int):
    """Checks if the dataset exists."""
    # filename
    fname = f"{signature}_run{num}.npz"
    path = os.path.join(data_path, fname)
    return os.path.exists(path)


def generate_datasets(
    dset_params: dict, data_path: str, n_repetitions: int, force: bool = False
) -> str:
    """Stores a simulated linear dataset."""
    # create data path if it does not exist
    if not os.path.exists(data_path):
        os.makedirs(data_path)

    # generates signature
    signature = config_to_signature(dset_params)

    # iterate over number of repetitions
    for rep_num in range(n_repetitions):
        # try to load the dataset
        if check_dataset(data_path, signature, rep_num):
            print(f"Dataset {signature}_IT{rep_num} already exists.")
            if not force:
                continue

        # generate dataset
        (
            concrete,
            abstract,
            tau,
            gamma,
            pi,
            dataset,
        ) = linear_dataset(**dset_params)

        samples_x, samples_y = dataset

        # filename
        fname = f"{signature}_run{rep_num}.npz"

        # save dataset
        np.savez(
            os.path.join(data_path, fname),
            concrete=concrete,
            abstract=abstract,
            tau=tau,
            gamma=gamma,
            pi=pi,
            samples_x=samples_x,
            samples_y=samples_y,
        )

        print(f"Dataset {fname} generated.")

    return signature
