import math
import matplotlib.pyplot as plt
import numpy as np
import random


def pull_arm(mu, i, j):
    X = random.uniform(0, 1)
    if X < mu[i][j]:
        return 1
    else:
        return 0


def distributed_successive_eliminate(W, T, C, gamma, K, N, mu, lambda_2, best_arm):
    # Initialization LINE 1
    D = 4

    t = np.zeros(N)
    Set_K = [i for i in range(0, K)]
    S = [[i for i in range(0, K)] for _ in range(N)]
    B = [[] for _ in range(N)]
    X = np.zeros((K, N))
    hat_mu = np.zeros((K, N))
    tau = np.zeros((K, N))
    hat_tau = np.ones((K, N))
    reward_x = np.zeros((K, N))

    a = 0
    U = np.zeros((K, N))
    sum_pull_time = np.zeros((K, N))
    sum_reward = np.zeros((K, N))
    regret_list = [[0], [0], [0], [0], [0], [0], [0], [0]]

    # 预训练跑f次
    f = 1
    for j in range(N):
        for i in range(K):
            hat_mu[i][j] = pull_arm(mu, i, j)
            tau[i][j] = f
            U[i][j] = f
            sum_pull_time[i][j] = f
            sum_reward[i][j] += mu[i][j] * f
    for j in range(N):
        for i in range(K):
            hat_tau[i][j] = f
    for item in range(len(t)):
        t[item] = f

    # LINE 2
    while t.max() <= T:

        for j in range(N):
            # LINE 3
            # i_max = np.argmax([hat_mu[i, j] for i in S[j]])     #这边如果删除了3，还剩0，1，4三个arm，选出arm4是最大的，此时返回的i_max是4还是2
            i_max = sorted([(i, hat_mu[i][j]) for i in S[j]], key=lambda x: x[1], reverse=True)[0][0]
            # if t.max() == 100:
            #    print(j,i_max,hat_mu)

            hat_mu_max = hat_mu[i_max][j]
            # LINE 4
            for i in S[j]:
                # LINE 5
                reward_x[i][j] = pull_arm(mu, i, j)  # pull_arm function to be defined according to the specific problem
                sum_pull_time[i][j] += 1
                sum_reward[i][j] += reward_x[i][j]
                regret_list[j].append((regret_list[j][-1] + mu_star[best_arm] - mu_star[i]))

                # LINE 6
                t[j] += 1
                # LINE 7
                hat_tau[i][j] += 1  # 这边怕多跑了一次，所以加了下标j
                # LINE 8
                U[i, j] = math.sqrt((1 * math.log(T)) / (N * hat_tau[i][j])) + C / (
                        (1 - lambda_2) * (hat_tau[i][j] + 1))
                # LINE 9
                if hat_mu[i, j] > hat_mu_max - gamma * U[i, j]:
                    # if t.max() == 20000:
                    #    print(hat_mu[i, j],hat_mu_max - gamma * U[i, j],hat_mu_max,gamma * U[i, j])
                    # LINE 10 11
                    tau[i, j] += 1
                # LINE 12
                else:
                    # print("delete")

                    # LINE 13
                    # a = hat_tau[i][j] + D
                    a = t[j] + len(S[j]) * D
                    # LINE 14
                    i_a = i
                    # LINE 15
                    B[j].append([i, a])
        new_hat_mu = hat_mu
        for j in range(N):
            for i in S[j]:
                # LINE 19
                new_hat_mu[i, j] = (hat_tau[i][j] / (hat_tau[i][j] + 1)) * sum(W[j] * hat_mu[i]) + (
                        1 / (1 + hat_tau[i][j])) * reward_x[i][j]
        hat_mu = new_hat_mu
        # LINE 20
        for j in range(N):
            for [i, a] in B[j]:
                # print(j,B[j])
                if t[j] >= a and len(S[j]) > 1 and (i in S[j]):
                    # if hat_tau[i][j]>= a and len(S[j]) > 1:
                    # LINE 22
                    # print(len(S[j]))
                    # print("tobewrong",S[j],i)
                    S[j].remove(i)
                    # print("wrong")
    # print(hat_mu)
    # print(hat_tau)
    # print(S)
    return sum_pull_time, t, S, sum_reward, regret_list


if __name__ == '__main__':
    W = np.array([[0.2, 0.2, 0.2, 0, 0, 0, 0.2, 0.2],
                  [0.2, 0.2, 0.2, 0.2, 0, 0, 0, 0.2],
                  [0.2, 0.2, 0.2, 0.2, 0.2, 0, 0, 0],
                  [0, 0.2, 0.2, 0.2, 0.2, 0.2, 0, 0],
                  [0, 0, 0.2, 0.2, 0.2, 0.2, 0.2, 0],
                  [0, 0, 0, 0.2, 0.2, 0.2, 0.2, 0.2],
                  [0.2, 0, 0, 0, 0.2, 0.2, 0.2, 0.2],
                  [0.2, 0.2, 0, 0, 0, 0.2, 0.2, 0.2]])
    # W = np.array([[0.34, 0.33, 0.33],
    #               [0.33, 0.34, 0.33],
    #               [0.33, 0.33, 0.34]])
    # Mean of arms (K, N)
    mu = np.array([
        [0.8, 0.9, 0.95, 0.85, 0.85, 0.8, 0.7, 0.65, 0.75, 0.75],
        [0.7, 0.6, 0.2, 0.1, 0.7, 0.2, 0.3, 0.5, 0.7, 0.4],
        [0.3, 0.3, 0.6, 0.4, 0.8, 0.4, 0.5, 0.7, 0.6, 0.6],
        [0.6, 0.5, 0.7, 0.3, 0.85, 0.9, 0.6, 0.6, 0.3, 0.3],
        [0.7, 0.6, 0.8, 0.2, 0.85, 0.3, 0.3, 0.5, 0.3, 0.2],
        [0.5, 0.7, 0.9, 0.7, 0.9, 0.7, 0.6, 0.5, 0.3, 0.6],
        [0.4, 0.5, 0.7, 0.7, 0.7, 0.8, 0.3, 0.7, 0.7, 0.5],
        [0.3, 0.4, 0.6, 0.7, 0.75, 0.7, 0.5, 0.4, 0.2, 0.3]
    ]).T
    mu_star = column_means = np.mean(mu, axis=1)
    best_arm = 4
    # print(mu_star)
    # mu = np.array([[0.9, 0.6, 0.4, 0.2, 0.1], [0.9, 0.6, 0.4, 0.2, 0.1], [0.9, 0.6, 0.4, 0.2, 0.1]]).T
    T = int(1e4)  # time
    K = 10  # K arms
    N = 8  # n_agent
    LAMDA_2 = 0.5759
    GAMMA = 1
    # constant -> 3^(1/2)
    C = math.sqrt(N)

    repeated_time = 50
    regret_list_zero = []

    for repeat_time in range(repeated_time):
        sum_pull_time, time, S, sum_reward, regret_list = distributed_successive_eliminate(W, T, C, GAMMA, K, N, mu, LAMDA_2, best_arm)
        regret_list_zero.append(regret_list[0])

    min_length = len(regret_list_zero[0])

    # 遍历嵌套列表中的每个列表
    for inner_list in regret_list_zero:
        # 更新最小长度
        min_length = min(min_length, len(inner_list))

    for inner_list_index in range(len(regret_list_zero)):
        # 更新最小长度
        regret_list_zero[inner_list_index] = regret_list_zero[inner_list_index][:min_length]

    # print(len(regret_list_zero))

    regret_list_zero_np = np.array(regret_list_zero)
    print(regret_list_zero_np.shape)
    regret_mean = np.mean(regret_list_zero_np, axis=0)
    # print(regret_mean)
    regret_std = np.std(regret_list_zero_np, axis=0)
    index_list = range(min_length)
    # print("regret_std:",len(regret_std))
    print((regret_std))
    # print(regret_mean - regret_list[0])
    # plt.scatter(index_list, regret_mean,s=1)

    MARKERS = ["o", "D", "s", "^", "v", "p", "*"]
    COLORS = ["#0e5ad3", "#bc2d14", "#22aa16", "#a011a3", "#d1ba0e", "#14ccc2", "#d67413"]
    LINES = ["solid", "dashed", "dashdot", "dotted", "solid", "dashed", "dashdot", "dotted"]

    # for i in range(len(regret_mean)):

    plt.plot(index_list, regret_mean, linewidth=0.1, color=COLORS[0], marker=MARKERS[0], markersize=0.1,
             linestyle=LINES[0])
    # print(regret_std)
    plt.fill_between(index_list, regret_mean - regret_std, regret_mean + regret_std, facecolor=COLORS[1],
                     edgecolor='gray', alpha=0.7)

    plt.title('regret')
    plt.xlabel('time')
    plt.ylabel('value')
    plt.legend()
    plt.show()
    # print(sum_reward)
    # print(sum_pull_time)
    # print(regret_list)
    # print(len(regret_list[0]), len(regret_list[1]), len(regret_list[2]))

    # index_list = range(len(regret_list[0]))
    # plt.scatter(index_list, regret_list[0])
    # plt.title('regret')
    # plt.xlabel('time')
    # plt.ylabel('value')
    # plt.show()

    with open('list_data_DSE_mean.txt', 'w') as f:
        for item in regret_mean:
            f.write('%s\n' % item)

    with open('list_data_DSE_std.txt', 'w') as f:
        for item in regret_std:
            f.write('%s\n' % item)




