import numpy as np
import torch


def to_tensor(array):
    return torch.tensor(array, dtype=torch.float64)


def generate_dataset(conductance_matrix, generator_inds, load_inds, n_samples, i_max_mean=25,
                     norm_factor=25, data_seed=None, vary_constraints=True, features_rank=None, noise_std=0):

    if data_seed is not None:
        np.random.seed(data_seed)

    v_ref = 350
    v0 = v_ref / norm_factor
    n_nodes = len(conductance_matrix)
    Y = -np.copy(conductance_matrix)
    Y[np.arange(n_nodes), np.arange(n_nodes)] = -Y.sum(1)
    v_lbs, v_ubs = 325 * np.ones(n_nodes) / norm_factor, 375 * np.ones(n_nodes) / norm_factor
    dataset = []
    connection_matrix_triangular = np.zeros((n_nodes, n_nodes))
    connected_node_pairs = []
    for i in range(n_nodes):
        for j in range(i, n_nodes):
            if conductance_matrix[i, j] != 0:
                connected_node_pairs.append((i, j))
                connection_matrix_triangular[i, j] = 1
    for bs in range(n_samples):
        if vary_constraints:
            p_lbs = np.zeros(n_nodes)
            p_lbs[generator_inds] = np.minimum(0, np.random.normal(loc=-14000, scale=2500, size=len(generator_inds)))
            p_ubs = np.zeros(n_nodes)
            p_ubs[load_inds] = np.maximum(np.random.normal(loc=8000, scale=2500, size=len(load_inds)), 0)
            i_max_matrix = connection_matrix_triangular * np.random.normal(loc=i_max_mean, scale=i_max_mean/5,
                                                                           size=(n_nodes, n_nodes))
        else:
            p_lbs = np.zeros(n_nodes)
            p_lbs[generator_inds] = -12000
            p_ubs = np.zeros(n_nodes)
            p_ubs[load_inds] = 8000
            i_max_matrix = i_max_mean * connection_matrix_triangular
        i_max_matrix = i_max_matrix + i_max_matrix.T
        p_w_lin = np.ones(n_nodes)
        p_w_lin[generator_inds] = np.random.normal(loc=.8, scale=.1, size=len(generator_inds))
        p_w_lin[load_inds] = np.random.normal(loc=1.2, scale=.1, size=len(load_inds))
        v_w_lin = -v_ref * Y @ p_w_lin
        v_w_lin /= np.max(np.abs(v_w_lin))
        p_W_sq = np.zeros((n_nodes, n_nodes))
        v_W_sq = -v_ref * Y @ p_W_sq

        A_bounds = []
        for n in range(n_nodes - 1):
            row = np.zeros(n_nodes - 1)
            row[n] = 1
            A_bounds.append(row)
        for n in range(n_nodes - 1):
            row = np.zeros(n_nodes - 1)
            row[n] = -1
            A_bounds.append(row)
        A_bounds = np.array(A_bounds)
        A_imax = []
        b_imax = []
        for i, j in connected_node_pairs:
            g = conductance_matrix[i, j]
            row = np.zeros(n_nodes - 1)
            if i != 0 and j != 0:
                row[i - 1] = g
                row[j - 1] = -g
                A_imax.append(row)
                b_imax.append(-i_max_matrix[i, j])
                A_imax.append(-row)
                b_imax.append(-i_max_matrix[i, j])
            elif i == 0:
                row[j - 1] = -g
                A_imax.append(row)
                b_imax.append(350 * g - i_max_matrix[i, j])
                A_imax.append(-row)
                b_imax.append(-350 * g - i_max_matrix[i, j])
            else:
                row[j - 1] = -g
                A_imax.append(row)
                b_imax.append(350 * g - i_max_matrix[i, j])
                A_imax.append(-row)
                b_imax.append(-350 * g - i_max_matrix[i, j])

        A_imax = np.array(A_imax)
        b_imax = np.array(b_imax)
        A = np.concatenate([-Y[:, 1:] * v_ref, Y[:, 1:] * v_ref, A_bounds, A_imax])
        b = np.concatenate([-p_ubs / norm_factor + -v_ref * v0 * Y[:, 0],
                            p_lbs / norm_factor - -v_ref * v0 * Y[:, 0],
                            -v_ubs[1:], v_lbs[1:], b_imax/norm_factor])

        features_matrix_raw = to_tensor(np.concatenate([v_w_lin, p_lbs / 1e4, p_ubs / 1e4, i_max_matrix.reshape(-1)]))
        features_matrix_raw = features_matrix_raw * to_tensor(np.random.normal(loc=1, scale=noise_std,
                                                                               size=features_matrix_raw.shape))
        '''features_matrix = ((feature_transform_matrix @ features_matrix_raw[:, None])[:, 0] / n_nodes
                           # + to_tensor(np.random.normal(loc=0, scale=noise_std, size=features_rank))
                           )'''
        features_matrix = to_tensor(features_matrix_raw)
        true_problem_parameters = {'A': to_tensor(A)[None],
                                   'b': to_tensor(b)[None],
                                   # 'p_lbs': to_tensor(p_lbs)[None],
                                   # 'p_ubs': to_tensor(p_ubs)[None],
                                   'i_max_matrix': to_tensor(i_max_matrix)[None],
                                   'w_lin': to_tensor(v_w_lin)[None, 1:],
                                   'W_sq': to_tensor(v_W_sq)[None, 1:, 1:],
                                   'w_lin_0': to_tensor(v_w_lin)[None, :1],
                                   'features': features_matrix[None]}
        dataset.append(true_problem_parameters)
    return dataset
