import numpy as np
import torch


def compute_gradients(model, loss):
    model.zero_grad()
    loss.backward(retain_graph=True)
    param_grad = []
    for param in model.parameters():
        param_grad = param.grad.data.view(-1) if not len(param_grad) else torch.cat((param_grad, param.grad.view(-1)))
    # return [param.grad.clone() for param in model.parameters()]

    return param_grad


# def random_poison_gradients(gradients, poison_factor=10):
#     return [grad + torch.randn_like(grad) * poison_factor for grad in gradients]


def get_malicious_updates_fang_trmean(all_updates, deviation, n_attackers, epoch_num, q_level=2, norm='inf'):
    b = 2
    max_vector = torch.max(all_updates, 0)[0]
    min_vector = torch.min(all_updates, 0)[0]

    max_ = (max_vector > 0).type(torch.FloatTensor).cuda()
    min_ = (min_vector < 0).type(torch.FloatTensor).cuda()

    max_[max_ == 1] = b
    max_[max_ == 0] = 1 / b
    min_[min_ == 1] = b
    min_[min_ == 0] = 1 / b

    max_range = torch.cat((max_vector[:, None], (max_vector * max_)[:, None]), dim=1)
    min_range = torch.cat(((min_vector * min_)[:, None], min_vector[:, None]), dim=1)

    rand = torch.from_numpy(np.random.uniform(0, 1, [len(deviation), n_attackers])).type(torch.FloatTensor).cuda()

    max_rand = torch.stack([max_range[:, 0]] * rand.shape[1]).T + rand * torch.stack([max_range[:, 1] - max_range[:, 0]] * rand.shape[1]).T
    min_rand = torch.stack([min_range[:, 0]] * rand.shape[1]).T + rand * torch.stack([min_range[:, 1] - min_range[:, 0]] * rand.shape[1]).T

    mal_vec = (torch.stack([(deviation > 0).type(torch.FloatTensor)] * max_rand.shape[1]).T.cuda() * max_rand + torch.stack(
        [(deviation > 0).type(torch.FloatTensor)] * min_rand.shape[1]).T.cuda() * min_rand).T

    mal_updates = torch.cat((mal_vec, all_updates), 0)

    return mal_updates

