import numpy as np
import random
import matplotlib.pyplot as plt
import math

class Convex_Bandit:
    def __init__(self, data_dict=None):
        self.keys = ["best_arm", "coeff", "noise_var", "T", "K", "true_mean"]
        self.data_dict = dict.fromkeys(self.keys, None)
        if data_dict is not None:
            for k, v in data_dict.items():
                if k in self.keys:
                    self.data_dict[k] = v
        self.chosen_number = {}
        self.empirical_mean = {}
        self.segment = (0, 1)
        self.i = 0


    def Update_Parameters(self, **kwargs):
        # update any parameters passed through kwargs
        for k, v in kwargs.items():
            if k in self.keys:
                self.data_dict[k] = v


    def Get_Random_Function(self):
        """
        Randomly pick best arm between [0,1] and coeff between [1,2]
        The function is f(x) = coeff * (x - best_arm)^2
        :return:
        """
        best_arm = random.random()
        coeff = 1 + random.random()
        self.Update_Parameters(best_arm=best_arm)
        self.Update_Parameters(coeff=coeff)


    def Get_Real_Value(self, arm):
        keys = ["best_arm", "coeff"]
        best_arm, coeff = [self.data_dict.get(key) for key in keys]
        r = coeff * (arm - best_arm)**2
        return coeff * (arm - best_arm)**2



    def Plot_True_Mean(self, S=0.5):
        """
        This is the function to plot the true mean
        :param S:
        :return:
        """
        keys = ["best_arm", "coeff"]
        best_arm, coeff = [self.data_dict.get(key) for key in keys]

        # Plotting
        plt.figure(figsize=(7, 4))
        x = np.linspace(0, 1, 1000)
        y = 1 - self.Get_Real_Value(x)
        plt.figure(figsize=(9, 5))
        plt.plot(x, y, label='Expected rewards')
        plt.axhline(y=1-S, color='r', linestyle='-', label=f'Satisficing S={1-S:.3f}')

        # Adding legend, title, and labels
        plt.legend()
        plt.xlabel('Arms in [0,1]')
        plt.xticks([0,0.2,0.4,0.6,0.8,1])
        plt.ylabel('Expected rewards')

        # Show plot
        plt.savefig(f"./figures/Arm_Distribution_Convex.png")
        #plt.show()


    def Sample_Value(self, arm):
        noise_var = self.data_dict['noise_var']

        if arm not in self.chosen_number:
            self.chosen_number[arm] = 1
        else:
            self.chosen_number[arm] += 1
        num = self.chosen_number[arm]  # Number of time arm is pulled, with this pull included

        observed_value = np.random.normal(self.Get_Real_Value(arm), np.sqrt(noise_var))
        if arm not in self.empirical_mean:
            self.empirical_mean[arm] = observed_value
        else:
            self.empirical_mean[arm] = self.empirical_mean[arm] * (num-1)/num + observed_value / num
        return observed_value


    def Clear_History(self):
        self.chosen_number = {}
        self.empirical_mean = {}
        self.segment = (0, 1)
        self.i = 0


    def Algorithm_Convex(self, T=None):
        """
        This implements the algorithm for convex bandit
        https://arxiv.org/abs/1107.1744
        :param T:
        :return: reward_trajectory, arm_trajectory
        """
        if T is not None:
            self.Update_Parameters(T=T)
        keys = ["T", "noise_var", "lip_coeff"]
        T, noise_var, lip_coeff = [self.data_dict.get(key) for key in keys]

        reward_trajectory = []
        arm_trajectory = []
        t, epoch = 0, 0
        # l, r = 0, 1
        # i = 0
        l, r = self.segment[0], self.segment[1]
        while t < T:
            epoch += 1
            w = r - l
            x_l, x_c, x_r = l + w/4, l + w/2, l + 3*w/4
            while t < T:
                self.i += 1
                gamma = 2 ** (-self.i)

                # 1. Explore
                iterate_num = math.ceil(np.sqrt(noise_var) * np.log(T) / (2 * gamma ** 2))
                if t + iterate_num * 3 >= T:
                    # Budget T will be exhausted
                    iterate_num = math.floor((T - t) / 3)
                    for _ in range(T - t - iterate_num*3):
                        reward_trajectory.append(self.Sample_Value(x_l))
                        arm_trajectory.append(x_l)
                        t += 1
                for _ in range(iterate_num):
                    reward_trajectory.append(self.Sample_Value(x_l))
                    reward_trajectory.append(self.Sample_Value(x_c))
                    reward_trajectory.append(self.Sample_Value(x_r))
                    arm_trajectory += [x_l, x_c, x_r]
                    t += 3

                # 2. Check confidence bounds
                if t >= T:
                    self.i -= 1
                    break

                empirical_mean = self.empirical_mean
                if max(empirical_mean[x_l], empirical_mean[x_r]) > min(empirical_mean[x_l], empirical_mean[x_r]) + 3 * gamma:
                    # Case 1: CI’s at x_l and x_r are gamma-separated
                    if empirical_mean[x_l] >= empirical_mean[x_r]:
                        l = x_l
                    else:
                        r = x_r
                    break
                elif max(empirical_mean[x_l], empirical_mean[x_r]) > empirical_mean[x_c] + 3 * gamma:
                    # Case 2: CI’s at x_c and x_l or x_r are gamma-separated
                    if empirical_mean[x_l] >= empirical_mean[x_r]:
                        l = x_l
                    else:
                        r = x_r
                    break

        self.segment = (l, r)
        return reward_trajectory, arm_trajectory


    def Algorithm_SAT_Explore(self, T=None, S=0.5):
        """
        Our Algorithm
        :param T:
        :param S:
        :return:
        """
        if T is not None:
            self.Update_Parameters(T=T)

        keys = ["T", "coeff"]
        T, coeff = [self.data_dict.get(key) for key in keys]
        alpha, beta = 0.5, 2

        t, i = 0, 0
        Total_Reward = []
        Total_Arm = []
        # print(f"Satisficing threshold is {1-S:.4f}")
        while t < T:
            i += 1  # i = 1,2,3...

            # 1. Use ALG to get arm trajectory
            t_i = math.ceil(2 ** (i / alpha))
            # self.Clear_History()  # So that history of round i will not influence round i+1
            explore_reward_t, arm_trajectory = self.Algorithm_Convex(T=min(t_i, T - t))
            t += t_i
            Total_Arm += list(arm_trajectory)
            if t >= T:  # Check for a stop
                Total_Reward += explore_reward_t
                break

            # 2. Repeat chosen arm for T_i times
            # arm = random.choice(arm_trajectory)  # Sample the repetitive arm from the trajectory
            arm = min(self.empirical_mean, key=self.empirical_mean.get)

            T_i = math.ceil(2 ** (2 * i * (1 - alpha) / alpha))
            explore_reward_arm = []
            for _ in range(min(T_i, T - t)):
                explore_reward_arm.append(self.Sample_Value(arm))
                Total_Arm.append(arm)
            t += T_i
            if t >= T:
                Total_Reward += (explore_reward_t + explore_reward_arm)
                # print("Explore for " + str(len(explore_reward_t)) + " rounds")
                # print(f"Repeat sampled arm with reward {1-self.Get_Real_Value(arm):.4f} for {len(explore_reward_arm)} rounds")
                break

            # 3. Repeat while LCB >= S
            k = 0
            # while sum(explore_reward_arm) / (T_i + k) + np.sqrt(2 * np.log(k / T_i + 2) / (T_i + k)) <= S:
            while self.empirical_mean[arm] + np.sqrt(np.log(self.chosen_number[arm]) / self.chosen_number[arm]) <= S:
                explore_reward_arm.append(self.Sample_Value(arm))
                Total_Arm.append(arm)
                k += 1
                t += 1
                if t >= T:
                    break  # Note that outer loop will also break due to t>=T

            # 4. Calculate total reward in round i
            Total_Reward += (explore_reward_t + explore_reward_arm)

        #     print("Explore for " + str(len(explore_reward_t)) + " rounds")
        #     print(f"Repeat sampled arm with reward {1-self.Get_Real_Value(arm):.4f} for {len(explore_reward_arm)} rounds")
        # print("----END----")
        return Total_Reward, Total_Arm
