# -*- coding: utf-8 -*-
import numpy as np
import scipy as sc
import cvxpy as cp

## MSE functions


def GD_step(theta, grad, eta):
    return theta - eta * grad


def grad_MSE(theta, x, y):
    # x : array (n,d)
    # y : array (n,1)
    # theta: array (d,)
    return np.mean(np.dot(x, theta[:, None]) * x - y[:, None] * x, axis=0)


def grad_MSE_allbag(theta, x, y, ridge=1):
    """
    x : list of array (n_i,d)
    y : list of array (n_i,1)
    theta: array (d,)
    """
    T = len(x)
    list_grads = []
    for i in range(T):
        list_grads += [
            np.dot(x[i], theta)[:, None] * x[i] - y[i] * x[i] + ridge * theta
        ]
    return list_grads


def loss_MSE(theta, x, y):
    return np.mean((np.dot(x, theta) - y) ** 2)


def gradient_descent_MSE(
    X, y, Xtest, ytest, theta0, NGDstep=200, lr=1, met="local", ridge=0, **kwargs
):
    # X : list array (ni,d)
    # y : list array (ni,1)
    # Xtest : array (T,ntest,d)
    # ytest : array (T,ntest,1)
    # theta: array (T,d,)
    theta = theta0
    trainloss = np.zeros(NGDstep)
    testloss = np.zeros(NGDstep)
    for step in range(NGDstep):
        grads = grad_MSE_allbag(theta, X, y, ridge=ridge)
        grad = aggregate_grads(met, grads, **kwargs)
        theta = GD_step(theta, grad, lr)
        trainloss[step] = loss_MSE(theta, X[0], y[0])
        testloss[step] = loss_MSE(theta, Xtest, ytest)
    return theta, trainloss, testloss


## Functions for mean aggreagation


def compute_naive_estimator(Z):
    # Z : list of array (ni,d)
    _, d = Z[0].shape
    T = len(Z)
    means = np.zeros((T, d))
    for i, zi in enumerate(Z):
        means[i] = np.mean(zi, axis=0)
    return means


def aggregate_grads(met, grads, **kwargs):
    if met == "naive":
        return np.mean(grads[0], axis=0)
    if met == "grandmean":
        return grandmean_grads(grads)
    if met == "weighted":
        return weighted_grads(grads, **kwargs)


def weighted_grads(grads, weights=None):
    """
    grads : list of array (ni,d)
    weights : array (T,)
    """
    T = len(grads)
    _, d = grads[0].shape
    if weights is None:
        weights = np.zeros(T)
        weights[0] = 1
    muNE = np.zeros((T, d))
    for i in range(T):
        muNE[i] = np.mean(grads[i], axis=0)
    return np.dot(weights, muNE)


def grandmean_grads(grads):
    """
    grads : list of array (ni,d)
    """
    T = len(grads)
    _, d = grads[0].shape
    muNE = np.zeros((T, d))
    for i in range(T):
        muNE[i] = np.mean(grads[i], axis=0)
    return np.mean(muNE, axis=0)


## Optimization function
def minimize_quadratic_on_simplex(A, b, clipping=True):
    """
    Minimize omega^T A omega + b^T omega
    subject to omega >= 0 and sum(omega) = 1
    """

    B = b.shape[0]

    # Optimization variable
    omega = cp.Variable(B)

    # Objective
    objective = cp.Minimize(cp.quad_form(omega, A) + b @ omega)

    # Constraints (simplex)
    constraints = [omega >= 0, cp.sum(omega) == 1]

    # Solve problem
    problem = cp.Problem(objective, constraints)
    problem.solve()

    if clipping:
        output = np.maximum(omega.value, 0)
        output /= np.sum(output)
    else:
        output = omega.value
    return output, problem.value


def egd_learning_rate(A, b, safety=0.5):
    """
    Learning rate théoriquement motivé pour EGD
    quand A est symétrique définie positive
    """
    op_norm = np.linalg.norm(A, 2)  # norme opérateur (valeur singulière max)
    b_inf = np.linalg.norm(b, np.inf)

    eta = safety / (2 * op_norm + b_inf)
    return eta


def exponentiated_gradient_quadratic(A, b, w0=None, lr=0.1, n_iter=1000, tol=1e-8):
    """
    Minimise f(w) = w^T A w + b^T w
    sous la contrainte w ∈ simplex (w >= 0, sum(w) = 1)
    via Exponentiated Gradient Descent.

    Parameters
    ----------
    A : ndarray (d, d)
        Matrice de la forme quadratique
    b : ndarray (d,)
        Vecteur linéaire
    w0 : ndarray (d,), optional
        Initialisation sur le simplexe (uniforme par défaut)
    lr : float
        Pas d'apprentissage
    n_iter : int
        Nombre d'itérations
    tol : float
        Tolérance sur la variation de w

    Returns
    -------
    w : ndarray (d,)
        Solution approchée
    history : list
        Valeurs de la fonction objectif
    """

    d = b.shape[0]

    # Initialisation sur le simplexe
    if w0 is None:
        w = np.ones(d) / d
    else:
        w = np.clip(w0, 1e-12, None)
        w /= w.sum()

    history = []

    for _ in range(n_iter):
        # Gradient
        grad = 2 * np.dot(A, w) + b

        # for numerical approxiamtion
        update = -lr * grad - np.max(-lr * grad)

        # Mise à jour exponentiée
        w_new1 = w * np.exp(update)

        # Renormalisation (projection sur le simplexe)
        w_new1 /= w_new1.sum()

        # correction
        grad = 2 * np.dot(A, w_new1) + b
        update = -lr * grad - np.max(-lr * grad)

        w_new = w * np.exp(update)
        w_new /= w_new.sum()

        # Valeur de la fonction
        obj = np.dot(w_new, np.dot(A, w_new)) + np.dot(b, w_new)
        history.append(obj)

        # Critère d'arrêt
        if np.linalg.norm(w_new - w, 1) < tol:
            w = w_new
            break

        w = w_new

    return w, history


## q aggregation method


def naivemeans(Z):
    """
    Z : list of array (ni,d)
    """
    muNE = []
    for zz in Z:
        muNE += [np.mean(zz, axis=0)]
    return np.array(muNE)


def Qaggregation(Z, M, i=0, c0=1.5, cbs=1, metopt="expo"):
    """
    Z : list of array (ni,d)
    M : positive float, upper bound on data
    """
    muNE = naivemeans(Z)
    diff = muNE - muNE[i]
    Gram = np.dot(diff, diff.T)
    Ztarget = Z[i]
    Ni, _ = Ztarget.shape
    Sigmai = np.cov(
        Ztarget, rowvar=False, bias=False
    )  # rowvar=False since rows are samples
    Q = c0 * np.sqrt(np.sum(np.dot(diff, Sigmai) * diff, axis=1) / Ni)
    Q[i] = 2 * np.trace(Sigmai) / Ni
    Qbs = M * np.linalg.norm(diff, axis=1, ord=2) / Ni
    A = Gram
    b = Q + cbs * Qbs
    if metopt == "expo":
        lr = egd_learning_rate(A, b)
        omega, _ = exponentiated_gradient_quadratic(A, b, lr=lr)
    else:
        omega, _ = minimize_quadratic_on_simplex(Gram, b)
    return omega


## Random Fourier Features


def RFF(X, D, variance=None):
    """
    X : list of array (ni,d)
    D : int, dimension of the RFF
    """
    _, d = np.shape(X[0])
    if variance is None:
        variance = np.eye(d)
    omega = np.random.multivariate_normal(np.zeros(d), variance, size=D)
    phi = np.random.uniform(0, 2 * np.pi, size=D)
    output = []
    for XX in X:
        aux = np.dot(XX, omega.T)
        output += [np.cos(aux + phi) * np.sqrt(2 / D)]
    return output, omega, phi
