import math
import matplotlib.pyplot as plt
import numpy as np
import random


def pull_arm(mu, i, j):
    X = random.uniform(0, 1)
    if X < mu[i][j]:
        return 1
    else:
        return 0


def distributed_successive_eliminate(W, T, C, gamma, K, N, mu, lambda_2, best_arm,mu_star):
    # Initialization LINE 1
    D = 4

    t = np.zeros(N)
    Set_K = [i for i in range(0, K)]
    S = [[i for i in range(0, K)] for _ in range(N)]
    B = [[] for _ in range(N)]
    X = np.zeros((K, N))
    hat_mu = np.zeros((K, N))
    tau = np.zeros((K, N))
    hat_tau = np.ones((K, N))
    reward_x = np.zeros((K, N))

    a = 0
    U = np.zeros((K, N))
    sum_pull_time = np.zeros((K, N))
    sum_reward = np.zeros((K, N))
    regret_list = [[0], [0], [0], [0], [0], [0], [0], [0]]

    # 预训练跑f次
    f = 1
    for j in range(N):
        for i in range(K):
            hat_mu[i][j] = pull_arm(mu, i, j)
            tau[i][j] = f
            U[i][j] = f
            sum_pull_time[i][j] = f
            sum_reward[i][j] += mu[i][j] * f
    for j in range(N):
        for i in range(K):
            hat_tau[i][j] = f
    for item in range(len(t)):
        t[item] = f

    # LINE 2
    while t.max() <= T:

        for j in range(N):
            # LINE 3
            # i_max = np.argmax([hat_mu[i, j] for i in S[j]])     #这边如果删除了3，还剩0，1，4三个arm，选出arm4是最大的，此时返回的i_max是4还是2
            i_max = sorted([(i, hat_mu[i][j]) for i in S[j]], key=lambda x: x[1], reverse=True)[0][0]
            # if t.max() == 100:
            #    print(j,i_max,hat_mu)

            hat_mu_max = hat_mu[i_max][j]
            # LINE 4
            for i in S[j]:
                # LINE 5
                reward_x[i][j] = pull_arm(mu, i, j)  # pull_arm function to be defined according to the specific problem
                sum_pull_time[i][j] += 1
                sum_reward[i][j] += reward_x[i][j]
                regret_list[j].append((regret_list[j][-1] + mu_star[best_arm] - mu_star[i]))

                # LINE 6
                t[j] += 1
                # LINE 7
                hat_tau[i][j] += 1  # 这边怕多跑了一次，所以加了下标j
                # LINE 8
                U[i, j] = math.sqrt((1 * math.log(T)) / (N * hat_tau[i][j])) + C / (
                        (1 - lambda_2) * (hat_tau[i][j] + 1))
                # LINE 9
                if hat_mu[i, j] > hat_mu_max - gamma * U[i, j]:
                    # if t.max() == 20000:
                    #    print(hat_mu[i, j],hat_mu_max - gamma * U[i, j],hat_mu_max,gamma * U[i, j])
                    # LINE 10 11
                    tau[i, j] += 1
                # LINE 12
                else:
                    # print("delete")

                    # LINE 13
                    # a = hat_tau[i][j] + D
                    a = t[j] + len(S[j]) * D
                    # LINE 14
                    i_a = i
                    # LINE 15
                    B[j].append([i, a])
        new_hat_mu = hat_mu
        for j in range(N):
            for i in S[j]:
                # LINE 19
                new_hat_mu[i, j] = (hat_tau[i][j] / (hat_tau[i][j] + 1)) * sum(W[j] * hat_mu[i]) + (
                        1 / (1 + hat_tau[i][j])) * reward_x[i][j]
        hat_mu = new_hat_mu
        # LINE 20
        for j in range(N):
            for [i, a] in B[j]:
                # print(j,B[j])
                if t[j] >= a and len(S[j]) > 1 and (i in S[j]):
                    # if hat_tau[i][j]>= a and len(S[j]) > 1:
                    # LINE 22
                    # print(len(S[j]))
                    # print("tobewrong",S[j],i)
                    S[j].remove(i)
                    # print("wrong")
    # print(hat_mu)
    # print(hat_tau)
    # print(S)
    return sum_pull_time, t, S, sum_reward, regret_list


def DSE_DiffAEM_list():
    W = np.array([[0.2, 0.2, 0.2, 0, 0, 0, 0.2, 0.2],
                  [0.2, 0.2, 0.2, 0.2, 0, 0, 0, 0.2],
                  [0.2, 0.2, 0.2, 0.2, 0.2, 0, 0, 0],
                  [0, 0.2, 0.2, 0.2, 0.2, 0.2, 0, 0],
                  [0, 0, 0.2, 0.2, 0.2, 0.2, 0.2, 0],
                  [0, 0, 0, 0.2, 0.2, 0.2, 0.2, 0.2],
                  [0.2, 0, 0, 0, 0.2, 0.2, 0.2, 0.2],
                  [0.2, 0.2, 0, 0, 0, 0.2, 0.2, 0.2]])
    # W = np.array([[0.34, 0.33, 0.33],
    #               [0.33, 0.34, 0.33],
    #               [0.33, 0.33, 0.34]])
    # Mean of arms (K, N)
    mu_0 = np.array([
        [0.8, 0.9, 0.95, 0.85, 0.85, 0.8, 0.7, 0.65, 0.75, 0.75],
        [0.7, 0.6, 0.2, 0.1, 0.7, 0.2, 0.3, 0.5, 0.7, 0.4],
        [0.3, 0.3, 0.6, 0.4, 0.8, 0.4, 0.5, 0.7, 0.6, 0.6],
        [0.6, 0.5, 0.7, 0.3, 0.85, 0.9, 0.6, 0.6, 0.3, 0.3],
        [0.7, 0.6, 0.8, 0.2, 0.85, 0.3, 0.3, 0.5, 0.3, 0.2],
        [0.5, 0.7, 0.9, 0.7, 0.9, 0.7, 0.6, 0.5, 0.3, 0.6],
        [0.4, 0.5, 0.7, 0.7, 0.7, 0.8, 0.3, 0.7, 0.7, 0.5],
        [0.3, 0.4, 0.6, 0.7, 0.75, 0.7, 0.5, 0.4, 0.2, 0.3]
    ]).T

    mu_1 = np.array([
        [0.8, 0.9, 0.95, 0.85, 0.85, 0.8, 0.7, 0.65],
        [0.7, 0.6, 0.2, 0.1, 0.7, 0.2, 0.3, 0.5],
        [0.3, 0.3, 0.6, 0.4, 0.8, 0.4, 0.5, 0.7],
        [0.6, 0.5, 0.7, 0.3, 0.85, 0.9, 0.6, 0.6],
        [0.7, 0.6, 0.8, 0.2, 0.85, 0.3, 0.3, 0.5],
        [0.5, 0.7, 0.9, 0.7, 0.9, 0.7, 0.6, 0.5],
        [0.4, 0.5, 0.7, 0.7, 0.7, 0.8, 0.3, 0.7],
        [0.3, 0.4, 0.6, 0.7, 0.75, 0.7, 0.5, 0.4]
    ]).T

    mu_2 = np.array([
    [0.9, 0.6, 0.8, 0.8, 0.8, 0.8, 0.7, 0.6, 0.6, 0.9, 0.7, 0.6, 0.4, 0.5, 0.6],
    [0.8, 0.6, 0.7, 0.1, 0.7, 0.2, 0.3, 0.5, 0.6, 0.7, 0.1, 0.7, 0.2, 0.3, 0.5],
    [0.7, 0.7, 0.6, 0.4, 0.8, 0.4, 0.5, 0.7, 0.7, 0.7, 0.6, 0.4, 0.8, 0.4, 0.5],
    [0.8, 0.8, 0.6, 0.3, 0.8, 0.9, 0.6, 0.6, 0.7, 0.2, 0.3, 0.5, 0.6, 0.7, 0.1],
    [0.7, 0.8, 0.8, 0.2, 0.8, 0.3, 0.3, 0.5, 0.7, 0.2, 0.3, 0.5, 0.6, 0.7, 0.1],
    [0.9, 0.9, 0.6, 0.7, 0.9, 0.7, 0.6, 0.5, 0.7, 0.2, 0.3, 0.5, 0.6, 0.7, 0.9],
    [0.8, 0.6, 0.9, 0.7, 0.7, 0.8, 0.3, 0.7, 0.7, 0.9, 0.7, 0.6, 0.5, 0.7, 0.9],
    [0.8, 0.6, 0.6, 0.7, 0.7, 0.7, 0.5, 0.4, 0.6, 0.5, 0.8, 0.1, 0.4, 0.3, 0.4]
]).T

    mu1 = np.array([
        [0.9],
        [0.8],
        [0.7],
        [0.8],
        [0.7],
        [0.9],
        [0.8],
        [0.8]
    ]).T

    mu2 = np.array([
        [0.9, 0.6, 0.8, 0.8, 0.8, 0.8, 0.7, 0.6, 0.6, 0.9],
        [0.8, 0.6, 0.7, 0.1, 0.7, 0.2, 0.3, 0.5, 0.6, 0.7],
        [0.7, 0.7, 0.6, 0.4, 0.8, 0.4, 0.5, 0.7, 0.7, 0.7],
        [0.8, 0.8, 0.6, 0.3, 0.8, 0.9, 0.6, 0.6, 0.7, 0.2],
        [0.7, 0.8, 0.8, 0.2, 0.8, 0.3, 0.3, 0.5, 0.7, 0.2],
        [0.9, 0.9, 0.6, 0.7, 0.9, 0.7, 0.6, 0.5, 0.7, 0.2],
        [0.8, 0.6, 0.9, 0.7, 0.7, 0.8, 0.3, 0.7, 0.7, 0.9],
        [0.8, 0.6, 0.6, 0.7, 0.7, 0.7, 0.5, 0.4, 0.6, 0.5]
    ]).T

    mu3 = np.array([
        [0.9, 0.6, 0.8, 0.8, 0.8, 0.8, 0.7, 0.6, 0.6, 0.9],
        [0.8, 0.6, 0.7, 0.1, 0.7, 0.2, 0.3, 0.5, 0.6, 0.7],
        [0.7, 0.7, 0.6, 0.4, 0.8, 0.4, 0.5, 0.7, 0.7, 0.7],
        [0.8, 0.8, 0.6, 0.3, 0.8, 0.9, 0.6, 0.6, 0.7, 0.2],
        [0.7, 0.8, 0.8, 0.2, 0.8, 0.3, 0.3, 0.5, 0.7, 0.2],
        [0.9, 0.9, 0.6, 0.7, 0.9, 0.7, 0.6, 0.5, 0.7, 0.2],
        [0.8, 0.6, 0.9, 0.7, 0.7, 0.8, 0.3, 0.7, 0.7, 0.9],
        [0.8, 0.6, 0.6, 0.7, 0.7, 0.7, 0.5, 0.4, 0.6, 0.5]
    ]).T

    mu4 = np.array([
        [0.9, 0.6, 0.8, 0.8, 0.8, 0.8, 0.7, 0.6, 0.6, 0.9],
        [0.8, 0.6, 0.7, 0.1, 0.7, 0.2, 0.3, 0.5, 0.6, 0.7],
        [0.7, 0.7, 0.6, 0.4, 0.8, 0.4, 0.5, 0.7, 0.7, 0.7],
        [0.8, 0.8, 0.6, 0.3, 0.8, 0.9, 0.6, 0.6, 0.7, 0.2],
        [0.7, 0.8, 0.8, 0.2, 0.8, 0.3, 0.3, 0.5, 0.7, 0.2],
        [0.9, 0.9, 0.6, 0.7, 0.9, 0.7, 0.6, 0.5, 0.7, 0.2],
        [0.8, 0.6, 0.9, 0.7, 0.7, 0.8, 0.3, 0.7, 0.7, 0.9],
        [0.8, 0.6, 0.6, 0.7, 0.7, 0.7, 0.5, 0.4, 0.6, 0.5]
    ]).T

    mu5 = np.array([
        [0.9, 0.6, 0.8, 0.8, 0.8, 0.8, 0.7, 0.6, 0.6, 0.9],
        [0.8, 0.6, 0.7, 0.1, 0.7, 0.2, 0.3, 0.5, 0.6, 0.7],
        [0.7, 0.7, 0.6, 0.4, 0.8, 0.4, 0.5, 0.7, 0.7, 0.7],
        [0.8, 0.8, 0.6, 0.3, 0.8, 0.9, 0.6, 0.6, 0.7, 0.2],
        [0.7, 0.8, 0.8, 0.2, 0.8, 0.3, 0.3, 0.5, 0.7, 0.2],
        [0.9, 0.9, 0.6, 0.7, 0.9, 0.7, 0.6, 0.5, 0.7, 0.2],
        [0.8, 0.6, 0.9, 0.7, 0.7, 0.8, 0.3, 0.7, 0.7, 0.9],
        [0.8, 0.6, 0.6, 0.7, 0.7, 0.7, 0.5, 0.4, 0.6, 0.5]
    ]).T

    mu6 = np.array([
        [0.9, 0.6, 0.8, 0.8, 0.8, 0.8],
        [0.8, 0.6, 0.7, 0.1, 0.7, 0.2],
        [0.7, 0.7, 0.6, 0.4, 0.8, 0.4],
        [0.8, 0.8, 0.6, 0.3, 0.8, 0.9],
        [0.7, 0.8, 0.8, 0.2, 0.8, 0.3],
        [0.9, 0.9, 0.6, 0.7, 0.9, 0.7],
        [0.8, 0.6, 0.9, 0.7, 0.7, 0.8],
        [0.8, 0.6, 0.6, 0.7, 0.7, 0.7]
    ]).T

    mu7 = np.array([
        [0.9, 0.6, 0.8, 0.8, 0.8, 0.8, 0.7],
        [0.8, 0.6, 0.7, 0.1, 0.7, 0.2, 0.3],
        [0.7, 0.7, 0.6, 0.4, 0.8, 0.4, 0.5],
        [0.8, 0.8, 0.6, 0.3, 0.8, 0.9, 0.6],
        [0.7, 0.8, 0.8, 0.2, 0.8, 0.3, 0.3],
        [0.9, 0.9, 0.6, 0.7, 0.9, 0.7, 0.6],
        [0.8, 0.6, 0.9, 0.7, 0.7, 0.8, 0.3],
        [0.8, 0.6, 0.6, 0.7, 0.7, 0.7, 0.5]
    ]).T

    mu8 = np.array([
        [0.9, 0.6, 0.8, 0.8, 0.6, 0.8, 0.7, 0.6],
        [0.8, 0.6, 0.7, 0.1, 0.6, 0.2, 0.3, 0.5],
        [0.7, 0.7, 0.6, 0.4, 0.6, 0.4, 0.5, 0.7],
        [0.8, 0.8, 0.6, 0.3, 0.6, 0.6, 0.6, 0.6],
        [0.7, 0.8, 0.8, 0.2, 0.6, 0.3, 0.3, 0.5],
        [0.9, 0.9, 0.6, 0.7, 0.6, 0.6, 0.6, 0.5],
        [0.8, 0.6, 0.9, 0.7, 0.6, 0.6, 0.3, 0.7],
        [0.8, 0.6, 0.6, 0.7, 0.6, 0.7, 0.5, 0.4]
    ]).T

    mu9 = np.array([
        [0.9, 0.6, 0.8, 0.8, 0.8, 0.8, 0.7, 0.6, 0.6],
        [0.8, 0.6, 0.7, 0.1, 0.7, 0.2, 0.3, 0.5, 0.6],
        [0.7, 0.7, 0.6, 0.4, 0.8, 0.4, 0.5, 0.7, 0.7],
        [0.8, 0.8, 0.6, 0.3, 0.8, 0.9, 0.6, 0.6, 0.7],
        [0.7, 0.8, 0.8, 0.2, 0.8, 0.3, 0.3, 0.5, 0.7],
        [0.9, 0.9, 0.6, 0.7, 0.9, 0.7, 0.6, 0.5, 0.7],
        [0.8, 0.6, 0.9, 0.7, 0.7, 0.8, 0.3, 0.7, 0.7],
        [0.8, 0.6, 0.6, 0.7, 0.7, 0.7, 0.5, 0.4, 0.6]
    ]).T


    mu10 = np.array([
        [0.9, 0.6, 0.8, 0.8, 0.6, 0.8, 0.7, 0.6, 0.6, 0.9],
        [0.8, 0.6, 0.7, 0.1, 0.6, 0.2, 0.3, 0.5, 0.6, 0.7],
        [0.7, 0.7, 0.6, 0.4, 0.6, 0.4, 0.5, 0.7, 0.7, 0.7],
        [0.8, 0.8, 0.6, 0.3, 0.6, 0.6, 0.6, 0.6, 0.7, 0.2],
        [0.7, 0.8, 0.8, 0.2, 0.6, 0.3, 0.3, 0.5, 0.7, 0.2],
        [0.9, 0.9, 0.6, 0.7, 0.6, 0.6, 0.6, 0.5, 0.7, 0.2],
        [0.8, 0.6, 0.9, 0.7, 0.6, 0.6, 0.3, 0.7, 0.7, 0.9],
        [0.8, 0.6, 0.6, 0.7, 0.6, 0.7, 0.5, 0.4, 0.6, 0.5]
    ]).T

    # mu10 = np.array([
    #     [0.9, 0.6, 0.8, 0.8, 0.8, 0.8, 0.7, 0.6, 0.6, 0.9],
    #     [0.8, 0.6, 0.7, 0.1, 0.7, 0.2, 0.3, 0.5, 0.6, 0.7],
    #     [0.7, 0.7, 0.6, 0.4, 0.8, 0.4, 0.5, 0.7, 0.7, 0.7],
    #     [0.8, 0.8, 0.6, 0.3, 0.8, 0.9, 0.6, 0.6, 0.7, 0.2],
    #     [0.7, 0.8, 0.8, 0.2, 0.8, 0.3, 0.3, 0.5, 0.7, 0.2],
    #     [0.9, 0.9, 0.6, 0.7, 0.9, 0.7, 0.6, 0.5, 0.7, 0.2],
    #     [0.8, 0.6, 0.9, 0.7, 0.7, 0.8, 0.3, 0.7, 0.7, 0.9],
    #     [0.8, 0.6, 0.6, 0.7, 0.7, 0.7, 0.5, 0.4, 0.6, 0.5]
    # ]).T

    mu11 = np.array([
        [0.9, 0.6, 0.8, 0.8, 0.6, 0.8, 0.7, 0.6, 0.6, 0.9, 0.7],
        [0.8, 0.6, 0.7, 0.1, 0.7, 0.2, 0.3, 0.5, 0.6, 0.7, 0.1],
        [0.7, 0.7, 0.6, 0.4, 0.6, 0.4, 0.5, 0.7, 0.7, 0.7, 0.6],
        [0.8, 0.8, 0.6, 0.3, 0.6, 0.6, 0.6, 0.6, 0.7, 0.2, 0.3],
        [0.7, 0.8, 0.8, 0.2, 0.6, 0.3, 0.3, 0.5, 0.7, 0.2, 0.3],
        [0.9, 0.9, 0.6, 0.7, 0.6, 0.7, 0.6, 0.5, 0.7, 0.2, 0.3],
        [0.8, 0.6, 0.9, 0.7, 0.6, 0.8, 0.3, 0.7, 0.7, 0.9, 0.7],
        [0.8, 0.6, 0.6, 0.7, 0.6, 0.7, 0.5, 0.4, 0.6, 0.5, 0.8]
    ]).T

    mu12 = np.array([
        [0.9, 0.6, 0.8, 0.8, 0.6, 0.8, 0.7, 0.6, 0.6, 0.9, 0.7, 0.6],
        [0.8, 0.6, 0.7, 0.1, 0.6, 0.2, 0.3, 0.5, 0.6, 0.7, 0.1, 0.7],
        [0.7, 0.7, 0.6, 0.4, 0.6, 0.4, 0.5, 0.7, 0.7, 0.7, 0.6, 0.4],
        [0.8, 0.8, 0.6, 0.3, 0.6, 0.6, 0.6, 0.6, 0.7, 0.2, 0.3, 0.5],
        [0.7, 0.8, 0.8, 0.2, 0.6, 0.3, 0.3, 0.5, 0.7, 0.2, 0.3, 0.5],
        [0.9, 0.9, 0.6, 0.7, 0.6, 0.6, 0.6, 0.5, 0.7, 0.2, 0.3, 0.5],
        [0.8, 0.6, 0.9, 0.7, 0.6, 0.6, 0.3, 0.7, 0.7, 0.9, 0.7, 0.6],
        [0.8, 0.6, 0.6, 0.7, 0.6, 0.7, 0.5, 0.4, 0.6, 0.5, 0.8, 0.1]
    ]).T

    # mu12 = np.array([
    #     [0.9, 0.6, 0.8, 0.8, 0.8, 0.8, 0.7, 0.6, 0.6, 0.9, 0.7, 0.6],
    #     [0.8, 0.6, 0.7, 0.1, 0.7, 0.2, 0.3, 0.5, 0.6, 0.7, 0.1, 0.7],
    #     [0.7, 0.7, 0.6, 0.4, 0.8, 0.4, 0.5, 0.7, 0.7, 0.7, 0.6, 0.4],
    #     [0.8, 0.8, 0.6, 0.3, 0.8, 0.9, 0.6, 0.6, 0.7, 0.2, 0.3, 0.5],
    #     [0.7, 0.8, 0.8, 0.2, 0.8, 0.3, 0.3, 0.5, 0.7, 0.2, 0.3, 0.5],
    #     [0.9, 0.9, 0.6, 0.7, 0.9, 0.7, 0.6, 0.5, 0.7, 0.2, 0.3, 0.5],
    #     [0.8, 0.6, 0.9, 0.7, 0.7, 0.8, 0.3, 0.7, 0.7, 0.9, 0.7, 0.6],
    #     [0.8, 0.6, 0.6, 0.7, 0.7, 0.7, 0.5, 0.4, 0.6, 0.5, 0.8, 0.1]
    # ]).T

    mu13 = np.array([
        [0.9, 0.6, 0.8, 0.8, 0.8, 0.8, 0.7, 0.6, 0.6, 0.9, 0.7, 0.6, 0.4],
        [0.8, 0.6, 0.7, 0.1, 0.7, 0.2, 0.3, 0.5, 0.6, 0.7, 0.1, 0.7, 0.2],
        [0.7, 0.7, 0.6, 0.4, 0.8, 0.4, 0.5, 0.7, 0.7, 0.7, 0.6, 0.4, 0.8],
        [0.8, 0.8, 0.6, 0.3, 0.8, 0.9, 0.6, 0.6, 0.7, 0.2, 0.3, 0.5, 0.6],
        [0.7, 0.8, 0.8, 0.2, 0.8, 0.3, 0.3, 0.5, 0.7, 0.2, 0.3, 0.5, 0.6],
        [0.9, 0.9, 0.6, 0.7, 0.9, 0.7, 0.6, 0.5, 0.7, 0.2, 0.3, 0.5, 0.6],
        [0.8, 0.6, 0.9, 0.7, 0.7, 0.8, 0.3, 0.7, 0.7, 0.9, 0.7, 0.6, 0.5],
        [0.8, 0.6, 0.6, 0.7, 0.7, 0.7, 0.5, 0.4, 0.6, 0.5, 0.8, 0.1, 0.4]
    ]).T

    mu14 = np.array([
        [0.9, 0.6, 0.8, 0.8, 0.6, 0.8, 0.7, 0.6, 0.6, 0.9, 0.7, 0.6, 0.4, 0.5],
        [0.8, 0.6, 0.7, 0.1, 0.6, 0.2, 0.3, 0.5, 0.6, 0.7, 0.1, 0.7, 0.2, 0.3],
        [0.7, 0.7, 0.6, 0.4, 0.6, 0.4, 0.5, 0.7, 0.7, 0.7, 0.6, 0.4, 0.8, 0.4],
        [0.8, 0.8, 0.6, 0.3, 0.6, 0.6, 0.6, 0.6, 0.7, 0.2, 0.3, 0.5, 0.6, 0.7],
        [0.7, 0.8, 0.8, 0.2, 0.6, 0.3, 0.3, 0.5, 0.7, 0.2, 0.3, 0.5, 0.6, 0.7],
        [0.9, 0.9, 0.6, 0.7, 0.6, 0.6, 0.6, 0.5, 0.7, 0.2, 0.3, 0.5, 0.6, 0.7],
        [0.8, 0.6, 0.9, 0.7, 0.6, 0.6, 0.3, 0.7, 0.7, 0.9, 0.7, 0.6, 0.5, 0.7],
        [0.8, 0.6, 0.6, 0.7, 0.6, 0.7, 0.5, 0.4, 0.6, 0.5, 0.8, 0.1, 0.4, 0.3]
    ]).T

    # mu14 = np.array([
    #     [0.9, 0.6, 0.8, 0.8, 0.8, 0.8, 0.7, 0.6, 0.6, 0.9, 0.7, 0.6, 0.4, 0.5],
    #     [0.8, 0.6, 0.7, 0.1, 0.7, 0.2, 0.3, 0.5, 0.6, 0.7, 0.1, 0.7, 0.2, 0.3],
    #     [0.7, 0.7, 0.6, 0.4, 0.8, 0.4, 0.5, 0.7, 0.7, 0.7, 0.6, 0.4, 0.8, 0.4],
    #     [0.8, 0.8, 0.6, 0.3, 0.8, 0.9, 0.6, 0.6, 0.7, 0.2, 0.3, 0.5, 0.6, 0.7],
    #     [0.7, 0.8, 0.8, 0.2, 0.8, 0.3, 0.3, 0.5, 0.7, 0.2, 0.3, 0.5, 0.6, 0.7],
    #     [0.9, 0.9, 0.6, 0.7, 0.9, 0.7, 0.6, 0.5, 0.7, 0.2, 0.3, 0.5, 0.6, 0.7],
    #     [0.8, 0.6, 0.9, 0.7, 0.7, 0.8, 0.3, 0.7, 0.7, 0.9, 0.7, 0.6, 0.5, 0.7],
    #     [0.8, 0.6, 0.6, 0.7, 0.7, 0.7, 0.5, 0.4, 0.6, 0.5, 0.8, 0.1, 0.4, 0.3]
    # ]).T

    mu15 = np.array([
        [0.9, 0.6, 0.8, 0.8, 0.8, 0.8, 0.7, 0.6, 0.6, 0.9, 0.7, 0.6, 0.4, 0.5, 0.6],
        [0.8, 0.6, 0.7, 0.1, 0.7, 0.2, 0.3, 0.5, 0.6, 0.7, 0.1, 0.7, 0.2, 0.3, 0.5],
        [0.7, 0.7, 0.6, 0.4, 0.8, 0.4, 0.5, 0.7, 0.7, 0.7, 0.6, 0.4, 0.8, 0.4, 0.5],
        [0.8, 0.8, 0.6, 0.3, 0.8, 0.9, 0.6, 0.6, 0.7, 0.2, 0.3, 0.5, 0.6, 0.7, 0.1],
        [0.7, 0.8, 0.8, 0.2, 0.8, 0.3, 0.3, 0.5, 0.7, 0.2, 0.3, 0.5, 0.6, 0.7, 0.1],
        [0.9, 0.9, 0.6, 0.7, 0.9, 0.7, 0.6, 0.5, 0.7, 0.2, 0.3, 0.5, 0.6, 0.7, 0.9],
        [0.8, 0.6, 0.9, 0.7, 0.7, 0.8, 0.3, 0.7, 0.7, 0.9, 0.7, 0.6, 0.5, 0.7, 0.9],
        [0.8, 0.6, 0.6, 0.7, 0.7, 0.7, 0.5, 0.4, 0.6, 0.5, 0.8, 0.1, 0.4, 0.3, 0.4]
    ]).T

    mu16 = np.array([
        [0.9, 0.6, 0.8, 0.8, 0.6, 0.8, 0.7, 0.6, 0.6, 0.9, 0.7, 0.6, 0.4, 0.5, 0.6, 0.6],
        [0.8, 0.6, 0.7, 0.1, 0.6, 0.2, 0.3, 0.5, 0.6, 0.7, 0.1, 0.7, 0.2, 0.3, 0.5, 0.5],
        [0.7, 0.7, 0.6, 0.4, 0.6, 0.4, 0.5, 0.7, 0.7, 0.7, 0.6, 0.4, 0.8, 0.4, 0.5, 0.5],
        [0.8, 0.8, 0.6, 0.3, 0.6, 0.6, 0.6, 0.6, 0.7, 0.2, 0.3, 0.5, 0.6, 0.7, 0.1, 0.1],
        [0.7, 0.8, 0.8, 0.2, 0.6, 0.3, 0.3, 0.5, 0.7, 0.2, 0.3, 0.5, 0.6, 0.7, 0.1, 0.1],
        [0.9, 0.9, 0.6, 0.7, 0.6, 0.6, 0.6, 0.5, 0.7, 0.2, 0.3, 0.5, 0.6, 0.7, 0.9, 0.4],
        [0.8, 0.6, 0.9, 0.7, 0.6, 0.6, 0.3, 0.7, 0.7, 0.9, 0.7, 0.6, 0.5, 0.7, 0.9, 0.4],
        [0.8, 0.6, 0.6, 0.7, 0.6, 0.7, 0.5, 0.4, 0.6, 0.5, 0.8, 0.1, 0.4, 0.3, 0.4, 0.4]
    ]).T

    # mu16 = np.array([
    #     [0.9, 0.6, 0.8, 0.8, 0.8, 0.8, 0.7, 0.6, 0.6, 0.9, 0.7, 0.6, 0.4, 0.5, 0.6, 0.6],
    #     [0.8, 0.6, 0.7, 0.1, 0.7, 0.2, 0.3, 0.5, 0.6, 0.7, 0.1, 0.7, 0.2, 0.3, 0.5, 0.5],
    #     [0.7, 0.7, 0.6, 0.4, 0.8, 0.4, 0.5, 0.7, 0.7, 0.7, 0.6, 0.4, 0.8, 0.4, 0.5, 0.5],
    #     [0.8, 0.8, 0.6, 0.3, 0.8, 0.9, 0.6, 0.6, 0.7, 0.2, 0.3, 0.5, 0.6, 0.7, 0.1, 0.1],
    #     [0.7, 0.8, 0.8, 0.2, 0.8, 0.3, 0.3, 0.5, 0.7, 0.2, 0.3, 0.5, 0.6, 0.7, 0.1, 0.1],
    #     [0.9, 0.9, 0.6, 0.7, 0.9, 0.7, 0.6, 0.5, 0.7, 0.2, 0.3, 0.5, 0.6, 0.7, 0.9, 0.4],
    #     [0.8, 0.6, 0.9, 0.7, 0.7, 0.8, 0.3, 0.7, 0.7, 0.9, 0.7, 0.6, 0.5, 0.7, 0.9, 0.4],
    #     [0.8, 0.6, 0.6, 0.7, 0.7, 0.7, 0.5, 0.4, 0.6, 0.5, 0.8, 0.1, 0.4, 0.3, 0.4, 0.4]
    # ]).T


    #print(np.mean(mu,axis=0))
    #print(np.mean(mu,axis=1))
    #print(mu)
    #return 0

    mu_star_0 = column_means = np.mean(mu_0, axis=1)
    mu_star_1 = column_means = np.mean(mu_1, axis=1)
    mu_star_2 = column_means = np.mean(mu_2, axis=1)
    best_arm_0 = 4
    best_arm_1 = 4
    best_arm_2 = 0

    mu_star8 = np.mean(mu8, axis=1)
    mu_star10 = np.mean(mu10, axis=1)
    mu_star12 = np.mean(mu12, axis=1)
    mu_star14 = np.mean(mu14, axis=1)
    mu_star16 = np.mean(mu16, axis=1)
    # print(mu_star12)
    best_arm8 = 0
    best_arm10 = 0
    best_arm12 = 0
    best_arm14 = 0
    best_arm16 = 0

    # print(mu_star)
    # mu = np.array([[0.9, 0.6, 0.4, 0.2, 0.1], [0.9, 0.6, 0.4, 0.2, 0.1], [0.9, 0.6, 0.4, 0.2, 0.1]]).T
    T = int(8e4)  # time
    K_0 = 10  # K arms
    K_1 = 8
    K_2 = 15
    N = 8  # n_agent
    LAMDA_2 = 0.5759
    GAMMA_1 = 1
    GAMMA_2 = 1
    GAMMA_3 = 1
    # constant -> 3^(1/2)
    C = math.sqrt(N)

    repeated_time = 10

    regret_list_zero_8 = []
    for repeat_time in range(repeated_time):
        sum_pull_time, time, S, sum_reward, regret_list = distributed_successive_eliminate(W, T, C, GAMMA_2, 8, N, mu8,LAMDA_2, best_arm8,mu_star8)
        regret_list_zero_8.append(regret_list[0][-1])

    repeat_time = 0
    regret_list_zero_10 = []
    for repeat_time in range(repeated_time):
        sum_pull_time, time, S, sum_reward, regret_list = distributed_successive_eliminate(W, T, C, GAMMA_2, 10, N, mu10,LAMDA_2, best_arm10,mu_star10)
        regret_list_zero_10.append(regret_list[0][-1])

    repeat_time = 0
    regret_list_zero_12 = []
    for repeat_time in range(repeated_time):
        sum_pull_time, time, S, sum_reward, regret_list = distributed_successive_eliminate(W, T, C, GAMMA_2, 12, N, mu12,LAMDA_2, best_arm12,mu_star12)
        regret_list_zero_12.append(regret_list[0][-1])

    repeat_time = 0
    regret_list_zero_14 = []
    for repeat_time in range(repeated_time):
        sum_pull_time, time, S, sum_reward, regret_list = distributed_successive_eliminate(W, T, C, GAMMA_2, 14, N,mu14, LAMDA_2, best_arm14,mu_star14)
        regret_list_zero_14.append(regret_list[0][-1])

    repeat_time = 0
    regret_list_zero_16 = []
    for repeat_time in range(repeated_time):
        sum_pull_time, time, S, sum_reward, regret_list = distributed_successive_eliminate(W, T, C, GAMMA_2, 16, N,mu16, LAMDA_2, best_arm16,mu_star16)
        regret_list_zero_16.append(regret_list[0][-1])


    regret_list_zero_np_8 = np.array(regret_list_zero_8)
    regret_mean_8 = np.mean(regret_list_zero_np_8)
    regret_std_8 = np.std(regret_list_zero_np_8)

    regret_list_zero_np_10 = np.array(regret_list_zero_10)
    regret_mean_10 = np.mean(regret_list_zero_np_10)
    regret_std_10 = np.std(regret_list_zero_np_10)

    regret_list_zero_np_12 = np.array(regret_list_zero_12)
    regret_mean_12 = np.mean(regret_list_zero_np_12)
    regret_std_12 = np.std(regret_list_zero_np_12)

    regret_list_zero_np_14 = np.array(regret_list_zero_14)
    regret_mean_14 = np.mean(regret_list_zero_np_14)
    regret_std_14 = np.std(regret_list_zero_np_14)

    regret_list_zero_np_16 = np.array(regret_list_zero_16)
    regret_mean_16 = np.mean(regret_list_zero_np_16)
    regret_std_16 = np.std(regret_list_zero_np_16)

    #regret_mean_5p = [regret_mean_8,regret_mean_10,regret_mean_12,regret_mean_14,regret_mean_16]

    #plt.plot([8,10,12,14,16], [regret_mean_8,regret_mean_10,regret_mean_12,regret_mean_14,regret_mean_16], '-o', color='red', fillstyle='none', markersize=10)
    #plt.fill_between([8,10,12,14,16], [regret_mean_8-regret_std_8,regret_mean_10-regret_std_10,regret_mean_12-regret_std_12,regret_mean_14-regret_std_14,regret_mean_16-regret_std_16], [regret_mean_8+regret_std_8,regret_mean_10+regret_std_10,regret_mean_12+regret_std_12,regret_mean_14+regret_std_14,regret_mean_16+regret_std_16], color='lightcoral',alpha=0.7)

    #plt.grid(True, linestyle='--')
    #plt.title('regret')
    #plt.xlabel('time')
    #plt.ylabel('average regret of each agent')
    #plt.legend()
    #plt.show()
    return ([regret_mean_8,regret_mean_10,regret_mean_12,regret_mean_14,regret_mean_16],[regret_mean_8-regret_std_8,regret_mean_10-regret_std_10,regret_mean_12-regret_std_12,regret_mean_14-regret_std_14,regret_mean_16-regret_std_16], [regret_mean_8+regret_std_8,regret_mean_10+regret_std_10,regret_mean_12+regret_std_12,regret_mean_14+regret_std_14,regret_mean_16+regret_std_16])
