from copy import deepcopy
from scipy.optimize import minimize


def func(x, coefficient, group_norm, batch_size, nabla, noise_multiplier, param_numel):
    """
    Examples
    :param x0: [median(norm_0), median(norm_1), ..., median(norm_9)]
    :param coefficient: 2(1+1/lr)
    :param group_norm: {0: [norm_0], 1: [norm_1], ..., 9: [norm_9]}
    :param batch_size: n
    :param nabla: nabla
    :param noise_multiplier: sigma
    :param param_numel: |\mathbf{I}|
    :return: clipping bias + noise addition variance
    """
    clip_bias = 0.
    noise_var = 0.
    for idx, (_, norm) in enumerate(sorted(group_norm.items(), key=lambda x: x)):
        for g in norm:
            clip_bias += max(0, g - x[idx])
        noise_var += (x[idx] * len(norm) / batch_size) ** 2
    clip_bias /= batch_size
    clip_bias = coefficient * nabla * clip_bias + clip_bias ** 2
    noise_var *= noise_multiplier ** 2 * param_numel / (batch_size ** 2)

    return clip_bias + noise_var


def compute_minimize(l2_norm_clip, lr, group_norm, batch_size, nabla, noise_multiplier, param_numel):
    coefficient = (1. + 1. / lr) * 2
    bnds = ((0, None),) * len(l2_norm_clip)
    res = minimize(fun=func, x0=l2_norm_clip,
                   args=(coefficient, group_norm, batch_size, nabla, noise_multiplier, param_numel), bounds=bnds,
                   tol=1e-5)

    return res.x
