import pdb
import numpy as np
import torch
import math
def mean(client_grads, group_ids, client_participated, gr, group_participated):
    members = np.where(group_ids==gr)
    n_active_members = np.sum(client_participated[members])
    if (n_active_members == 0): #this is a redundant if, but added for explainability
        group_participated[gr] = 0
    else:
        group_participated[gr] = 1
        group_grads = torch.sum(client_grads[members], dim=0)/n_active_members
    return group_grads

def prism(client_grads, group_ids, client_participated, gr, group_grads, group_participated, old_direction, fmax):
    members = np.where(group_ids==gr)
    n_active_members = np.sum(client_participated[members])
    if (n_active_members == 0):
        group_participated[gr] = 0
    else:
        fs = {}
        for member in members[0]:
            if (client_participated[member]):
                direction = torch.sign(client_grads[member])
                flip = torch.sign(direction*(direction-old_direction.reshape(-1)))
                fs[member] = torch.sum(flip*(client_grads[member]**2))
                del direction, flip
        sorted_keys = sorted(fs, key=fs.get)
        trim = int(math.ceil(fmax*len(fs)))
        #print("FS of clients in group %d" %gr, fs)
        if (2*trim >= len(fs)): group_participated[gr] = 0
        else: 
            group_participated[gr] = 1
            if (trim == 0): filtered = sorted_keys
            else: filtered = sorted_keys[trim:-trim]
            group_grads[gr] = torch.mean(client_grads[filtered], dim=0)
    return
