import numpy as np
from sklearn.preprocessing import normalize, minmax_scale
import matplotlib.pyplot as plt
import time
import random
from random import sample, shuffle, choice
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error , r2_score

d = 25
sigma = 0.01  # reward noise variance
lambda1 = 0.01  # tune alpha_0
lambda2 = 0.01  # tune alpha_s s \in [S]
gamma = 0.06   # tune width_s

class FedSupUCB(): # type = sync async
    def __init__(self, contexts, theta, regrets, rewards, pattern, C_thred, D_thred, K, M, T_0, comm, type = 'async'):
        self.d = d  # d dimension
        self.K = K  # K arms
        self.T_0 = T_0
        self.M = M
        self.T = M * T_0  # total rounds
        self.C = C_thred
        self.D = D_thred
        self.type = type
        self.delta = 1/(self.M *self.T)  # risk probability
        self.contexts = contexts
        self.theta = theta
        self.regrets = regrets
        self.rewards = rewards
        self.pattern = pattern
        self.comm = comm  # default True


        self.game_regrets = np.zeros(self.T)
        self.game_cumulregrets = np.zeros(self.T)

        self.game_rewards = np.zeros(self.T)
        self.game_cumulrewards = np.zeros(self.T)

        self.comm_cost = np.zeros(self.T)
        self.game_width = np.zeros(self.T)
        self.diff_theta = np.zeros(self.T)

        self.S = (int(np.log(self.d)) + 4)  # S layers
        self.com_last = np.zeros(self.S + 1)  # layer s last comm time
        self.alpha = np.zeros(self.S + 1)  # alpha_s s\in {0, [S]}
        self.w = np.zeros(self.S + 1)  # w_s s\in {0, [S]}
        self.alpha[0] = lambda1 * (2 * np.sqrt(self.d * np.log(2 * self.T * self.M / self.delta)) + 1)
        self.w[0] = gamma * d ** 1.5 / np.sqrt(self.T)
        for i in range(1, self.S + 1):
            self.alpha[i] = lambda2 * (np.sqrt(np.log(2 * self.d * self.K * self.M * self.T / self.delta)) + 1)
            self.w[i] = self.w[i - 1]* 0.5
        print(f'initialized model!')

        self.A_ser = np.expand_dims(np.identity(self.d), 0).repeat(self.S + 1, axis=0)  # A matrix （1+S）*d*d
        self.b_ser = np.zeros((self.S + 1, self.d))  # b vector （1+S）*d
        A_matrix = np.expand_dims(np.identity(self.d), 0).repeat(self.S + 1, axis=0)  # (1+S)*d*d
        self.A_client = np.expand_dims(A_matrix, 0).repeat(self.M, axis=0)  # A cleint matrix M*(1+S)*d*d
        self.A_delta_client = np.expand_dims(A_matrix, 0).repeat(self.M, axis=0)  # A_dleta_client matrix M*(1+S)*d*d
        self.b_client = np.zeros((self.M, self.S + 1, self.d))  # b client vector M*(1+S)*d
        self.b_delta_client = np.zeros((self.M, self.S + 1, self.d))  # b_delta_client vector M*(1+S)*d

    def communicate(self, client, s, t):  # only communicate data in layer s
        if self.type == 'async':
            # print(f'Async communciation starts!')
            self.comm_cost[t] = self.comm_cost[t - 1] + 1
            self.A_ser[s] += self.A_delta_client[client][s]  # server aggregate data
            self.b_ser[s] += self.b_delta_client[client][s]
            self.A_client[client][s] = self.A_ser[s]  # client receive updated data
            self.b_client[client][s] = self.b_ser[s]
            self.A_delta_client[client][s] -= self.A_delta_client[client][s]  # client empty local buffer
            self.b_delta_client[client][s] -= self.b_delta_client[client][s]
        elif self.type == 'sync':
            # print(f'Sync communciation starts!')
            self.comm_cost[t] = self.comm_cost[t - 1] + self.M
            for i in range(self.M):
                self.A_ser[s] += self.A_delta_client[i][s]  # server aggregate data
                self.b_ser[s] += self.b_delta_client[i][s]
                self.A_delta_client[i][s] -= self.A_delta_client[i][s]  # client empty local buffer
                self.b_delta_client[i][s] -= self.b_delta_client[i][s]
            for i in range(self.M):
                self.A_client[i][s] = self.A_ser[s]  # client receive updated data
                self.b_client[i][s] = self.b_ser[s]

    def base_linucb(self, K_contexts, client, s):
        A = self.A_client[client][s]
        b = self.b_client[client][s]
        A_inv = np.linalg.inv(A)
        A_det = np.linalg.det(A)
        theta_hat = np.dot(A_inv, b)
        r_hat = np.zeros(self.K)
        width = np.zeros(self.K)
        for i in range(self.K):
            x = K_contexts[i]
            r_hat[i] = np.dot(x, theta_hat)
            width[i] = self.alpha[s] * np.sqrt(np.dot(np.dot(x, A_inv), x))
        return r_hat, width, theta_hat

    def select_arm(self, t):
        s = 0
        client = self.pattern[t]
        K_contexts = self.contexts[t]
        K_rewards = self.rewards[t]
        best = np.argmin(self.regrets[t])
        action = None
        action_set = np.arange(self.K)
        r_hat, width, theta_hat = self.base_linucb(K_contexts, client, s)
        ucb = r_hat + width
        lcb = r_hat - width
        thred = np.max(lcb)
        index = np.where(ucb[action_set] >= thred)[0]
        action_set = action_set[index]
        while action == None:
            r_hat, width, theta_hat = self.base_linucb(K_contexts, client, s)
            ucb = r_hat + width
            if s == self.S:
                action = np.random.choice(action_set)
            elif np.max(width[action_set]) <= self.w[s]:
                thred = np.max(ucb[action_set]) - 2 * self.w[s]
                index = np.where(ucb[action_set] >= thred)[0]
                action_set = action_set[index]
                s = s + 1
            else:
                index = np.where(width[action_set] > self.w[s])[0]
                action_set = action_set[index]
                index = np.argmax(width[action_set])
                action = action_set[index]
        self.A_delta_client[client][s] += np.outer(K_contexts[action], K_contexts[action])
        self.b_delta_client[client][s] += K_rewards[action] * K_contexts[action]
        if self.comm == True:
            temp = np.linalg.det( self.A_client[client][s] + self.A_delta_client[client][s]) / np.linalg.det( self.A_client[client][s])
            if (self.type == 'async') and temp > (1 + self.C):  # comminicate with server
                self.communicate(client, s, t)
            elif (self.type == 'sync') and np.log(temp)* (t - self.com_last[s]) > self.D:
                self.com_last[s] = t
                self.communicate(client, s, t)
            else:
                self.comm_cost[t] = self.comm_cost[t - 1]
        else:  # when no communication, just do local update
            self.A_client[client][s] += np.outer(K_contexts[action], K_contexts[action])
            self.b_client[client][s] += K_rewards[action] * K_contexts[action]

        self.game_regrets[t] = self.regrets[t][action]
        self.game_cumulregrets[t] = np.sum(self.game_regrets)
        self.game_rewards[t] = self.rewards[t][action]
        self.game_cumulrewards[t] = np.sum(self.game_rewards)

        self.diff_theta[t] = np.linalg.norm(self.theta - theta_hat)
        self.game_width[t] = width[action]
        return best, action, action_set, s, width[action]

    def run(self):
        for t in range(self.T):
            best, action, action_set, s, width = self.select_arm(t)
        print(f'finish running!')
        return self.game_cumulregrets, self.game_cumulrewards, self.comm_cost, self.diff_theta, self.game_width

def generate_contexts(M, T_0, K):
    T = M* T_0
    # gennerate data for T rounds
    theta = np.random.uniform(low=-1, high=1, size=d)
    theta = theta / np.linalg.norm(theta)
    print(f'true theta norm = {np.linalg.norm(theta)}')
    contexts = np.random.uniform(low=-1, high=1, size=(T, K, d))  # context table [T, K,d]
    truemean = np.zeros((T, K))  # [T,K] rewards table
    rewards = np.zeros((T, K))  # [T,K]
    regrets = np.zeros((T, K))  # [T,K]
    best = np.zeros((T))  # [T] best arm
    for t in range(T):
        contexts[t] = normalize(contexts[t], axis=1)
        truemean[t] = np.dot(contexts[t], theta)
        regrets[t] = np.max(truemean[t]) - truemean[t]
        best[t] = np.argmax(truemean[t])
        noise = np.random.normal(scale=sigma, size=K)
        rewards[t] = truemean[t] + noise
    return contexts, theta, regrets, rewards

def generate_pattern(M, T_0):  # random arrival pattern
    T = M * T_0
    pattern1 = np.random.randint(M, size=T)     #random
    pattern2 = np.tile(np.arange(M), T_0)   #round-robin [01230123]
    pattern3 = np.repeat(np.arange(M), T_0)     #click-leave [00112233]
    return pattern1, pattern2, pattern3

#synthetic figure1
def synthetic_compare_pattern():
    d = 25
    K = 20
    M = 20
    T_0 = 2000
    T = T_0 * M
    delta = 1 / (M * T)  # risk probability s
    sigma = 0.01  # reward noise variance
    lambda1 = 0.01  # tune alpha_0
    lambda2 = 0.01  # tune alpha_s s \in [S]
    gamma = 0.06  # tune width_s
    C_thred = 0.15  # doubling trick parameter
    D_thred = 170  # comm thredsholdls

    regret_async_random = []
    commcost_async_random = []
    regret_async_round = []
    commcost_async_round = []
    regret_async_click = []
    commcost_async_click = []
    regret_sync = []
    commcost_sync = []

    epoch = 3
    for i in range(epoch):
        pattern1, pattern2, pattern3 = generate_pattern(M, T_0) #random, round-robin, click-leave
        print(f'generate pattern!')
        contexts, theta, regrets, rewards = generate_contexts(M, T_0, K)
        print(f'generate contexts!')
        sup_model = FedSupUCB(contexts, theta, regrets, rewards, pattern1, C_thred, D_thred, K, M, T_0, comm=True,  type = 'async')
        game_cumulregrets1, game_cumulrewards1, comm_cost1, diff_theta1, game_width1 = sup_model.run()
        regret_async_random.append(game_cumulregrets1)
        commcost_async_random.append(comm_cost1)

        sup_model = FedSupUCB(contexts, theta, regrets, rewards, pattern2, C_thred, D_thred, K, M, T_0, comm=True,  type = 'async')
        game_cumulregrets2, game_cumulrewards2, comm_cost2, diff_theta2, game_width2 = sup_model.run()
        regret_async_round.append(game_cumulregrets2)
        commcost_async_round.append(comm_cost2)


        sup_model = FedSupUCB(contexts, theta, regrets, rewards, pattern3, C_thred, D_thred, K, M, T_0, comm=True, type='async')
        game_cumulregrets3, game_cumulrewards3, comm_cost3, diff_theta3, game_width3 = sup_model.run()
        regret_async_click.append(game_cumulregrets3)
        commcost_async_click.append(comm_cost3)

        sup_model = FedSupUCB(contexts, theta, regrets, rewards, pattern2, C_thred, D_thred, K, M, T_0, comm=True, type='sync')
        game_cumulregrets4, game_cumulrewards4, comm_cost4, diff_theta4, game_width4 = sup_model.run()
        regret_sync.append(game_cumulregrets4)
        commcost_sync.append(comm_cost4)

    mean_regret_async_random = np.mean(regret_async_random, axis = 0)
    std_regret_async_random = np.std(regret_async_random, axis=0)
    mean_commcost_async_random = np.mean(commcost_async_random, axis=0)
    std_commcost_async_random = np.std(commcost_async_random, axis=0)

    mean_regret_async_round = np.mean(regret_async_round, axis=0)
    std_regret_async_round = np.std(regret_async_round, axis=0)
    mean_commcost_async_round = np.mean(commcost_async_round, axis=0)
    std_commcost_async_round = np.std(commcost_async_round, axis=0)

    mean_regret_async_click = np.mean(regret_async_click, axis=0)
    std_regret_async_click = np.std(regret_async_click, axis=0)
    mean_commcost_async_click = np.mean(commcost_async_click, axis=0)
    std_commcost_async_click = np.std(commcost_async_click, axis=0)

    mean_regret_sync = np.mean(regret_sync, axis=0)
    std_regret_sync = np.std(regret_sync, axis=0)
    mean_commcost_sync = np.mean(commcost_sync, axis=0)
    std_commcost_sync = np.std(commcost_sync, axis=0)

    start = 0
    alpha = 1
    x = np.linspace(0, T - 1, T)

    plt.figure()
    plt.plot(x, mean_regret_async_random, label='Async-random')
    plt.fill_between(x, mean_regret_async_random-alpha*std_regret_async_random, mean_regret_async_random + alpha*std_regret_async_random, alpha=0.3)
    plt.plot(x, mean_regret_async_round, label='Async-round-robin')
    plt.fill_between(x, mean_regret_async_round-alpha*std_regret_async_round, mean_regret_async_round + alpha*std_regret_async_round, alpha=0.3)
    plt.plot(x, mean_regret_async_click, label='Async-click-leave')
    plt.fill_between(x, mean_regret_async_click-alpha*std_regret_async_click, mean_regret_async_click + alpha*std_regret_async_click, alpha=0.3)
    plt.plot(x, mean_regret_sync, label='Sync-round-robin')
    plt.fill_between(x, mean_regret_sync-alpha*std_regret_sync, mean_regret_sync+ alpha*std_regret_sync, alpha=0.3)
    plt.legend()
    plt.xlabel('Number of arm pulling')
    plt.ylabel('Cumulative regret')
    plt.grid(linestyle='dashed', color="grey")
    time1 = time.strftime('%m-%d %H:%M:%S', time.localtime())
    plt.savefig(path + 'syntheticpattern_regret_C%s_M%s_d%s_K%s_Time%s' % (C_thred, M, d, K, time1) + '.png', dpi=300)
    plt.show()

    plt.figure()
    plt.plot(x, mean_commcost_async_random, label='Async-random')
    plt.fill_between(x,mean_commcost_async_random-alpha*std_commcost_async_random, mean_commcost_async_random + alpha*std_commcost_async_random, alpha=0.3)
    plt.plot(x, mean_commcost_async_round, label='Async-round-robin')
    plt.fill_between(x, mean_commcost_async_round-alpha*std_commcost_async_round, mean_commcost_async_round + alpha*std_commcost_async_round, alpha=0.3)
    plt.plot(x, mean_commcost_async_click, label='Async-click-leave')
    plt.fill_between(x, mean_commcost_async_click-alpha*std_commcost_async_click, mean_commcost_async_click+ alpha*std_commcost_async_click, alpha=0.3)
    plt.plot(x, mean_commcost_sync, label='Sync-round-robin')
    plt.fill_between(x, mean_commcost_sync-alpha*std_commcost_sync, mean_commcost_sync+ alpha*std_commcost_sync, alpha=0.3)
    plt.legend()
    plt.xlabel('Number of arm pulling')
    plt.ylabel('Communication cost')
    plt.grid(linestyle='dashed', color="grey")
    time1 = time.strftime('%m-%d %H:%M:%S', time.localtime())
    plt.savefig(path + 'syntheticpattern_commcost_C%s_M%s_d%s_K%s_Time%s' % (C_thred, M, d, K, time1) + '.png', dpi=300)
    plt.show()

def synthetic_async_clients():
    d = 25
    K = 20
    M = 20
    T_0 = 2000
    T = T_0 * M
    delta = 1 / (M * T)  # risk probability s
    sigma = 0.01  # reward noise variance
    loop = 5  # run loop times and average
    lambda1 = 0.01  # tune alpha_0
    lambda2 = 0.01  # tune alpha_s s \in [S]
    gamma = 0.06  # tune width_s
    C_thred = 0.15  # doubling trick parameter
    D_thred = 170   # comm thredsholdls

    regret_1 = []
    regret_20 = []
    regret_40 = []
    regret_80 = []


    epoch = 3
    for i in range(epoch):
        M = 1
        pattern1, pattern2, pattern3 = generate_pattern(M, T_0)  # random, round-robin, click-leave
        contexts, theta, regrets, rewards = generate_contexts(M, T_0, K)
        sup_model = FedSupUCB(contexts, theta, regrets, rewards, pattern2, C_thred, D_thred, K, M, T_0, comm= False,  type = 'async')
        game_cumulregrets1, game_cumulrewards1, comm_cost1, diff_theta1, game_width1 = sup_model.run()
        game_cumulregrets1 = game_cumulregrets1[0:M * T_0:M]/M
        regret_1.append(game_cumulregrets1)

        M = 20
        pattern1, pattern2, pattern3 = generate_pattern(M, T_0)  # random, round-robin, click-leave
        contexts, theta, regrets, rewards = generate_contexts(M, T_0, K)
        sup_model = FedSupUCB(contexts, theta, regrets, rewards, pattern2, C_thred, D_thred, K, M, T_0, comm= True, type='async')
        game_cumulregrets1, game_cumulrewards1, comm_cost1, diff_theta1, game_width1 = sup_model.run()
        game_cumulregrets1 = game_cumulregrets1[0:M * T_0:M] / M
        regret_20.append(game_cumulregrets1)

        M = 40
        pattern1, pattern2, pattern3 = generate_pattern(M, T_0)  # random, round-robin, click-leave
        contexts, theta, regrets, rewards = generate_contexts(M, T_0, K)
        sup_model = FedSupUCB(contexts, theta, regrets, rewards, pattern2, C_thred, D_thred, K, M, T_0, comm=True, type='async')
        game_cumulregrets1, game_cumulrewards1, comm_cost1, diff_theta1, game_width1 = sup_model.run()
        game_cumulregrets1 = game_cumulregrets1[0:M * T_0:M] / M
        regret_40.append(game_cumulregrets1)

        M = 80
        pattern1, pattern2, pattern3 = generate_pattern(M, T_0)  # random, round-robin, click-leave
        contexts, theta, regrets, rewards = generate_contexts(M, T_0, K)
        sup_model = FedSupUCB(contexts, theta, regrets, rewards, pattern2, C_thred, D_thred, K, M, T_0, comm=True, type='async')
        game_cumulregrets1, game_cumulrewards1, comm_cost1, diff_theta1, game_width1 = sup_model.run()
        game_cumulregrets1 = game_cumulregrets1[0:M * T_0:M] / M
        regret_80.append(game_cumulregrets1)


    mean_1 = np.mean(regret_1, axis = 0)
    std_1 = np.std(regret_1, axis=0)
    mean_20 = np.mean(regret_20, axis=0)
    std_20 = np.std(regret_20, axis=0)
    mean_40 = np.mean(regret_40, axis=0)
    std_40 = np.std(regret_40, axis=0)
    mean_80 = np.mean(regret_80, axis=0)
    std_80 = np.std(regret_80, axis=0)

    start = 0
    alpha = 1
    x = np.linspace(0, T_0 - 1, T_0)
    plt.figure()
    plt.plot(x, mean_1, label='SupLinUCB')
    plt.fill_between(x, mean_1-alpha*std_1, mean_1 + alpha*std_1, alpha=0.3)
    plt.plot(x, mean_20, label='20 clients async')
    plt.fill_between(x, mean_20 - alpha * std_20, mean_20 + alpha * std_20, alpha=0.3)
    plt.plot(x, mean_40, label='40 clients async')
    plt.fill_between(x, mean_40 - alpha * std_40, mean_40 + alpha * std_40, alpha=0.3)
    plt.plot(x, mean_80, label='80 clients async')
    plt.fill_between(x, mean_80 - alpha * std_80, mean_80 + alpha * std_80, alpha=0.3)
    plt.legend()
    plt.xlabel('Number of arm pulling')
    plt.ylabel('Cumulative regret')
    plt.grid(linestyle='dashed', color="grey")
    time1 = time.strftime('%m-%d %H:%M:%S', time.localtime())
    plt.savefig(path + 'synthetic_async_clients_C%s_M%s_d%s_K%s_Time%s' % (C_thred, M, d, K, time1) + '.png', dpi=300)
    plt.show()

def synthetic_sync_clients():
    d = 25
    K = 20
    M = 20
    T_0 = 2000
    T = T_0 * M
    delta = 1 / (M * T)  # risk probability s
    sigma = 0.01  # reward noise variance
    loop = 5  # run loop times and average
    lambda1 = 0.01  # tune alpha_0
    lambda2 = 0.01  # tune alpha_s s \in [S]
    gamma = 0.06  # tune width_s
    C_thred = 0.15  # doubling trick parameter
    D_thred = 170  # comm thredsholdls

    regret_1 = []
    regret_20 = []
    regret_40 = []
    regret_80 = []


    epoch = 4
    for i in range(epoch):
        M = 1
        pattern1, pattern2, pattern3 = generate_pattern(M, T_0)  # random, round-robin, click-leave
        contexts, theta, regrets, rewards = generate_contexts(M, T_0, K)
        sup_model = FedSupUCB(contexts, theta, regrets, rewards, pattern2, C_thred, D_thred, K, M, T_0, comm= False,  type = 'sync')
        game_cumulregrets1, game_cumulrewards1, comm_cost1, diff_theta1, game_width1 = sup_model.run()
        game_cumulregrets1 = game_cumulregrets1[0:M * T_0:M]/M
        regret_1.append(game_cumulregrets1)

        M = 20
        pattern1, pattern2, pattern3 = generate_pattern(M, T_0)  # random, round-robin, click-leave
        contexts, theta, regrets, rewards = generate_contexts(M, T_0, K)
        sup_model = FedSupUCB(contexts, theta, regrets, rewards, pattern2, C_thred, D_thred, K, M, T_0, comm= True, type='sync')
        game_cumulregrets1, game_cumulrewards1, comm_cost1, diff_theta1, game_width1 = sup_model.run()
        game_cumulregrets1 = game_cumulregrets1[0:M * T_0:M] / M
        regret_20.append(game_cumulregrets1)

        M = 40
        pattern1, pattern2, pattern3 = generate_pattern(M, T_0)  # random, round-robin, click-leave
        contexts, theta, regrets, rewards = generate_contexts(M, T_0, K)
        sup_model = FedSupUCB(contexts, theta, regrets, rewards, pattern2, C_thred, D_thred, K, M, T_0, comm=True, type='sync')
        game_cumulregrets1, game_cumulrewards1, comm_cost1, diff_theta1, game_width1 = sup_model.run()
        game_cumulregrets1 = game_cumulregrets1[0:M * T_0:M] / M
        regret_40.append(game_cumulregrets1)

        M = 80
        pattern1, pattern2, pattern3 = generate_pattern(M, T_0)  # random, round-robin, click-leave
        contexts, theta, regrets, rewards = generate_contexts(M, T_0, K)
        sup_model = FedSupUCB(contexts, theta, regrets, rewards, pattern2, C_thred, D_thred, K, M, T_0, comm=True, type='sync')
        game_cumulregrets1, game_cumulrewards1, comm_cost1, diff_theta1, game_width1 = sup_model.run()
        game_cumulregrets1 = game_cumulregrets1[0:M * T_0:M] / M
        regret_80.append(game_cumulregrets1)


    mean_1 = np.mean(regret_1, axis = 0)
    std_1 = np.std(regret_1, axis=0)
    mean_20 = np.mean(regret_20, axis=0)
    std_20 = np.std(regret_20, axis=0)
    mean_40 = np.mean(regret_40, axis=0)
    std_40 = np.std(regret_40, axis=0)
    mean_80 = np.mean(regret_80, axis=0)
    std_80 = np.std(regret_80, axis=0)

    start = 0
    alpha = 1
    x = np.linspace(0, T_0 - 1, T_0)
    plt.figure()
    plt.plot(x, mean_1, label='SupLinUCB')
    plt.fill_between(x, mean_1-alpha*std_1, mean_1 + alpha*std_1, alpha=0.3)
    plt.plot(x, mean_20, label='20 clients sync')
    plt.fill_between(x, mean_20 - alpha * std_20, mean_20 + alpha * std_20, alpha=0.3)
    plt.plot(x, mean_40, label='40 clients sync')
    plt.fill_between(x, mean_40 - alpha * std_40, mean_40 + alpha * std_40, alpha=0.3)
    plt.plot(x, mean_80, label='80 clients sync')
    plt.fill_between(x, mean_80 - alpha * std_80, mean_80 + alpha * std_80, alpha=0.3)
    plt.legend()
    plt.xlabel('Number of arm pulling')
    plt.ylabel('Cumulative regret')
    plt.grid(linestyle='dashed', color="grey")
    time1 = time.strftime('%m-%d %H:%M:%S', time.localtime())
    plt.savefig(path + 'synthetic_sync_clients_C%s_M%s_d%s_K%s_Time%s' % (C_thred, M, d, K, time1) + '.png', dpi=300)
    plt.show()

def synthetic_tradeoff():
    d = 25
    K = 20
    M = 400
    T_0 = 200
    T = T_0 * M
    delta = 1 / (M * T)  # risk probability s
    sigma = 0.01  # reward noise variance
    lambda1 = 0.01  # tune alpha_0
    lambda2 = 0.01  # tune alpha_s s \in [S]
    gamma = 0.06  # tune width_s
    C_thred = 0.15  # doubling trick parameter
    D_thred = 170  # comm thredsholdls

    pattern1, pattern2, pattern3 = generate_pattern(M, T_0)  # random, round-robin, click-leave
    print(f'generate pattern!')
    contexts, theta, regrets, rewards = generate_contexts(M, T_0, K)
    print(f'generate contexts!')
    epochs = 10
    low = 0.001
    high = 4
    ln_C = np.linspace(np.log(low), np.log(high), epochs)
    ln_C = np.exp(ln_C)
    regrets_C = np.zeros(epochs)
    commcosts_C = np.zeros(epochs)
    for i in range(epochs):
        C_thred = ln_C[i]
        sup_model = FedSupUCB(contexts, theta, regrets, rewards, pattern3, C_thred, D_thred, K, M, T_0, comm=True,
                              type='async')
        game_cumulregrets1, game_cumulrewards1, comm_cost1, diff_theta1, game_width1 = sup_model.run()
        regrets_C[i] = game_cumulregrets1[-1]
        commcosts_C[i] = comm_cost1[-1]
        print(f'C: {C_thred}_Async_regret: {regrets_C[i]}_commcost: {commcosts_C[i]}')
    start = 0
    plt.figure()
    plt.scatter(commcosts_C, regrets_C, label='Async-ununiform')
    plt.legend()
    plt.xlabel('Communication cost')
    plt.ylabel('Cumulative regret')
    plt.grid(linestyle='dashed', color="grey")
    time1 = time.strftime('%m-%d %H:%M:%S', time.localtime())
    plt.savefig(path + 'ununifrom_async_C%s_M%s_d%s_K%s_Time%s' % (C_thred, M, d, K, time1) + '.png', dpi=300)
    plt.show()
    low = 0.9
    high = 10
    D = np.linspace(low, high, epochs)
    D = np.exp(D)
    regrets_D = np.zeros(epochs)
    commcosts_D = np.zeros(epochs)
    for i in range(epochs):
        D_thred = D[i]
        sup_model = FedSupUCB(contexts, theta, regrets, rewards, pattern3, C_thred, D_thred, K, M, T_0, comm=True,
                              type='sync')
        game_cumulregrets1, game_cumulrewards1, comm_cost1, diff_theta1, game_width1 = sup_model.run()
        regrets_D[i] = game_cumulregrets1[-1]
        commcosts_D[i] = comm_cost1[-1]
        print(f'D: {D_thred}_Sync_regret: {regrets_D[i]}_commcost: {commcosts_D[i]}')
    start = 0
    plt.figure()
    plt.scatter(commcosts_D, regrets_D, label='Sync-ununiform')
    plt.legend()
    plt.xlabel('Communication cost')
    plt.ylabel('Cumulative regret')
    plt.grid(linestyle='dashed', color="grey")
    time1 = time.strftime('%m-%d %H:%M:%S', time.localtime())
    plt.savefig(path + 'ununifrom_sync_C%s_M%s_d%s_K%s_Time%s' % (C_thred, M, d, K, time1) + '.png', dpi=300)
    plt.show()

    start = 0
    plt.figure()
    plt.scatter(commcosts_D, regrets_D, label='Sync')
    plt.scatter(commcosts_C, regrets_C, label='Async')
    plt.legend()
    plt.xlabel('Communication cost')
    plt.ylabel('Cumulative regret')
    plt.grid(linestyle='dashed', color="grey")
    time1 = time.strftime('%m-%d %H:%M:%S', time.localtime())
    plt.savefig(path + '_C%s_M%s_d%s_K%s_Time%s' % (C_thred, M, d, K, time1) + '.png', dpi=300)
    plt.show()

def Movielens_async_client():
    start = 100

    M = 1
    K = 20
    C_thred = 0.0885866
    D_thred = 1.3  # comm thredsholdls
    T_0 = 5000

    M = 1
    pattern1, pattern2, pattern3 = generate_pattern(M, T_0)
    contexts, theta, regrets, rewards = realworld_contexts(M*T_0)
    random_cumulrewards = np.ones(T_0)
    random_cumulregrets = np.zeros(T_0)
    for t in range(T_0):
        random_pick = choice(np.arange(K))
        random_cumulrewards[t] = rewards[t][random_pick] + random_cumulrewards[t - 1]
        random_cumulregrets[t] = regrets[t][random_pick] + random_cumulregrets[t - 1]
    print(f'finish random strategy!')
    sup_model = FedSupUCB(contexts, theta, regrets, rewards, pattern2, C_thred, D_thred,K, M, T_0, comm= False,  type = 'async')
    game_cumlregret1, game_cumlrewards1, comm_cost1, diff_theta1, game_width1 = sup_model.run()
    normalized_cuml_rewards1 = game_cumlrewards1[0:M * T_0:M]/M
    normalized_cuml_rewards1 = np.divide(normalized_cuml_rewards1, random_cumulrewards)
    plt.figure()
    plt.plot(normalized_cuml_rewards1[start:], label='SupLinUCB')
    plt.legend()
    plt.xlabel('Number of arm pulling')
    plt.ylabel('Normalized reward')
    plt.grid(linestyle='dashed', color="grey")
    time1 = time.strftime('%m-%d %H:%M:%S', time.localtime())
    plt.savefig(path + 'Movielens_async_SupLinUCB_K%s_d%s_C%s_Time%s' % (K, d, C_thred, time1) + '.png', dpi=300)
    plt.show()

    M = 5
    pattern1, pattern2, pattern3 = generate_pattern(M, T_0)
    contexts, theta, regrets, rewards = realworld_contexts(M * T_0)
    sup_model = FedSupUCB(contexts, theta, regrets, rewards, pattern2, C_thred, D_thred, K, M, T_0, comm=True,
                          type='async')
    game_cumlregret1, game_cumlrewards1, comm_cost1, diff_theta1, game_width1 = sup_model.run()
    normalized_cuml_rewards2 = game_cumlrewards1[0:M * T_0:M] / M
    normalized_cuml_rewards2 = np.divide(normalized_cuml_rewards2, random_cumulrewards)
    plt.figure()
    plt.plot(normalized_cuml_rewards2[start:], label='5 clients async')
    plt.legend()
    plt.xlabel('Number of arm pulling')
    plt.ylabel('Normalized reward')
    plt.grid(linestyle='dashed', color="grey")
    time1 = time.strftime('%m-%d %H:%M:%S', time.localtime())
    plt.savefig(path + 'Movielens_async_M5_K%s_d%s_C%s_Time%s' % (K, d, C_thred, time1) + '.png', dpi=300)
    plt.show()

    M = 10
    pattern1, pattern2, pattern3 = generate_pattern(M, T_0)
    contexts, theta, regrets, rewards = realworld_contexts(M * T_0)
    sup_model = FedSupUCB(contexts, theta, regrets, rewards, pattern2, C_thred, D_thred, K, M, T_0, comm=True,
                          type='async')
    game_cumlregret1, game_cumlrewards1, comm_cost1, diff_theta1, game_width1 = sup_model.run()
    normalized_cuml_rewards3 = game_cumlrewards1[0:M * T_0:M] / M
    normalized_cuml_rewards3 = np.divide(normalized_cuml_rewards3, random_cumulrewards)
    plt.figure()
    plt.plot(normalized_cuml_rewards3[start:], label='10 clients async')
    plt.legend()
    plt.xlabel('Number of arm pulling')
    plt.ylabel('Normalized reward')
    plt.grid(linestyle='dashed', color="grey")
    time1 = time.strftime('%m-%d %H:%M:%S', time.localtime())
    plt.savefig(path + 'Movielens_async_M10_K%s_d%s_C%s_Time%s' % (K, d, C_thred, time1) + '.png', dpi=300)
    plt.show()

    M = 20
    pattern1, pattern2, pattern3 = generate_pattern(M, T_0)
    contexts, theta, regrets, rewards = realworld_contexts(M * T_0)
    sup_model = FedSupUCB(contexts, theta, regrets, rewards, pattern2, C_thred, D_thred, K, M, T_0, comm=True,
                          type='async')
    game_cumlregret1, game_cumlrewards1, comm_cost1, diff_theta1, game_width1 = sup_model.run()
    normalized_cuml_rewards4 = game_cumlrewards1[0:M * T_0:M] / M
    normalized_cuml_rewards4 = np.divide(normalized_cuml_rewards4, random_cumulrewards)
    plt.figure()
    plt.plot(normalized_cuml_rewards4[start:], label='20 clients async')
    plt.legend()
    plt.xlabel('Number of arm pulling')
    plt.ylabel('Normalized reward')
    plt.grid(linestyle='dashed', color="grey")
    time1 = time.strftime('%m-%d %H:%M:%S', time.localtime())
    plt.savefig(path + 'Movielens_async_M20_K%s_d%s_C%s_Time%s' % (K, d, C_thred, time1) + '.png', dpi=300)
    plt.show()

    x = np.linspace(0, T_0 - 1, T_0)
    plt.figure()
    plt.plot(normalized_cuml_rewards1[start:], label= 'SupLinUCB')
    plt.plot(normalized_cuml_rewards2[start:], label='5 clients async')
    plt.plot(normalized_cuml_rewards3[start:], label='10 clients async')
    plt.plot(normalized_cuml_rewards4[start:], label='20 clients async')
    plt.legend()
    plt.xlabel('Number of arm pulling')
    plt.ylabel('Normalized reward')
    plt.grid(linestyle='dashed', color="grey")
    time1 = time.strftime('%m-%d %H:%M:%S', time.localtime())
    plt.savefig(path + 'Movielens_async_compare_clients_K%s_d%s_D%s_Time%s' % (K, d, C_thred, time1) + '.png', dpi=300)
    plt.show()

def Movielens_sync_client():
    start = 100

    M = 1
    K = 20
    C_thred = 0.01
    D_thred = 72.928075 # comm thredsholdls
    T_0 = 5000

    M = 1
    pattern1, pattern2, pattern3 = generate_pattern(M, T_0)
    contexts, theta, regrets, rewards = realworld_contexts(M*T_0)
    random_cumulrewards = np.ones(T_0)
    random_cumulregrets = np.zeros(T_0)
    for t in range(T_0):
        random_pick = choice(np.arange(K))
        random_cumulrewards[t] = rewards[t][random_pick] + random_cumulrewards[t - 1]
        random_cumulregrets[t] = regrets[t][random_pick] + random_cumulregrets[t - 1]
    print(f'finish random strategy!')
    sup_model = FedSupUCB(contexts, theta, regrets, rewards, pattern2, C_thred, D_thred,K, M, T_0, comm= False,  type = 'sync')
    game_cumlregret1, game_cumlrewards1, comm_cost1, diff_theta1, game_width1 = sup_model.run()
    normalized_cuml_rewards1 = game_cumlrewards1[0:M * T_0:M]/M
    normalized_cuml_rewards1 = np.divide(normalized_cuml_rewards1, random_cumulrewards)
    plt.figure()
    plt.plot(normalized_cuml_rewards1[start:], label='SupLinUCB')
    plt.legend()
    plt.xlabel('Number of arm pulling')
    plt.ylabel('Normalized reward')
    plt.grid(linestyle='dashed', color="grey")
    time1 = time.strftime('%m-%d %H:%M:%S', time.localtime())
    plt.savefig(path + 'Movielens_sync_SupLinUCB_K%s_d%s_C%s_Time%s' % (K, d, C_thred, time1) + '.png', dpi=300)
    plt.show()

    M = 5
    pattern1, pattern2, pattern3 = generate_pattern(M, T_0)
    contexts, theta, regrets, rewards = realworld_contexts(M * T_0)
    sup_model = FedSupUCB(contexts, theta, regrets, rewards, pattern2, C_thred, D_thred, K, M, T_0, comm=True,
                          type='sync')
    game_cumlregret1, game_cumlrewards1, comm_cost1, diff_theta1, game_width1 = sup_model.run()
    normalized_cuml_rewards2 = game_cumlrewards1[0:M * T_0:M] / M
    normalized_cuml_rewards2 = np.divide(normalized_cuml_rewards2, random_cumulrewards)
    plt.figure()
    plt.plot(normalized_cuml_rewards2[start:], label='5 clients async')
    plt.legend()
    plt.xlabel('Number of arm pulling')
    plt.ylabel('Normalized reward')
    plt.grid(linestyle='dashed', color="grey")
    time1 = time.strftime('%m-%d %H:%M:%S', time.localtime())
    plt.savefig(path + 'Movielens_sync_M5_K%s_d%s_C%s_Time%s' % (K, d, C_thred, time1) + '.png', dpi=300)
    plt.show()

    M = 10
    pattern1, pattern2, pattern3 = generate_pattern(M, T_0)
    contexts, theta, regrets, rewards = realworld_contexts(M * T_0)
    sup_model = FedSupUCB(contexts, theta, regrets, rewards, pattern2, C_thred, D_thred, K, M, T_0, comm=True,
                          type='sync')
    game_cumlregret1, game_cumlrewards1, comm_cost1, diff_theta1, game_width1 = sup_model.run()
    normalized_cuml_rewards3 = game_cumlrewards1[0:M * T_0:M] / M
    normalized_cuml_rewards3 = np.divide(normalized_cuml_rewards3, random_cumulrewards)
    plt.figure()
    plt.plot(normalized_cuml_rewards3[start:], label='10 clients async')
    plt.legend()
    plt.xlabel('Number of arm pulling')
    plt.ylabel('Normalized reward')
    plt.grid(linestyle='dashed', color="grey")
    time1 = time.strftime('%m-%d %H:%M:%S', time.localtime())
    plt.savefig(path + 'Movielens_sync_M10_K%s_d%s_C%s_Time%s' % (K, d, C_thred, time1) + '.png', dpi=300)
    plt.show()

    M = 20
    pattern1, pattern2, pattern3 = generate_pattern(M, T_0)
    contexts, theta, regrets, rewards = realworld_contexts(M * T_0)
    sup_model = FedSupUCB(contexts, theta, regrets, rewards, pattern2, C_thred, D_thred, K, M, T_0, comm=True,
                          type='sync')
    game_cumlregret1, game_cumlrewards1, comm_cost1, diff_theta1, game_width1 = sup_model.run()
    normalized_cuml_rewards4 = game_cumlrewards1[0:M * T_0:M] / M
    normalized_cuml_rewards4 = np.divide(normalized_cuml_rewards4, random_cumulrewards)
    plt.figure()
    plt.plot(normalized_cuml_rewards4[start:], label='20 clients async')
    plt.legend()
    plt.xlabel('Number of arm pulling')
    plt.ylabel('Normalized reward')
    plt.grid(linestyle='dashed', color="grey")
    time1 = time.strftime('%m-%d %H:%M:%S', time.localtime())
    plt.savefig(path + 'Movielens_sync_M20_K%s_d%s_C%s_Time%s' % (K, d, C_thred, time1) + '.png', dpi=300)
    plt.show()

    x = np.linspace(0, T_0 - 1, T_0)
    plt.figure()
    plt.plot(normalized_cuml_rewards1[start:], label= 'SupLinUCB')
    plt.plot(normalized_cuml_rewards2[start:], label='5 clients sync')
    plt.plot(normalized_cuml_rewards3[start:], label='10 clients sync')
    plt.plot(normalized_cuml_rewards4[start:], label='20 clients sync')
    plt.legend()
    plt.xlabel('Number of arm pulling')
    plt.ylabel('Normalized reward')
    plt.grid(linestyle='dashed', color="grey")
    time1 = time.strftime('%m-%d %H:%M:%S', time.localtime())
    plt.savefig(path + 'Movielens_sync_compare_clients_K%s_d%s_D%s_Time%s' % (K, d, C_thred, time1) + '.png', dpi=300)
    plt.show()

def readFeatureVectorFile():
    path = './dataset/processed_data/Arm_FeatureVectors_d25.dat'
    FeatureVectors = {}
    with open(path, 'r') as f:
        line = f.readline()
        for line in f:
            line = line.split("\t")
            vec = line[1].strip('[]').strip('\n').split(';')
            vec = np.array(vec,dtype=float)
            FeatureVectors[int(line[0])] = vec
    return FeatureVectors

def parseLine(line):
    userID, tim, pool_articles = line.split("\t")
    userID, tim = int(userID), int(tim)
    pool_articles = pool_articles.strip('[').strip('\n').strip(']')
    pool_articles = np.array(pool_articles.split(','))
    pool_articles = np.array(pool_articles, dtype = int)
    return userID, tim, pool_articles

def realworld_armpool(): # generate conetxts from movielens
    path = './dataset/processed_data/K20_N37_ObsMoreThan2500_PosOverThree.dat'
    with open(path, 'r') as f:  ###
        arm_pool = []
        line = f.readline() # just read the first line title
        for line in f:
            userID, tim, pool_articles = line.split("\t")
            pool_articles = pool_articles.strip('[').strip('\n').strip(']').split(',')
            pool_articles = np.array(pool_articles, dtype=int)
            arm_pool.append(pool_articles)
    return arm_pool

def realworld_contexts(T = 120000):
    K = 20
    contexts = np.zeros((T, K, d))  # context table [T, K,d]
    rewards = np.zeros((T, K))  # [T,K]
    regrets = np.ones((T, K))  # [T,K]
    feature_vector = readFeatureVectorFile()
    arm_pool = realworld_armpool()
    for t in range(T):
        for i in range(K):
            arm_index = arm_pool[t][i]
            contexts[t][i] = feature_vector[arm_index]
            if i <= 4:
                rewards[t][i] = 1
            else:
                rewards[t][i] = 0.1
        contexts[t] = normalize(contexts[t], axis=1)
    regrets = np.ones((T, K)) - rewards

    x = np.reshape(contexts,(-1,d))
    y = np.reshape(rewards,(-1,1))
    clf = LinearRegression(fit_intercept= False, copy_X=True)
    clf.fit(x, y)
    theta = np.array(clf.coef_, dtype= float)[0]
    bias = clf.intercept_
    y_hat = clf.predict(x)
    R2 = r2_score(y, y_hat)
    print(f'Linear regression R2 score = {R2}')
    print(f'Linear regression estimated theta length = {len(theta)}')
    return contexts, theta, regrets, rewards


if __name__ == '__main__':
    path = "./results/"
    synthetic_compare_pattern()
    # synthetic_async_clients()
    # synthetic_sync_clients()
    # synthetic_tradeoff()
    # Movielens_async_client()
    # Movielens_sync_client()



