import numpy as np
import numpy.random as rand
from daosim import standardize, simulate, corr
from causalAssembly.models_dag import ProductionLineGraph

# from causalAssembly.drf_fitting import fit_drf
from dagma.utils import simulate_nonlinear_sem

from graph import CausalGraph


def sample(id, g, rng=rand.default_rng()):
    if id == "sachs":
        X = np.genfromtxt("external_data/sachs_data.csv", delimiter=",", skip_header=1)
        n_obs = X.shape[0]
        boot_idx = rng.choice(n_obs, size=n_obs, replace=True)
        X = X[boot_idx, :]
        # standardize(X)
        return X
    tokens = id.split("-")
    if tokens[1] == "causalAssembly":
        # assembly_line_data = ProductionLineGraph.get_data()
        # assembly_line = ProductionLineGraph.get_ground_truth()
        # assembly_line.random_state = rng
        # assembly_line.drf = fit_drf(assembly_line, assembly_line_data)
        # X = assembly_line.sample_from_drf(size=int(tokens[0]))
        # X = ProductionLineGraph.get_data().to_numpy()
        # return X
        X = ProductionLineGraph.get_data().to_numpy()
        n_obs = X.shape[0]
        num_samples = int(tokens[0])
        subsample = rng.choice(n_obs, size=num_samples, replace=True)
        X = X[subsample, :]
        return X
    elif tokens[1] in ["mlp", "gp"]:
        p = g.num_nodes()
        gm = np.zeros((p, p), dtype=np.int8)
        for u in range(p):
            for v in g.dir_neighbors[u]:
                gm[u, v] = 1
        n = int(tokens[0])
        return simulate_nonlinear_sem(gm, n, tokens[1])

    num_samples = int(tokens[0])
    noise = tokens[1]
    correction = tokens[2]

    if correction == "onion":
        return sample_onion(g, num_samples)
    else:
        return sample_anm(g, num_samples, noise, correction, rng)


def sample_onion(g: CausalGraph, n, rng=rand.default_rng()):
    gm = g.transpose_graph().to_matrix()
    _, B, O = corr(gm, rng=rng)
    X = simulate(B, O, n, rng=rng)
    X = standardize(X)
    return X


def sample_anm(g: CausalGraph, n, noise, correction, rng=rand.default_rng()):
    p = g.num_nodes()
    tg = g.transpose_graph()
    w = [
        [rng.choice([-1, 1]) * rng.uniform(0.25, 1.0) for _ in tg.dir_neighbors[u]]
        for u in range(p)
    ]
    stddev = rng.uniform(0.5, 2.0)
    if noise == "gaussian":
        stddev = rng.uniform(0.5, 2.0)
        X = rng.normal(0.0, stddev, (n, p))
    elif noise == "uniform":
        X = rng.uniform(-1.0, 1.0, (n, p))
    else:
        raise ValueError(f"noise {noise} not supported")

    for v in g.topological_order():
        for i, u in enumerate(tg.dir_neighbors[v]):
            X[:, v] += w[v][i] * X[:, u]

    if correction == "standardized":
        X = (X - np.mean(X, axis=0)) / np.std(X, axis=0)
    elif correction == "raw":
        pass
    else:
        raise ValueError(f"correction {correction} not supported")
    return X


def write_to_file(X, data_file):
    p = X.shape[1]
    header = ",".join([f"X{i}" for i in range(p)])
    with open(data_file, "w") as file:
        np.savetxt(
            file,
            X,
            delimiter=",",
            header=header,
            comments="",
            fmt="%f",
        )


def read_from_file(data_path):
    return np.genfromtxt(data_path, delimiter=",", skip_header=1)
