import random

import networkx as nx
import numpy as np

from dagsolver_utils import ExDagDataException


def normalize_data(X, Y):
    mean = np.mean(X, axis=0)
    std = np.std(X, axis=0)
    X = X - mean
    X = X / std

    for i, _ in enumerate(Y):
        Y[i] = Y[i] - mean
        Y[i] = Y[i] / std

    return X, Y

def load_problem(graph_type, variant, d, n, p, intra_edge_ratio, inter_edge_ratio, w_max_inter, w_min_inter, w_decay, noise_scale, noise_scale_variance):
    tabu_edges = []
    intra_nodes = None
    inter_nodes = None #TODO: all problems should define this
    if graph_type == 'cds':
        import cds_utils
        W_true, B_true, A_true, X, Y, intra_nodes, inter_nodes, tabu_edges = cds_utils.load_data(n, 4, p, './CDS_Data')
        W_bi_true = np.zeros_like(W_true)
        B_bi_true = np.zeros_like(B_true)

    elif graph_type == 'dynamic':
        degree_intra = intra_edge_ratio * 2
        degree_inter = inter_edge_ratio * 2

        if p == 0:
            degree_inter = 0
        if variant == 'er':
            graph_type_intra = 'erdos-renyi'
        elif variant == 'sf':
            graph_type_intra = 'barabasi-albert'
        else:
            assert False

        graph_type_inter = 'er'
        if graph_type_inter == 'er':
            graph_type_inter = 'erdos-renyi'

        #from structure.data_generators.wrappers import DataGenerationException
        try:
            generator = 'notears'
            if noise_scale_variance is not None:
                noise_scale_vector = [random.uniform(noise_scale - noise_scale_variance, noise_scale + noise_scale_variance) for _ in range(d)]
            else:
                noise_scale_vector = [noise_scale] * d
            from structure.data_generators import gen_stationary_dyn_net_and_df
            g,df, intra_nodes, inter_nodes = gen_stationary_dyn_net_and_df(num_nodes=d, n_samples=n, p=p,
                                                                           degree_intra=degree_intra, degree_inter=degree_inter,
                                                                           graph_type_intra=graph_type_intra, graph_type_inter=graph_type_inter,
                                                                           w_max_intra=2.0, w_min_intra=0.5, w_min_inter=w_min_inter, w_max_inter=w_max_inter,
                                                                           w_decay=w_decay, noise_scale=noise_scale_vector, max_data_gen_trials=1000,
                                                                           generator=generator) #, w_min_inter=0.01, w_max_inter=0.2)
        except Exception as e: # DataGenerationException as e:
            print(f'Error: Cannot generate samples data. Exception: {e}', 'error.txt')
            raise ExDagDataException(e)

        W_true = nx.to_numpy_array(g, nodelist=intra_nodes)
        B_true = W_true != 0
        a_mat = nx.to_numpy_array(g, nodelist=intra_nodes + inter_nodes)[len(intra_nodes) :, : len(intra_nodes)]
        df_x = df[intra_nodes]
        df_x_lag = df[inter_nodes]
        X = df_x.to_numpy()
        W_bi_true = None
        B_bi_true = None
        # s0 = degree_intra / 2 * d
        # B_true = utils.simulate_dag(d, s0, 'SF')
        # W_true = utils.simulate_parameter(B_true)
        # X = utils.simulate_linear_sem(W_true, n, 'gauss', noise_scale=1.0)
        # X_lag = df_x_lag.to_numpy()
        #X2 = utils.simulate_linear_sem(W_true, n, 'gauss', noise_scale=1.0)

        Y = []
        A_true = []
        for lag in range(1, p + 1):
            lag_cols = [c for c in inter_nodes if f'_lag{lag}' in c]
            df_x_lag = df[lag_cols]
            Y_lag = df_x_lag.to_numpy()
            Y.append(Y_lag)

            idxs = [f'_lag{lag}' in c for c in inter_nodes]
            a_mat_lag = a_mat[idxs,:]
            A_true.append(a_mat_lag)


    else:
        assert False, 'unknown problem'
        
    # Generating default node names.
    # if intra_nodes is None:
    #     intra_nodes = [f'node_{i}' for i in range(len(X[0]))]
    # if inter_nodes is None:
    #     inter_nodes = [f'node_{i}_lag_{lag}' for lag in range(1, cfg.problem.get('p', 0) + 1) for i in range(len(X[0]))]
    
    return W_true, W_bi_true, B_true, B_bi_true, A_true, X, Y, tabu_edges, intra_nodes, inter_nodes

