import math
import numpy as np

def linear_model(train_x, train_v, candidate, client_characteristic):
    # train_x: 第i行，表示第i次的选择的最长对角线对应的client的特征
    # train_v: 第i个，表示第i次的valid participant
    # candidate: 第i个元素表示第i个client是否为候选参与者; 默认为1,表示候选参与者
    # client_characteristic: 第i行，表示第i个client的特征
    # Return confidence_interval  第i行表示第i个client对应的置信区间, 只更新候选参与者的置信区间, 其它client默认为0

    delta_pro = 0.02 # Delta参数
    alpha = 1 + math.sqrt(math.log(4/delta_pro)/2)
    lambda_value = 1 # Lambda参数

    count_train_x = len(train_x) # 获取历史样本个数
    count = len(client_characteristic)  # 获取client个数
    len_characteristic = client_characteristic.shape[1] # 获取特征长度

    h_matrix = lambda_value * np.identity(len_characteristic) + train_x.T.dot(train_x)
    inv_h_matrix = np.linalg.inv(h_matrix)

    theta_estimation = inv_h_matrix.dot(train_x.T.dot(train_v.T))  # 更新对参数theta的估计

    confidence_interval = np.zeros((count,2)) # 置信区间
    for i in range(count):
        if candidate[i] == 0:
            continue
        mu_estimation = client_characteristic[i].dot(theta_estimation.T)
        bound_estimation = alpha*math.sqrt(client_characteristic[i].dot(inv_h_matrix.dot(client_characteristic[i].T)))
        confidence_interval[i] = np.array([mu_estimation - bound_estimation, mu_estimation + bound_estimation])
    
    return confidence_interval



if __name__ == '__main__':
    train_x = np.array([[1, 2],[3, 4],[1,2],[7,8],[3, 4],[1,2]])
    train_v = np.array([1,0,1,1,0,0])
    client_characteristic = np.array([[1, 2],[3, 4],[7,8]])

    candidate = [1,1,1]
    lambda_value = 1 # Lambda参数
    #len_characteristic = client_characteristic.shape[1] # 获取特征长度

    #h_matrix = lambda_value * np.identity(len_characteristic) + train_x.T.dot(train_x)
    #inv_h_matrix = np.linalg.inv(h_matrix)
    #theta_estimation = inv_h_matrix.dot(train_x.T.dot(train_v.T))  # 更新对参数theta的估计

    #print(theta_estimation)
    #print(client_characteristic[1].dot(theta_estimation.T))

    confidence_interval = linear_model(train_x, train_v, candidate, client_characteristic)
    print(confidence_interval)