import math
import numpy as np
from scipy.spatial.distance import pdist, squareform

# class GaussianProcessModel(object):
#     def __init__(self, train_x, train_u):

def kernel_function(client_characteristic, client_characteristic_other, length_scale):
    # Squared Exponential kernel function
    # client_characteristic:表示某个client的特征
    # client_characteristic_other:表示另一个client的特征
    # 返回 Squared Exponential kernel function的值
    dist = np.sum(np.square(client_characteristic-client_characteristic_other))
    return np.exp(-dist/(2 * length_scale ** 2))

def kernel_matrix_function(train_x,length_scale):
    # train_x: 第i行，表示第i次的选择的最长对角线对应的client的特征
    # length_scale: GP的核函数的尺度参数
    # 返回kernel_matrix, Gaussian Process方法的核函数矩阵
    kernel_matrix = np.exp(-squareform(pdist(np.array(train_x), 'euclidean')**2) / (2 * length_scale ** 2) )  # Squared Exponential kernel
    return kernel_matrix   # 返回Gaussian Process方法的核函数矩阵


def gaussian_process_model(train_x, train_u, candidate, client_characteristic):
    # train_x: 第i行，表示第i次的选择的最长对角线对应的client的特征
    # train_u: 第i个，表示第i次的utility
    # client_characteristic: 第i行，表示第i个client的特征
    # candidate: 第i个元素表示第i个client是否为候选参与者; 默认为1,表示候选参与者
    # 返回confidence_interval 第i行表示第i个client对应的置信区间, 只更新候选参与者的置信区间, 其它client默认为0
    sigma_par = 0.02 # GP的Sigma参数
    length_scale = 3  # GP的核函数的尺度参数
    # len_characteristic = client_characteristic.shape[1] # 获取特征长度
    count_train_x = len(train_x) # 获取历史样本个数
    count = len(client_characteristic)  # 获取client个数
    delta_pro = 0.02 # Delta参数
    beta_t_sqrt = math.sqrt(2 * math.log((count * math.pi**2 * count_train_x**2) / 3 / delta_pro))
    confidence_interval = np.zeros((count,2)) # 置信区间
    kernel_matrix = kernel_matrix_function(train_x, length_scale)
    inv_kernel_matrix_positive = np.linalg.inv(kernel_matrix + sigma_par ** 2 * np.eye(count_train_x))
    temp_mu = inv_kernel_matrix_positive.dot(np.array(train_u))
    for i in range(count):
        if candidate[i] == 0:
            continue
        small_k_vector = []
        for j in range(count_train_x):
            small_k_vector.append(kernel_function(train_x[j], client_characteristic[i], length_scale))
        small_k_vector_arr = np.array(small_k_vector)
        expectation_mu = small_k_vector_arr.dot(temp_mu)
        intermediate_result = 1 - small_k_vector_arr.dot(inv_kernel_matrix_positive.dot(small_k_vector_arr))
        if intermediate_result < 0:
            intermediate_result = 0  # 或者其他适当的处理逻辑
        var = math.sqrt(intermediate_result)
        #var = math.sqrt(1 - small_k_vector_arr.dot(inv_kernel_matrix_positive.dot(small_k_vector_arr)))
        confidence_interval[i] = [expectation_mu - beta_t_sqrt * var, expectation_mu + beta_t_sqrt * var]
    return confidence_interval # confidence_interval 第i行表示第i个client对应的置信区间, 只更新候选参与者的置信区间, 其它client默认为0



        


if __name__ == '__main__':
    # a = np.array([[1, 2],[3, 4],[7,8]])
    # b = np.array([1,1])
    # print(a.dot(b))
    # print(np.r_[a,[b]])
    # print(a.shape[1])

    # train_x = np.array([[1, 2],[3, 4],[7,8],[1,2]])
    # train_u = np.array([2, 3, 1, 6])
    # sigma_par = 0.02
    # length_scale  = 3
    # count = len(train_u)
    # kernel_matrix = kernel_matrix_function(train_x, length_scale)
    # inv_kernel_matrix_positive = np.linalg.inv(kernel_matrix + sigma_par ** 2 * np.eye(count))
    # temp_mu = inv_kernel_matrix_positive.dot(np.array(train_u))
    # print(temp_mu)

    # client_characteristic = np.array([1,2])
    # client_characteristic_other = np.array([6,4])
    # dist = np.sum(np.square(client_characteristic-client_characteristic_other))
    # print(dist)

    # list_temp = [1, 2, 3,4]
    # print(np.array(list_temp))

    # temp = np.array([[1, 2],[3, 4],[7,8]])
    # temp[1] = [5, 7]
    # print(temp)
    # print(len(temp))
    train_x = np.array([[1, 2],[3, 4],[7,8],[1,2], [1,2]])
    train_u = np.array([2, 3, 1, 6, 6])
    sigma_par = 0.02
    length_scale  = 3
    candidate = [1,1,1]
    client_characteristic = np.array([[1, 2],[3, 4],[7,8]])
    confidence_interval = gaussian_process_model(train_x,train_u,candidate,client_characteristic)
    print(confidence_interval)
    