import math
import numpy as np
import random
from K_Arm_Bandit import K_Arm_Bandit
from Lipschitz_Bandit import Lipschitz_Bandit
from Convex_Bandit import Convex_Bandit
import matplotlib.pyplot as plt
plt.rc("font", family="Times New Roman")


def Test_K_Arm(round_num=1000, SAT=True):
    ############################
    # Change the settings here
    K, Total_Budget = 4, 5000
    noise_var = 1.0
    true_mean = np.array([0.6, 0.7, 0.8, 1])  # Note that K == len(true_mean)
    time_budget = [i * 500 for i in range(1, int(Total_Budget / 500) + 1)]
    S = 0.93  # 1.5 for unrealizable case
    ############################

    K_Arm = K_Arm_Bandit({"K": K, "T": Total_Budget, "noise_var": noise_var, "true_mean": true_mean})
    regret_SELECT = np.zeros((round_num, len(time_budget)))
    regret_SatUCB = np.zeros((round_num, len(time_budget)))
    regret_SatUCB_plus = np.zeros((round_num, len(time_budget)))
    regret_TS = np.zeros((round_num, len(time_budget)))

    if SAT:
        print("\n##### Satisficing #####")
        for i in range(round_num):
            print(f"##### Round {i} #####")
            for j in range(len(time_budget)):
                T = time_budget[j]
                # print(f"##### T = {T} #####")
                # Sat-UCB
                K_Arm.Clear_History()
                Reward_Trajectory_SatUCB, Arm_Trajectory_SatUCB = K_Arm.Algorithm_Sat_UCB(T=T, S=S)
                cumulative_regret_SatUCB = np.sum([S - min(true_mean[arm], S) for arm in Arm_Trajectory_SatUCB])
                regret_SatUCB[i, j] = cumulative_regret_SatUCB

                # Sat-UCB+
                K_Arm.Clear_History()
                Reward_Trajectory_SatUCB_plus, Arm_Trajectory_SatUCB_plus = K_Arm.Algorithm_Sat_UCB(T=T, S=S, plus=True)
                cumulative_regret_SatUCB_plus = np.sum([S - min(true_mean[arm], S) for arm in Arm_Trajectory_SatUCB_plus])
                regret_SatUCB_plus[i, j] = cumulative_regret_SatUCB_plus

                # UCB
                K_Arm.Clear_History()
                Reward_Trajectory_TS, Arm_Trajectory_TS = K_Arm.Algorithm_TS(T=T)
                cumulative_regret_TS = np.sum([S - min(true_mean[arm], S) for arm in Arm_Trajectory_TS])
                regret_TS[i, j] = cumulative_regret_TS

            K_Arm.Clear_History()
            Reward_Trajectory, Arm_Trajectory = K_Arm.Algorithm_SAT_Explore(T=T, S=S)
            regret_Trajectory = [S - min(true_mean[arm], S) for arm in Arm_Trajectory]
            cumulative_regret_SELECT = np.cumsum(regret_Trajectory)
            regret_SELECT[i] = cumulative_regret_SELECT[np.array(time_budget) - 1]
        return regret_SELECT, regret_SatUCB, regret_SatUCB_plus, regret_TS
    else:
        print("\n##### Non-satisficing #####")
        for i in range(round_num):
            print(f"##### Round {i} #####")
            for j in range(len(time_budget)):
                T = time_budget[j]
                # print(f"##### T = {T} #####")
                # Sat-UCB
                K_Arm.Clear_History()
                Reward_Trajectory_SatUCB, Arm_Trajectory_SatUCB = K_Arm.Algorithm_Sat_UCB(T=T, S=S)
                cumulative_regret_SatUCB = np.sum([1 - true_mean[arm] for arm in Arm_Trajectory_SatUCB])
                regret_SatUCB[i, j] = cumulative_regret_SatUCB

                # Sat-UCB+
                K_Arm.Clear_History()
                Reward_Trajectory_SatUCB_plus, Arm_Trajectory_SatUCB_plus = K_Arm.Algorithm_Sat_UCB(T=T, S=S, plus=True)
                cumulative_regret_SatUCB_plus = np.sum([1 - true_mean[arm] for arm in Arm_Trajectory_SatUCB_plus])
                regret_SatUCB_plus[i, j] = cumulative_regret_SatUCB_plus

                # UCB
                K_Arm.Clear_History()
                Reward_Trajectory_UCB, Arm_Trajectory_UCB = K_Arm.Algorithm_UCB(T=T)
                cumulative_regret_UCB = np.sum([1 - true_mean[arm] for arm in Arm_Trajectory_UCB])
                regret_UCB[i, j] = cumulative_regret_UCB

            K_Arm.Clear_History()
            Reward_Trajectory, Arm_Trajectory = K_Arm.Algorithm_SAT_Explore(T=T, S=S)
            regret_Trajectory = [1 - true_mean[arm] for arm in Arm_Trajectory]
            cumulative_regret_SELECT = np.cumsum(regret_Trajectory)
            regret_SELECT[i] = cumulative_regret_SELECT[np.array(time_budget) - 1]
        return regret_SELECT, regret_SatUCB, regret_SatUCB_plus, regret_UCB


def Test_Concave_Bandit(round_num=1000, SAT=True):
    ############################
    # Change the settings here
    K, Total_Budget = 10000, 5000  # K here does not matter, a larger K only means
                                   # a more accurate interpolation of lipschitz reward distribution
    time_budget = [i * 500 for i in range(1, int(Total_Budget / 500) + 1)]
    noise_var = 1.0
    coeff, best_arm = 16, 0.25  # reward = 1 - coeff * (x - best_arm)^2
    S = 0.3  # -0.5 for unrealizable cases
    ############################
    S = 1 - S  # this is because we implement the algorithm based on convex bandit

    CB = Convex_Bandit({"best_arm":best_arm, "coeff":coeff, "T":Total_Budget, "noise_var":noise_var})
    true_mean = 1 - CB.Get_Real_Value(np.linspace(0, 1, K + 1))
    LB = Lipschitz_Bandit({"K": K, "noise_var": noise_var, "T": Total_Budget, "true_mean": true_mean,
                           "lip_coeff": max(abs(2 * coeff * best_arm), abs(2 * coeff * (1 - best_arm)))})
    regret_SELECT = np.zeros((round_num, len(time_budget)))
    regret_Convex = np.zeros((round_num, len(time_budget)))
    regret_SatUCB = np.zeros((round_num, len(time_budget)))
    regret_SatUCB_plus = np.zeros((round_num, len(time_budget)))

    if SAT:
        print("\n##### Satisficing #####")
        for i in range(round_num):
            print(f"##### Round {i} #####")
            for j in range(len(time_budget)):
                T = time_budget[j]
                # print(f"##### T = {T} #####")
                # UCB
                CB.Clear_History()
                Reward_Trajectory_Convex, Arm_Trajectory_Convex = CB.Algorithm_Convex(T=T)
                cumulative_regret_Convex = np.sum([max(CB.Get_Real_Value(arm), S) - S for arm in Arm_Trajectory_Convex])
                regret_Convex[i, j] = cumulative_regret_Convex

                # Sat-UCB
                S_lb = 1 - S
                LB.Clear_History()
                Reward_Trajectory_SatUCB, Arm_Trajectory_SatUCB = LB.Algorithm_Sat_UCB_Uniform(T=T, S=S_lb)
                cumulative_regret_SatUCB = np.sum([S_lb - min(LB.Get_Real_Value(arm), S_lb) for arm in Arm_Trajectory_SatUCB])
                regret_SatUCB[i, j] = cumulative_regret_SatUCB

                # Sat-UCB+
                LB.Clear_History()
                Reward_Trajectory_SatUCB_plus, Arm_Trajectory_SatUCB_plus = LB.Algorithm_Sat_UCB_Uniform(T=T, S=S_lb, plus=True)
                cumulative_regret_SatUCB_plus = np.sum([S_lb - min(LB.Get_Real_Value(arm), S_lb) for arm in Arm_Trajectory_SatUCB_plus])
                regret_SatUCB_plus[i, j] = cumulative_regret_SatUCB_plus

            # SELECT
            CB.Clear_History()
            Reward_Trajectory, Arm_Trajectory = CB.Algorithm_SAT_Explore(T=T, S=S)
            regret_Trajectory = [max(CB.Get_Real_Value(arm), S) - S for arm in Arm_Trajectory]
            cumulative_regret_SELECT = np.cumsum(regret_Trajectory)
            regret_SELECT[i] = cumulative_regret_SELECT[np.array(time_budget) - 1]
        return regret_SELECT, regret_SatUCB, regret_SatUCB_plus, regret_Convex
    else:
        print("\n##### Non-Satisficing #####")
        for i in range(round_num):
            print(f"##### Round {i} #####")
            for j in range(len(time_budget)):
                T = time_budget[j]
                # print(f"##### T = {T} #####")
                # UCB
                CB.Clear_History()
                Reward_Trajectory_Convex, Arm_Trajectory_Convex = CB.Algorithm_Convex(T=T)
                cumulative_regret_Convex = np.sum([CB.Get_Real_Value(arm) for arm in Arm_Trajectory_Convex])
                regret_Convex[i, j] = cumulative_regret_Convex

                # Sat-UCB
                S_lb = 1 - S
                LB.Clear_History()
                Reward_Trajectory_SatUCB, Arm_Trajectory_SatUCB = LB.Algorithm_Sat_UCB_Uniform(T=T, S=S_lb)
                cumulative_regret_SatUCB = np.sum([1 - LB.Get_Real_Value(arm) for arm in Arm_Trajectory_SatUCB])
                regret_SatUCB[i, j] = cumulative_regret_SatUCB

                # Sat-UCB+
                LB.Clear_History()
                Reward_Trajectory_SatUCB_plus, Arm_Trajectory_SatUCB_plus = LB.Algorithm_Sat_UCB_Uniform(T=T, S=S_lb, plus=True)
                cumulative_regret_SatUCB_plus = np.sum([1 - LB.Get_Real_Value(arm) for arm in Arm_Trajectory_SatUCB_plus])
                regret_SatUCB_plus[i, j] = cumulative_regret_SatUCB_plus

            # SELECT
            CB.Clear_History()
            Reward_Trajectory, Arm_Trajectory = CB.Algorithm_SAT_Explore(T=T, S=S)
            regret_Trajectory = [CB.Get_Real_Value(arm) for arm in Arm_Trajectory]
            cumulative_regret_SELECT = np.cumsum(regret_Trajectory)
            regret_SELECT[i] = cumulative_regret_SELECT[np.array(time_budget) - 1]
        return regret_SELECT, regret_SatUCB, regret_SatUCB_plus, regret_Convex


def Test_Lipschitz_Bandit_2d(round_num=1000, SAT=True):
    ############################
    # Change the settings here
    Total_Budget = 5000
    time_budget = [i * 500 for i in range(1, int(Total_Budget / 500) + 1)]
    S = 0.5  # 1.5 for unrealizable cases
    centers = [(0.5, 0.7)]  # should be a list of tuples
    coeff = 3
    x, y = np.linspace(0, 1, 400), np.linspace(0, 1, 400)
    X, Y = np.meshgrid(x, y)
    Z = 0
    for center in centers:
        Z += coeff * np.exp(-100 * (X - center[0]) ** 2 - 100 * (Y - center[1]) ** 2)
    ############################

    min_max = (Z.min(), Z.max())
    gradZ_X, gradZ_Y = np.gradient(Z, x[1] - x[0], y[1] - y[0])
    grad_norm = np.abs(gradZ_X) + np.abs(gradZ_Y)
    lip_coeff = np.max(grad_norm)
    LB2 = Lipschitz_Bandit({"noise_var": 1.0, "T": Total_Budget, "centers": centers, "min_max": min_max, "coeff": coeff, "lip_coeff": lip_coeff})

    regret_SELECT = np.zeros((round_num, len(time_budget)))
    regret_SatUCB = np.zeros((round_num, len(time_budget)))
    regret_UCB = np.zeros((round_num, len(time_budget)))
    regret_SatUCB_plus = np.zeros((round_num, len(time_budget)))

    if SAT:
        print("\n##### Satisficing #####")
        for i in range(round_num):
            print(f"##### Round {i} #####")
            for t in range(len(time_budget)):
                T = time_budget[t]
                # Sat-UCB+
                LB2.Clear_History()
                Reward_Trajectory_SatUCB_plus, Arm_Trajectory_SatUCB_plus = LB2.Algorithm_Sat_UCB_Uniform(T=T, S=S, dim=2, plus=True)
                cumulative_regret_SatUCB_plus = np.sum([S - min(LB2.Get_Real_Value_2(arm), S) for arm in Arm_Trajectory_SatUCB_plus])
                regret_SatUCB_plus[i, t] = cumulative_regret_SatUCB_plus

                # Sat-UCB
                LB2.Clear_History()
                Reward_Trajectory_SatUCB, Arm_Trajectory_SatUCB = LB2.Algorithm_Sat_UCB_Uniform(T=T, S=S, dim=2, plus=False)
                cumulative_regret_SatUCB = np.sum([S - min(LB2.Get_Real_Value_2(arm), S) for arm in Arm_Trajectory_SatUCB])
                regret_SatUCB[i, t] = cumulative_regret_SatUCB

                # Uniform Discretize
                Reward_Trajectory_UCB, Arm_Trajectory_UCB = LB2.Algorithm_Uniform(T=T, S=S, dim=2, half=False)
                cumulative_regret_UCB = np.sum([S - min(LB2.Get_Real_Value_2(arm), S) for arm in Arm_Trajectory_UCB])
                regret_UCB[i, t] = cumulative_regret_UCB

            # SELECT - Uniform
            LB2.Clear_History()
            Reward_Trajectory, Arm_Trajectory = LB2.Algorithm_SAT_Explore(T=T, S=S, dim=2, method="uniform")
            regret_Trajectory = [S - min(LB2.Get_Real_Value_2(arm), S) for arm in Arm_Trajectory]
            cumulative_regret_SELECT = np.cumsum(regret_Trajectory)
            regret_SELECT[i] = cumulative_regret_SELECT[np.array(time_budget) - 1]
        return regret_SELECT, regret_SatUCB, regret_SatUCB_plus, regret_UCB

    else:
        print("\n##### Non-satisficing #####")
        for i in range(round_num):
            print(f"##### Round {i} #####")
            for t in range(len(time_budget)):
                T = time_budget[t]
                # Sat-UCB+
                LB2.Clear_History()
                Reward_Trajectory_SatUCB_plus, Arm_Trajectory_SatUCB_plus = LB2.Algorithm_Sat_UCB_Uniform(T=T, S=S, dim=2, plus=True)
                cumulative_regret_SatUCB_plus = np.sum([1 - LB2.Get_Real_Value_2(arm) for arm in Arm_Trajectory_SatUCB_plus])
                regret_SatUCB_plus[i, t] = cumulative_regret_SatUCB_plus

                # Sat-UCB
                LB2.Clear_History()
                Reward_Trajectory_SatUCB, Arm_Trajectory_SatUCB = LB2.Algorithm_Sat_UCB_Uniform(T=T, S=S, dim=2, plus=False)
                cumulative_regret_SatUCB = np.sum([1 - LB2.Get_Real_Value_2(arm) for arm in Arm_Trajectory_SatUCB])
                regret_SatUCB[i, t] = cumulative_regret_SatUCB

                # Uniform Discretize
                Reward_Trajectory_UCB, Arm_Trajectory_UCB = LB2.Algorithm_Uniform(T=T, S=S, dim=2, half=False)
                cumulative_regret_UCB = np.sum([1 - LB2.Get_Real_Value_2(arm) for arm in Arm_Trajectory_UCB])
                regret_UCB[i, t] = cumulative_regret_UCB

            # SELECT - Uniform
            LB2.Clear_History()
            Reward_Trajectory, Arm_Trajectory = LB2.Algorithm_SAT_Explore(T=T, S=S, dim=2, method="uniform")
            regret_Trajectory = [1 - LB2.Get_Real_Value_2(arm) for arm in Arm_Trajectory]
            cumulative_regret_SELECT = np.cumsum(regret_Trajectory)
            regret_SELECT[i] = cumulative_regret_SELECT[np.array(time_budget) - 1]
        return regret_SELECT, regret_SatUCB, regret_SatUCB_plus, regret_UCB


if __name__ == '__main__':
    ##### Choose satisficing and repeat number #####
    SAT = True  # True for satisficing regret, False for standard regret
    round_num = 1000  # Number of rounds that experiment repeats to average
    ################################################

    ##### Choose the bandit to run #####
    # K Arm Bandit
    regret_SELECT, regret_SatUCB, regret_SatUCB_plus, regret_Oracle = Test_K_Arm(round_num=round_num, SAT=SAT)

    # Lipschitz Bandit
    # regret_SELECT, regret_SatUCB, regret_SatUCB_plus, regret_Oracle = Test_Lipschitz_Bandit_2d(round_num=round_num, SAT=SAT)

    # Concave Bandit
    # regret_SELECT, regret_SatUCB, regret_SatUCB_plus, regret_Oracle = Test_Concave_Bandit(round_num=round_num, SAT=SAT)
    #####################################


    ##### Plot the results #####
    Total_Budget = 5000
    time_budget = [i * 500 for i in range(1, int(Total_Budget / 500) + 1)]
    std_SELECT = 2 * np.std(regret_SELECT, axis=0) / np.sqrt(1000)
    regret_SELECT = np.mean(regret_SELECT, axis=0)
    std_SatUCB = 2 * np.std(regret_SatUCB, axis=0) / np.sqrt(1000)
    regret_SatUCB = np.mean(regret_SatUCB, axis=0)
    std_SatUCB_plus = 2 * np.std(regret_SatUCB_plus, axis=0) / np.sqrt(1000)
    regret_SatUCB_plus = np.mean(regret_SatUCB_plus, axis=0)
    std_UCB = 2 * np.std(regret_Oracle, axis=0) / np.sqrt(1000)
    regret_Oracle = np.mean(regret_Oracle, axis=0)

    plt.figure(figsize=(9, 5))

    # Plot the mean curve with error bars
    plt.errorbar([0] + time_budget, np.insert(regret_SELECT, 0, 0),      yerr=np.insert(std_SELECT, 0, 0),
                 fmt='-', color='r', ecolor='r', alpha=0.5, elinewidth=1, capsize=5, linewidth=3)

    plt.errorbar([0] + time_budget, np.insert(regret_SatUCB, 0, 0),      yerr=np.insert(std_SatUCB, 0, 0),
                 fmt=':', alpha=0.5, elinewidth=1, capsize=5, linewidth=3)

    plt.errorbar([0] + time_budget, np.insert(regret_SatUCB_plus, 0, 0), yerr=np.insert(std_SatUCB_plus, 0, 0),
                 fmt='--', alpha=0.5, elinewidth=1, capsize=5, linewidth=3)

    plt.errorbar([0] + time_budget, np.insert(regret_Oracle, 0, 0),         yerr=np.insert(std_UCB, 0, 0),
                 fmt='-.', alpha=0.5, elinewidth=1, capsize=5, linewidth=3)

    plt.plot([0] + time_budget, np.insert(regret_SELECT, 0, 0), color='r', linestyle='-', label='SELECT', linewidth=3)
    plt.plot([0] + time_budget, np.insert(regret_SatUCB, 0, 0), color='#1f77b4', linestyle=':', label='SAT-UCB', linewidth=3)
    plt.plot([0] + time_budget, np.insert(regret_SatUCB_plus, 0, 0), color='#ff7f0e', linestyle='--', label='SAT-UCB+', linewidth=3)
    plt.plot([0] + time_budget, np.insert(regret_Oracle, 0, 0), color='#2ca02c', linestyle='-.', label='Oracle', linewidth=3)

    # Adding legend
    plt.legend(fontsize=14, loc='upper left')
    plt.tick_params(axis='both', which='major', labelsize=14)
    plt.xlabel("Time horizon", fontsize=18)
    plt.xticks([i * 1000 for i in range(int(Total_Budget / 1000) + 1)])
    plt.ylabel("Regret", fontsize=18)
    plt.show()


