
import numpy as np

import admg_generators, utils
from dagsolver_utils import ExDagDataException


def normalize_data(X, Y):
    mean = np.mean(X, axis=0)
    std = np.std(X, axis=0)
    X = X - mean
    X = X / std

    for i, _ in enumerate(Y):
        Y[i] = Y[i] - mean
        Y[i] = Y[i] / std

    return X, Y

def load_problem(graph_type, d, n, edge_ratio, pdir, pbidir, max_in_arrows):
    tabu_edges = []
    intra_nodes = None
    inter_nodes = None #TODO: all problems should define this
    if graph_type == 'cds':
        import cds_utils
        W_true, B_true, A_true, X, Y, intra_nodes, inter_nodes, tabu_edges = cds_utils.load_data(n, 4, 0, './CDS_Data')
        W_bi_true = np.zeros_like(W_true)
        B_bi_true = np.zeros_like(B_true)


    elif graph_type == 'ermag':
        s0 = edge_ratio * d
        sem_type = 'gauss'
        noise_scale = 1.0
        tabu_edges_ratio = 0.2
        hidden_vertices_ratio = 0.2
        try:
            B_true = utils.simulate_dag(d, s0, "ER")
        except Exception as e:
            print(f'Error: Cannot generate samples data. Exception: {e}')
            raise ExDagDataException(e)

        # if cfg.problem.get('only_01', False):
        #     W_true = B_true
        # elif cfg.problem.get('only_positive', False):
        #     W_true = utils.simulate_parameter(B_true, w_ranges=((0.5, 2.0),))
        # else:
        W_true = utils.simulate_parameter(B_true)
        X = utils.simulate_linear_sem(W_true, n, sem_type, noise_scale=noise_scale)
        Y = []
        A_true = []


        # p = 0

        new_d = int(d * (1-hidden_vertices_ratio))
        indices = np.random.choice(range(d), size=new_d, replace=False)
        d = new_d
        X = X[:, indices]
        W_true = W_true[np.ix_(indices, indices)]
        B_true = B_true[np.ix_(indices, indices)]
        B_bi_true = np.zeros_like(B_true)
        for i in range(d):
            for j in range(i):
                if W_true[i, j] == 0.0 and W_true[j, i] == 0.0:
                    if np.random.rand() < tabu_edges_ratio:
                        tabu_edges.append((i, j))
                        B_bi_true[i, j] = 1.0
                        B_bi_true[j, i] = 1.0 # Maybe dubious.

        W_bi_true = np.copy(B_bi_true)

    elif graph_type == 'bowfree_admg':
        B_true, B_bi_true, tabu_edges, X = admg_generators.generate_graph_and_samples(d,pdir, pbidir, max_in_arrows, n)
        W_true = B_true
        W_bi_true = B_bi_true
        Y = []
        A_true = []

    else:
        assert False, 'unknown problem'
        
    # Generating default node names.
    # if intra_nodes is None:
    #     intra_nodes = [f'node_{i}' for i in range(len(X[0]))]
    # if inter_nodes is None:
    #     inter_nodes = [f'node_{i}_lag_{lag}' for lag in range(1, cfg.problem.get('p', 0) + 1) for i in range(len(X[0]))]
    
    return W_true, W_bi_true, B_true, B_bi_true, A_true, A_true, X, Y, tabu_edges, intra_nodes, inter_nodes

