import torch
import  numpy as np
# Projection to simplex
def projection2simplex(y):
    y = y.view(-1)
    m = len(y)
    sorted_y = torch.sort(y, descending=True)[0]
    tmpsum = 0.0
    tmax_f = (torch.sum(y) - 1.0) / m
    for i in range(m - 1):
        tmpsum += sorted_y[i]
        tmax = (tmpsum - 1) / (i + 1.0)
        if tmax > sorted_y[i + 1]:
            tmax_f = tmax
            break
    return torch.max(y - tmax_f, torch.zeros(m).to(y.device))


def projection2one(y):
    abs_y = abs(y)
    sum_y = torch.sum(abs_y)
    return (abs_y / sum_y).to(y.device)

def smooth(x, n=20):
    l = len(x)
    y = []
    for i in range(l):
        ii = max(0, i - n)
        jj = min(i + n, l - 1)
        v = np.array(x[ii:jj]).astype(np.float64)
        if i < 3:
            y.append(x[i])
        else:
            y.append(v.mean())
    return y

def mean_grad(grads):
    return grads.mean(1)

def perform(F, x, lambd, maps):
    _, grads_true = F(x, True, 'emp')
    g_true = maps["mgd"](grads_true)
    g_true_norm = np.linalg.norm(g_true)  # 二范数

    _, Gs = F(x, True, 'pop')
    G = maps["mgd"](Gs)
    G_norm = np.linalg.norm(G)

    CA_erro = np.linalg.norm(grads_true @ lambd - g_true) ** 2

    return g_true_norm, G_norm, CA_erro





