import math
import torch
import numpy as np
import torch.nn as nn
import time
import pdb
import attack
#import torch.multiprocessing as mp
import group
import utils
from copy import deepcopy
import statistics


def p2p_trmean(W, client_weights, nets, cmax, is_mal, device):

    aggregated_wts = torch.zeros(client_weights.shape).to(device)
    for i in range(len(W)):
        if (is_mal[i] == 0):
            k = len(np.where(W[i].cpu()>0)[0])
            local_cmax = int(np.ceil((cmax/len(W))*k))
            if (local_cmax > 0.5*k): 
                print("Cannot trim so many gradients")
                return client_weights

            ref_vec = utils.model_to_vec(nets[i])
            grads = client_weights[np.where(W[i].cpu()>0)[0]] - ref_vec
            sorted_array = torch.sort(grads, axis=0)
            aggregated_wts[i] = ref_vec + torch.mean(sorted_array[0][local_cmax:k-local_cmax,:], axis=0)
    return aggregated_wts

def p2prism(prev_medians, W, nets, client_weights, k_size, nbrs, graph_type, is_mal, device, white=False, model_re=None, lamda=None, deviation=None):

    mask = W.clone()
    w_size = len(client_weights)
    rpt = torch.zeros((w_size, w_size)).to(device)
    aggregated_wts = torch.zeros(client_weights.shape).to(device)
    medians = {}
    FS_record = {}
    mu = 0.75
    for i in range(len(client_weights)):
        if (is_mal[i] == 0):
            FS = {}
            ref_vec = utils.model_to_vec(nets[i])
            ref_dir = torch.sign(client_weights[i] - ref_vec)
            if (graph_type == 'k-regular'):
                for j in range(-k_size[0], k_size[1]+1):
                    idx = i+j
                    if (idx<0): idx += w_size
                    if (idx>=w_size): idx -= w_size
                    #if is_mal[idx]: flag = 1
                    if (j != 0): ##compute FS
                        if (white == False or is_mal[idx] == 0): client_grads = client_weights[idx] - ref_vec
                        elif (white == True and is_mal[idx] == 1): client_grads = client_weights[idx] + (lamda - math.sqrt((1+mu)*prev_medians[i])/torch.sum(deviation).item())*deviation - ref_vec
                        direction = torch.sign(client_grads)
                        FS[idx] = torch.sum(((direction != ref_dir) & (direction != 0))*(client_grads**2)).item()
            elif (graph_type == 'power-law'):
                for nbr in nbrs[i]:
                    client_grads = client_weights[nbr] - ref_vec
                    direction = torch.sign(client_grads)
                    FS[nbr] = torch.sum(((direction != ref_dir) & (direction != 0))*(client_grads**2)).item()

            median = statistics.median(list(FS.values()))
            if (median > 10*min(FS.values())): median = 10*min(FS.values())
            medians[i] = median
            #filtered_idx = sorted(FS)#[:filter_k]
            filtered_idx = []
            rpt_sum = 0
            for key in FS:
                if (FS[key] < median + mu*(median - min(list(FS.values())))): rpt[i][key] = 1.0
                else:
                    if (median > min(list(FS.values()))): rpt[i][key] = -(FS[key] - median) / (median - min(list(FS.values())))
                    else: rpt[i][key] = -(FS[key] - median)
                    if (rpt[i][key] < -5): rpt[i][key] = -5
                rpt[i][key] += W[i][key]
                if (rpt[i][key] > 0): 
                    rpt_sum += rpt[i][key]
                    filtered_idx.append(key)
                #if (FS[key] > median + 0*(median - min(list(FS.values())))):
                #    filtered_idx.remove(key)
            W[i] = rpt[i]/rpt_sum
            rpt[i] = W[i]
            filtered_idx.append(i)
            aggregated_wts[i] = torch.mean(client_weights[filtered_idx], axis=0)
            FS_record[i] = FS
    return rpt, FS_record, aggregated_wts, medians


def p2p(nbd, client_updates, device, msg=None):

    aggregated_wts = torch.mm(nbd, client_updates) 

    p=0

    return p, aggregated_wts
    
    
