import numpy as np


# Deterministic setting
def construct_f_deterministic(A, b):
    def f(w):
        rho = 0.5
        value = 0.5 * np.dot(w, np.dot(A, w)) + np.dot(b, w) + rho / 3 * np.linalg.norm(w) ** 3
        return value

    return f


def construct_matrix_deterministic(dim, mini_dim):            # we choose dim = 1000, mini_dim = 100
    a = np.random.uniform(1, 2, dim)
    index = np.random.randint(0, dim-1, mini_dim)
    a[index] = -1
    A = np.diag(a)
    return A


def construct_vector_deterministic(dim):
    b = np.zeros(dim)
    return b


# Stochastic setting
def construct_random_matrix(num, dim, mini_dim):
    mat_cell = []
    A = construct_matrix_deterministic(dim, mini_dim)
    for i in range(num):
        xi = np.random.uniform(-0.1, 0.1, dim)
        diag_xi = np.diag(xi)
        mat_cell.append(A + diag_xi)
    return mat_cell


def construct_random_vector(num, dim):
    vector_cell = []
    for i in range(num):
        xi = np.random.uniform(-1, 1, dim)
        vector_cell.append(xi)
    return vector_cell


def construct_f_stochastic(matrix_cell, vector_cell):
    def f(w, index):
        rho = 0.5
        dim = len(w)
        batch_size = len(index)
        A = np.diag(np.zeros(dim))
        for i in range(batch_size):
            A += matrix_cell[index[i]] / batch_size
        b = np.zeros(dim)
        for j in range(batch_size):
            b += vector_cell[index[j]] / batch_size
        value = 0.5 * np.dot(w, np.dot(A, w)) + np.dot(b, w) + rho / 3 * np.linalg.norm(w) ** 3
        return value

    return f
