from typing import Dict, List
import torch

def agg_func(protos):
    """
    Returns the average of the weights.
    """
    agg_proto = {}
    with torch.no_grad():
        for label, proto_list in protos.items():
            if len(proto_list) > 1:
                proto = 0 * proto_list[0].detach().clone().data
                for i in proto_list:
                    proto += i.detach().clone().data
                agg_proto[label] = proto / len(proto_list)
            else:
                agg_proto[label] = proto_list[0].detach().clone()
    return agg_proto

def proto_aggregation(local_protos_list:Dict):
    agg_protos_label = dict()
    for idx in local_protos_list.keys():
        local_protos = local_protos_list[idx]
        for label in local_protos.keys():
            if label in agg_protos_label:
                agg_protos_label[label].append(local_protos[label])
            else:
                agg_protos_label[label] = [local_protos[label]]

    for [label, proto_list] in agg_protos_label.items():
        if len(proto_list) > 1:
            proto = 0 * proto_list[0].data
            for i in proto_list:
                proto += i.data
            agg_protos_label[label] = proto / len(proto_list)
        else:
            agg_protos_label[label] = proto_list[0].data

    return agg_protos_label

def protos_by_class(local_protos:List[Dict]):
    byclass_protos = {}
    for local_proto_dict in local_protos:
        for idx in local_proto_dict.keys():
            if idx in byclass_protos.keys():
                byclass_protos[idx].append(local_proto_dict[idx])
            else:
                byclass_protos[idx] = [local_proto_dict[idx]]
    return byclass_protos


def get_covariance(feature_dicts):
    cov_protos_label = dict()
    for label, proto_list in feature_dicts.items():
        sample_list = []
        for z in proto_list:
            sample_list.append(z)
        sample_list = torch.stack(sample_list).T
        cov_protos_label[label] = torch.cov(sample_list, correction=1)
        #print(f"covariance of class {label} shape is {torch.cov(sample_list).shape}")
    return cov_protos_label

def get_global_cov(npc_dict, device='cuda', Lambda=0.001) -> Dict:
    
    #global_proto_dict = {}
    global_cov_dict = {}
    I_mat = torch.eye(20*7*7, device=device, requires_grad=False)
    for label in npc_dict.keys():
        nc = 0
        muc = 0    
        term1, term2, term3 = 0,0,0
        for nk, pk, ck in npc_dict[label]:
            zk = pk.reshape(20*7*7, 1)
            nc += nk
            term1 += (nk-1) * ck
            term2 += nk * (zk @ zk.T)
            muc += zk * nk
        muc = muc / nc
        term1 = term1.to(device)
        term2 = term2.to(device)
        term3 = nc * (muc @ muc.T).to(device)
        global_cov_dict[label] = (term1 + term2 - term3) / (nc-1) + I_mat * Lambda
    
    return global_cov_dict

def get_proto_covariance(npc_dict, device='cuda') -> Dict:
    global_cov_dict = {}
    for label in npc_dict.keys():
        zweights = []
        protos = []
        for nk, pk, _ in npc_dict[label]:
            zweights.append(nk)
            protos.append(pk)
        zweights = torch.tensor(zweights, device=device)
        protos = torch.stack(protos).T.to(device)
        global_cov_dict[label] = torch.cov(protos, fweights=zweights)
    return global_cov_dict
