import math
import matplotlib.pyplot as plt
import numpy as np
import random


def pull_arm(mu, i, j):
    X = random.uniform(0, 1)
    if X < mu[i][j]:
        return 1
    else:
        return 0


def distributed_successive_eliminate(W, T, C, gamma, K, N, mu, lambda_2, best_arm,mu_star):
    # Initialization LINE 1
    D = 1

    t = np.zeros(N)
    Set_K = [i for i in range(0, K)]
    S = [[i for i in range(0, K)] for _ in range(N)]
    B = [[] for _ in range(N)]
    X = np.zeros((K, N))
    hat_mu = np.zeros((K, N))
    tau = np.zeros((K, N))
    hat_tau = np.ones((K, N))
    reward_x = np.zeros((K, N))

    a = 0
    U = np.zeros((K, N))
    sum_pull_time = np.zeros((K, N))
    sum_reward = np.zeros((K, N))
    regret_list = [[0] for _ in range(N)]

    # 预训练跑f次
    f = 1
    for j in range(N):
        for i in range(K):
            hat_mu[i][j] = pull_arm(mu, i, j)
            tau[i][j] = f
            U[i][j] = f
            sum_pull_time[i][j] = f
            sum_reward[i][j] += mu[i][j] * f
    for j in range(N):
        for i in range(K):
            hat_tau[i][j] = f
    for item in range(len(t)):
        t[item] = f

    # LINE 2
    while t.max() <= T:

        for j in range(N):
            # LINE 3
            # i_max = np.argmax([hat_mu[i, j] for i in S[j]])     #这边如果删除了3，还剩0，1，4三个arm，选出arm4是最大的，此时返回的i_max是4还是2
            i_max = sorted([(i, hat_mu[i][j]) for i in S[j]], key=lambda x: x[1], reverse=True)[0][0]
            # if t.max() == 100:
            #    print(j,i_max,hat_mu)

            hat_mu_max = hat_mu[i_max][j]
            # LINE 4
            for i in S[j]:
                # LINE 5
                reward_x[i][j] = pull_arm(mu, i, j)  # pull_arm function to be defined according to the specific problem
                sum_pull_time[i][j] += 1
                sum_reward[i][j] += reward_x[i][j]
                regret_list[j].append((regret_list[j][-1] + mu_star[best_arm] - mu_star[i]))

                # LINE 6
                t[j] += 1
                # LINE 7
                hat_tau[i][j] += 1  # 这边怕多跑了一次，所以加了下标j
                # LINE 8
                U[i, j] = math.sqrt((1 * math.log(T)) / (N * hat_tau[i][j])) + C / (
                        (1 - lambda_2) * (hat_tau[i][j] + 1))
                # LINE 9
                if hat_mu[i, j] > hat_mu_max - gamma * U[i, j]:
                    # if t.max() == 20000:
                    #    print(hat_mu[i, j],hat_mu_max - gamma * U[i, j],hat_mu_max,gamma * U[i, j])
                    # LINE 10 11
                    tau[i, j] += 1
                # LINE 12
                else:
                    # print("delete")

                    # LINE 13
                    # a = hat_tau[i][j] + D
                    a = t[j] + len(S[j]) * D
                    # LINE 14
                    i_a = i
                    # LINE 15
                    B[j].append([i, a])
        new_hat_mu = hat_mu
        for j in range(N):
            for i in S[j]:
                # LINE 19
                new_hat_mu[i, j] = (hat_tau[i][j] / (hat_tau[i][j] + 1)) * sum(W[j] * hat_mu[i]) + (
                        1 / (1 + hat_tau[i][j])) * reward_x[i][j]
        hat_mu = new_hat_mu
        # LINE 20
        for j in range(N):
            for [i, a] in B[j]:
                # print(j,B[j])
                if t[j] >= a and len(S[j]) > 1 and (i in S[j]):
                    # if hat_tau[i][j]>= a and len(S[j]) > 1:
                    # LINE 22
                    # print(len(S[j]))
                    # print("tobewrong",S[j],i)
                    S[j].remove(i)
                    # print("wrong")
    # print(hat_mu)
    # print(hat_tau)
    # print(S)
    return sum_pull_time, t, S, sum_reward, regret_list


def DSE_DiffAEM_list():
    W = np.array([[0.2, 0.2, 0.2, 0, 0, 0, 0.2, 0.2],
                  [0.2, 0.2, 0.2, 0.2, 0, 0, 0, 0.2],
                  [0.2, 0.2, 0.2, 0.2, 0.2, 0, 0, 0],
                  [0, 0.2, 0.2, 0.2, 0.2, 0.2, 0, 0],
                  [0, 0, 0.2, 0.2, 0.2, 0.2, 0.2, 0],
                  [0, 0, 0, 0.2, 0.2, 0.2, 0.2, 0.2],
                  [0.2, 0, 0, 0, 0.2, 0.2, 0.2, 0.2],
                  [0.2, 0.2, 0, 0, 0, 0.2, 0.2, 0.2]])
    # W = np.array([[0.34, 0.33, 0.33],
    #               [0.33, 0.34, 0.33],
    #               [0.33, 0.33, 0.34]])
    # Mean of arms (K, N)
    mu5 = np.array([[0.95 - 0.03 * m - 0.01 * n for m in range(20)] for n in range(5)]).T
    mu8 = np.array([[0.95 - 0.03 * m - 0.01 * n for m in range(20)] for n in range(8)]).T
    mu11 = np.array([[0.95 - 0.03 * m - 0.01 * n for m in range(20)] for n in range(11)]).T
    mu14 = np.array([[0.95 - 0.03 * m - 0.01 * n for m in range(20)] for n in range(14)]).T
    mu17 = np.array([[0.95 - 0.03 * m - 0.01 * n for m in range(20)] for n in range(17)]).T
    mu20 = np.array([[0.95 - 0.03 * m - 0.01 * n for m in range(20)] for n in range(20)]).T

    W5 = [[1 / 5 for _ in range(5)] for _ in range(5)]
    W8 = [[1 / 8 for _ in range(8)] for _ in range(8)]
    W11 = [[1 / 11 for _ in range(11)] for _ in range(11)]
    W14 = [[1 / 14 for _ in range(14)] for _ in range(14)]
    W17 = [[1 / 17 for _ in range(17)] for _ in range(17)]
    W20 = [[1 / 20 for _ in range(20)] for _ in range(20)]

    mu_star5 = np.mean(mu5, axis=1)
    mu_star8 = np.mean(mu8, axis=1)
    mu_star11 = np.mean(mu11, axis=1)
    mu_star14 = np.mean(mu14, axis=1)
    mu_star17 = np.mean(mu17, axis=1)
    mu_star20 = np.mean(mu20, axis=1)
    # print(mu_star12)
    best_arm5 = 0
    best_arm8 = 0
    best_arm11 = 0
    best_arm14 = 0
    best_arm17 = 0
    best_arm20 = 0

    # print(mu_star)
    # mu = np.array([[0.9, 0.6, 0.4, 0.2, 0.1], [0.9, 0.6, 0.4, 0.2, 0.1], [0.9, 0.6, 0.4, 0.2, 0.1]]).T
    T = int(1e4)  # time
    K_0 = 10  # K arms
    K_1 = 8
    K_2 = 15
    N = 20  # n_agent
    LAMDA_2 = 0.5759
    GAMMA_1 = 1
    GAMMA_2 = 0.5
    GAMMA_3 = 2
    # constant -> 3^(1/2)
    C = math.sqrt(N)

    repeated_time = 50

    regret_list_zero_5 = []
    for repeat_time in range(repeated_time):
        sum_pull_time, time, S, sum_reward, regret_list = distributed_successive_eliminate(W5, T, C, GAMMA_2, 20, 5, mu5,LAMDA_2, best_arm5,mu_star5)
        regret_list_zero_5.append(regret_list[0][-1])

    repeat_time = 50
    regret_list_zero_8 = []
    for repeat_time in range(repeated_time):
        sum_pull_time, time, S, sum_reward, regret_list = distributed_successive_eliminate(W8, T, C, GAMMA_2, 20, 8, mu8,LAMDA_2, best_arm8,mu_star8)
        regret_list_zero_8.append(regret_list[0][-1])

    repeat_time = 50
    regret_list_zero_11 = []
    for repeat_time in range(repeated_time):
        sum_pull_time, time, S, sum_reward, regret_list = distributed_successive_eliminate(W11, T, C, GAMMA_2, 20, 11, mu11,LAMDA_2, best_arm11,mu_star11)
        regret_list_zero_11.append(regret_list[0][-1])

    repeat_time = 50
    regret_list_zero_14 = []
    for repeat_time in range(repeated_time):
        sum_pull_time, time, S, sum_reward, regret_list = distributed_successive_eliminate(W14, T, C, GAMMA_2, 20, 14,mu14, LAMDA_2, best_arm14,mu_star14)
        regret_list_zero_14.append(regret_list[0][-1])

    repeat_time = 50
    regret_list_zero_17 = []
    for repeat_time in range(repeated_time):
        sum_pull_time, time, S, sum_reward, regret_list = distributed_successive_eliminate(W17, T, C, GAMMA_2, 20, 17,mu17, LAMDA_2, best_arm17,mu_star17)
        regret_list_zero_17.append(regret_list[0][-1])

    repeat_time = 50
    regret_list_zero_20 = []
    for repeat_time in range(repeated_time):
        sum_pull_time, time, S, sum_reward, regret_list = distributed_successive_eliminate(W20, T, C, GAMMA_2, 20, 20,mu20, LAMDA_2, best_arm20,mu_star20)
        regret_list_zero_20.append(regret_list[0][-1])


    regret_list_zero_np_5 = np.array(regret_list_zero_5)
    regret_mean_5 = np.mean(regret_list_zero_np_5)
    regret_std_5 = np.std(regret_list_zero_np_5)

    regret_list_zero_np_8 = np.array(regret_list_zero_8)
    regret_mean_8 = np.mean(regret_list_zero_np_8)
    regret_std_8 = np.std(regret_list_zero_np_8)

    regret_list_zero_np_11 = np.array(regret_list_zero_11)
    regret_mean_11 = np.mean(regret_list_zero_np_11)
    regret_std_11 = np.std(regret_list_zero_np_11)

    regret_list_zero_np_14 = np.array(regret_list_zero_14)
    regret_mean_14 = np.mean(regret_list_zero_np_14)
    regret_std_14 = np.std(regret_list_zero_np_14)

    regret_list_zero_np_17 = np.array(regret_list_zero_17)
    regret_mean_17 = np.mean(regret_list_zero_np_17)
    regret_std_17 = np.std(regret_list_zero_np_17)

    regret_list_zero_np_20 = np.array(regret_list_zero_20)
    regret_mean_20 = np.mean(regret_list_zero_np_20)
    regret_std_20 = np.std(regret_list_zero_np_20)

    #regret_mean_5p = [regret_mean_8,regret_mean_10,regret_mean_12,regret_mean_14,regret_mean_16]

    #plt.plot([8,10,12,14,16], [regret_mean_8,regret_mean_10,regret_mean_12,regret_mean_14,regret_mean_16], '-o', color='red', fillstyle='none', markersize=10)
    #plt.fill_between([8,10,12,14,16], [regret_mean_8-regret_std_8,regret_mean_10-regret_std_10,regret_mean_12-regret_std_12,regret_mean_14-regret_std_14,regret_mean_16-regret_std_16], [regret_mean_8+regret_std_8,regret_mean_10+regret_std_10,regret_mean_12+regret_std_12,regret_mean_14+regret_std_14,regret_mean_16+regret_std_16], color='lightcoral',alpha=0.7)

    #plt.grid(True, linestyle='--')
    #plt.title('regret')
    #plt.xlabel('time')
    #plt.ylabel('average regret of each agent')
    #plt.legend()
    #plt.show()
    return ([regret_mean_5,regret_mean_8,regret_mean_11,regret_mean_14,regret_mean_17,regret_mean_20],[regret_mean_5-regret_std_5,regret_mean_8-regret_std_8,regret_mean_11-regret_std_11,regret_mean_14-regret_std_14,regret_mean_17-regret_std_17,regret_mean_20-regret_std_20], [regret_mean_5+regret_std_5,regret_mean_8+regret_std_8,regret_mean_11+regret_std_11,regret_mean_14+regret_std_14,regret_mean_17+regret_std_17,regret_mean_20+regret_std_20])
