import torch
from copy import deepcopy
import math
import numpy as np
import copy
import pdb

##No attack
def benign(device, lr, param_list, cmax):

    return param_list


def shejwalkar(device, past_weights, client_grads, client_participated, is_mal, dev_type='unit_vec', capabilities='ben', agr='p2p'):

    if (agr.find('p2p') != -1): ##we have model weights and not grad values
        ##all_updates = client_grads[(client_participated == 1) & (is_mal == 0)].clone()
        if (capabilities == 'ben'):
            attack_cap = client_grads[(client_participated == 1) & (is_mal == 0)].clone()
        elif (capabilities == 'all'): 
            attack_cap = client_grads[(client_participated == 1)].clone()
        elif (capabilities == 'mal'):
            attack_cap = client_grads[(client_participated == 1) & (is_mal == 1)].clone()

        model_re = torch.mean(attack_cap, 0) ##average of benign model weights
        grads_re = model_re - past_weights

    n_attackers = len(np.where(is_mal==1)[0])

    if dev_type == 'unit_vec':
        deviation = grads_re / torch.norm(grads_re)  # unit vector, dir opp to good dir
    elif dev_type == 'sign':
        deviation = torch.sign(grads_re)
    elif dev_type == 'std':
        deviation = torch.std(attack_cap, 0) ##deviation is of the models - is it right? Comparing all models to nets[0]. a fixed model, std should be the nsame of grads and weights, right?

    del attack_cap
    all_updates = client_grads[(client_participated == 1)].clone()
    lamda = torch.Tensor([10.0]).to(device) 
    threshold_diff = 1e-5
    prev_loss = -1
    step = lamda/2 
    lamda_succ = 0
    iters = 0     
    while torch.abs(lamda_succ - lamda) > threshold_diff:
        mal_update = past_weights + (grads_re - lamda * deviation)
        all_updates[is_mal==1] = torch.stack([mal_update] * n_attackers)

        agg_grads = torch.mean(all_updates, 0)
        loss = torch.norm(agg_grads - model_re)
        if prev_loss < loss: ##oracle outputs true ie if loss is increasing increase lamda even more for a stronger attack
            lamda_succ = lamda
            lamda = lamda + step / 2
        else:
            lamda = lamda - step / 2
        step = step / 2
        prev_loss = loss
    #print(lamda_succ)
    mal_update = (model_re - lamda_succ * deviation)
    client_grads[is_mal==1] = torch.stack([mal_update]*n_attackers)
    del all_updates
    
    return None, None, None
    #return model_re, lamda_succ, deviation 

    

