import torch
import torch.nn as nn

import numpy as np

softmax = nn.Softmax(dim=1)
onehot_enc = torch.eye(10)

def projection2simplex(y):
    m = len(y)
    sorted_y = torch.sort(y, descending=True)[0]
    tmpsum = 0.0
    tmax_f = (torch.sum(y) - 1.0) / m
    for i in range(m - 1):
        tmpsum += sorted_y[i]
        tmax = (tmpsum - 1) / (i + 1.0)
        if tmax > sorted_y[i + 1]:
            tmax_f = tmax
            break
    return torch.max(y - tmax_f, torch.zeros(m).to(y.device))

def projection2one(y):
    abs_y = abs(y)
    sum_y = torch.sum(abs_y)
    return (abs_y / sum_y).to(y.device)

def compute_lambd(grads, modo_kwargs):
    lambd = modo_kwargs['lambd']
    gamma = modo_kwargs['gamma']
    rho = modo_kwargs['rho']
    lambd = projection2simplex(lambd - gamma * (grads @ (torch.transpose(grads, 0, 1) @ lambd) + rho * lambd))

    # update lambda
    modo_kwargs['lambd'] = lambd

    return lambd @ grads

def compute_gradient(model, adv_image, true_label, loss_dict, modo_kwargs, args, device):
    grad_list = []

    for k, loss_fn in enumerate(loss_dict):
        grad_est = torch.zeros_like(adv_image.data[0])
        for i in range(adv_image.shape[1]):
            for j in range(adv_image.shape[3]):
                e = torch.ones(adv_image.shape[2]).to(device)
                adv_image[:, i, :, j] += args.fd_eps * e
                if loss_fn == 'mse' or loss_fn == 'huber':
                    loss_plus = loss_dict[loss_fn](softmax(model(adv_image)).reshape(-1),
                                                   onehot_enc[true_label].to(device))
                elif loss_fn == 'nll':
                    loss_plus = loss_dict[loss_fn](model(adv_image).reshape(-1),
                                                   true_label.to(device))
                else:
                    loss_plus = loss_dict[loss_fn](model(adv_image).reshape(-1), true_label)
                # loss_plus = self.criterion(self.model(input_sample), label_sample)
                adv_image[:, i, :, j] -=  args.fd_eps * e
                if loss_fn == 'mse' or loss_fn == 'huber':
                    loss_minus = loss_dict[loss_fn](softmax(model(adv_image)).reshape(-1),
                                                    onehot_enc[true_label].to(device))
                elif loss_fn == 'nll':
                    loss_minus = loss_dict[loss_fn](model(adv_image).reshape(-1),
                                                   true_label.to(device))
                else:
                    loss_minus = loss_dict[loss_fn](model(adv_image).reshape(-1), true_label)
                # loss_minus = self.criterion(self.model(input_sample), label_sample)
                grad_est[i, :, j] += (loss_plus - loss_minus) / (args.fd_eps)
        grad_list.append(grad_est.data.view(-1))
        grads = torch.stack(grad_list)
    grads = compute_lambd(grads, modo_kwargs).view(3, 32, 32).unsqueeze(0)
    return grads

def generate_u(args, device):
    # sample u_i and normlize
    rv = np.random.randn(1, args.s2)
    u_norm = np.linalg.norm(rv, keepdims=None)
    rv_norm = rv / u_norm
    # extend to size 1*d
    index = np.random.randint(1, args.d, size=args.s2)
    u = np.zeros(args.d)
    u[index] = rv_norm
    u = torch.tensor(u, dtype=torch.float32).to(device)
    return u.view(1, 3, 32, 32)

def compute_gradient_random(model, adv_image, true_label, loss_dict, modo_kwargs, args, device):
    # grad_list = []
    gradient_i = torch.zeros((args.q, args.num_task, args.d))
    for i in range(args.q):
        u = generate_u(args, device)
        for k, loss_fn in enumerate(loss_dict):
            if loss_fn == 'mse' or loss_fn == 'huber':
                loss_plus = loss_dict[loss_fn](softmax(model(adv_image + args.fd_eps * u)).reshape(-1),
                                               onehot_enc[true_label].to(device))
                loss_minus = loss_dict[loss_fn](softmax(model(adv_image)).reshape(-1),
                                                onehot_enc[true_label].to(device))
            else:
                loss_plus = loss_dict[loss_fn](model(adv_image + args.fd_eps * u).reshape(-1), true_label)
                loss_minus = loss_dict[loss_fn](model(adv_image).reshape(-1), true_label)
            # loss_minus = self.criterion(self.model(input_sample), label_sample)
            u = u.reshape((1, args.d))
            gradient_i[i, k, :] = u.data * (loss_plus - loss_minus) / (args.fd_eps)
            u = u.view(1, 3, 32, 32)
    gradient = torch.sum(gradient_i, axis=0) / args.q
    grads = gradient.data.squeeze(0).to(device)
    grads = compute_lambd(grads, modo_kwargs).view(1, 3, 32, 32)
    return grads