
import numpy as np

device = "cuda:5"

def get_prototype(features, targets):
    num_classes = len(np.unique(targets, axis=0))
    prot = np.zeros((num_classes, features.shape[-1]), dtype=features.dtype)
    # pdb.set_trace()
    
    for i in range(num_classes):
        prot[i] = np.mean(features[(targets == i).nonzero(), :].squeeze(), axis=0, keepdims=False)
    return prot
    
def get_prototype_median(features, targets):
    num_classes = len(np.unique(targets, axis=0))
    prot = np.zeros((num_classes, features.shape[-1]), dtype=features.dtype)
    # pdb.set_trace()
    
    for i in range(num_classes):
        prot[i] = np.median(features[(targets == i).nonzero(), :].squeeze(), axis=0, keepdims=False)
    return prot