import numpy
from data_generation.transform_functions import *
from data_generation.mediator_shapes import *

import os
from sacred import Ingredient

data_ingredient = Ingredient("data_ingredient")


@data_ingredient.config
def cfg():
    file_path: str = None
    name: str = None
    id: int = None
    dictionary = {
        "X": {
            "type": "normal",
            "length": 100,
        },
        "transformation": {
            "type": "neural_network",
            "args": {
                "num_hidden": 10,
                "num_parents": 1,
            },
        },
        "shape": "sequence",
        "depth": 10,
        "noise_type": "uniform",
    }


def load_data(config):
    if config["file_path"] is not None:
        file_path = config["file_path"]
        if not os.path.exists(file_path):
            raise FileNotFoundError(f"File not found: {file_path}")

        return numpy.loadtxt(file_path, dtype=numpy.float32)
    elif config["name"] is not None:
        if config["name"] == "tuebingen":
            from cdt.data import load_dataset

            t_data, _ = load_dataset("tuebingen")
            a_list = t_data.iloc[config["id"]]["A"]
            b_list = t_data.iloc[config["id"]]["B"]

            # Stack them into a (n_samples, 2) array
            data = numpy.float32(numpy.column_stack((a_list, b_list)))
            data = (data - data.mean(axis=0)) / data.std(axis=0)

            # Take a random selection of the points
            rng = numpy.random.default_rng()
            indices = rng.choice(
                data.shape[0],
                size=min(config["n_samples"], data.shape[0]),
                replace=False,
            )
            return data[indices]
        if config["name"] == "sachs":
            from cdt.data import load_dataset

            sachs_data, _ = load_dataset("sachs")

            a_id, b_id = config["id"].split(",")
            a_list = sachs_data[a_id]
            b_list = sachs_data[b_id]

            # Stack them into a (n_samples, 2) array
            data = numpy.float32(numpy.column_stack((a_list, b_list)))
            data = (data - data.mean(axis=0)) / data.std(axis=0)

            # Take a random selection of the points
            rng = numpy.random.default_rng()
            indices = rng.choice(
                data.shape[0],
                size=min(config["n_samples"], data.shape[0]),
                replace=False,
            )
            return data[indices]
    elif config["dictionary"] is not None:
        return numpy.float32(data_from_dict(config["dictionary"]))
    else:
        raise ValueError("You must specify how to generate the data")


def data_from_dict(params: dict) -> numpy.array:
    # Get random number generator
    if "seed" in params:
        rng = numpy.random.default_rng(params["seed"])
    else:
        rng = numpy.random.default_rng()

    # Generate X:
    x_data = generate_noise(rng, params["X"]["type"], params["X"]["length"])

    # Get transformation function
    if params["transformation"]["type"] == "neural_network":
        transformation_generator = neural_network_transform
    elif params["transformation"]["type"] == "quadratic":
        transformation_generator = quadratic_transform
    elif params["transformation"]["type"] == "linear":
        transformation_generator = linear_transform
    elif params["transformation"]["type"] == "cubic":
        transformation_generator = cubic_transform
    elif params["transformation"]["type"] == "tanh":
        transformation_generator = tanh_transform
    elif params["transformation"]["type"] == "negative_tanh":
        transformation_generator = negative_tanh_transform
    elif params["transformation"]["type"] == "prelu":
        transformation_generator = prelu_transform
    else:
        raise ValueError("Unknown transformation function")

    # Generate Y
    if params["shape"] == "sequence":
        y_data = sequential_mediator(
            x_data,
            depth=params["depth"],
            rng=rng,
            transformation_generator=transformation_generator,
            transformation_args=params["transformation"]["args"],
            noise_type=params["noise_type"],
            noise_parameters=params.get("noise_parameters", None),
        )
    elif params["shape"] == "one_mediator":
        y_data = one_mediator(
            x_data,
            depth=params["depth"],
            rng=rng,
            transformation_generator=transformation_generator,
            transformation_args=params["transformation"]["args"],
            noise_type=params["noise_type"],
            noise_parameters=params.get("noise_parameters", None),
            mediator_noise_type=params.get("mediator_noise_type", None),
            mediator_noise_parameters=params.get("mediator_noise_parameters", None),
        )
    elif params["shape"] == "multiple_mediators":
        y_data = multiple_mediators(
            x_data,
            depth=params["depth"],
            rng=rng,
            transformation_generator=transformation_generator,
            transformation_args=params["transformation"]["args"],
            noise_type=params["noise_type"],
            noise_parameters=params.get("noise_parameters", None),
            mediator_noise_type=params.get("mediator_noise_type", None),
            mediator_noise_parameters=params.get("mediator_noise_parameters", None),
        )
    elif params["shape"] == "confounded":
        x_data, y_data = confounded(
            x_data,
            depth=params["depth"],
            rng=rng,
            transformation_generator=transformation_generator,
            transformation_args=params["transformation"]["args"],
            noise_type=params["noise_type"],
        )
    else:
        raise ValueError("Unknown mediator shape")

    # Put everything together
    if params.get("standardize", True):
        return numpy.column_stack(
            [standardize_vectors(x_data)[0], standardize_vectors(y_data)[0]]
        )
    return numpy.column_stack([x_data, y_data])
