# --coding:utf-8--
import numpy as np
import torch


def numpy_generalized_variance(pop):
    cov_matrix = np.cov(pop)  # 计算各个agent的embedding之间的协方差矩阵
    # cov_matrix = torch.cov(pop)
    return np.linalg.det(cov_matrix)


def torch_generalized_variance(input, flag):
    n = input.shape[0]
    m = input.shape[1]
    mean = (torch.mean(input, dim=0)).view(1, m)
    means = torch.cat([mean, mean], dim=0)

    for i in range(n - 2):
        means = torch.cat([means, mean], dim=0)
    x = input - means

    cov_matrix = torch.matmul(x.T, x) / x.shape[0]

    gv = torch.det(cov_matrix)

    if gv == 0:
       gv = Gerschgorin_Circle(cov_matrix)

    print("Agent:{} compute GV={}".format(flag, gv))

    return gv

def Gerschgorin_Circle(cov_matrix):
    for i in range(cov_matrix.shape[0]):
        total = 0
        for j in range(cov_matrix.shape[1]):
            total += cov_matrix[i][j]
        cov_matrix[i][i]+=total

    gv = torch.det(cov_matrix)
    return gv