import numpy as np
import random
import math
import matplotlib.pyplot as plt


class K_Arm_Bandit:
    def __init__(self, data_dict=None):
        self.keys = ["K", "true_mean", "noise_var", "T"]
        self.data_dict = dict.fromkeys(self.keys, None)
        if data_dict is not None:
            for k, v in data_dict.items():
                if k in self.keys:
                    self.data_dict[k] = v
        self.chosen_number = None
        self.empirical_mean = None
        self.UCB = None
        self.LCB = None

    def Update_Parameters(self, **kwargs):
        # update any parameters passed through kwargs
        for k, v in kwargs.items():
            if k in self.keys:
                self.data_dict[k] = v


    def Plot_True_Mean(self, S=0.5):
        keys = ["K", "true_mean"]
        K, true_mean = [self.data_dict.get(key) for key in keys]

        # Sorting the true_mean vector from lowest to highest
        sorted_true_mean = np.sort(true_mean)

        # Plotting
        plt.figure(figsize=(10, 6))
        plt.plot(sorted_true_mean, marker='o', label='Expected rewards')
        plt.axhline(y=S, color='r', linestyle='-', label=f'Satisficing S={S:.3f}')

        # Adding legend, title, and labels
        plt.legend()
        plt.title(f'Expected Rewards for {K:.0f} arms with Satisficing {S:.3f}')
        plt.xlabel('Arms')
        plt.ylabel('Rewards')

        # Show plot
        plt.show()


    def Sample_Value(self, index):
        keys = ["K", "true_mean", "noise_var"]
        K, true_mean, noise_var = [self.data_dict.get(key) for key in keys]
        if self.chosen_number is None:
            self.chosen_number = np.zeros(K)
        if self.empirical_mean is None:
            self.empirical_mean = np.zeros(K)

        self.chosen_number[index] += 1
        num = self.chosen_number[index]  # Number of time index is pulled, with this pull included
        observed_value = np.random.normal(true_mean[index], np.sqrt(noise_var))
        self.empirical_mean[index] = self.empirical_mean[index] * (num - 1) / num + observed_value / num
        return observed_value


    def Confidence_Radias(self, index):
        T = self.data_dict["T"]
        return np.sqrt(2 * np.log(T) / self.chosen_number[index])


    def Update_Confidence_Bounds(self, index=None):
        """
        Update confidence for one or all arms
        :param index: update specific arm; update all arms if index=None
        """
        keys = ["K", "true_mean", "noise_var"]
        K, true_mean, noise_var = [self.data_dict.get(key) for key in keys]

        empirical_mean = self.empirical_mean
        if index is None:
            # Update for all arms
            self.UCB = np.array([empirical_mean[i] + self.Confidence_Radias(i) for i in range(K)])
            self.LCB = np.array([empirical_mean[i] - self.Confidence_Radias(i) for i in range(K)])
        else:
            # Update LUCB for a specific arm
            self.UCB[index] = empirical_mean[index] + self.Confidence_Radias(index)
            self.LCB[index] = empirical_mean[index] - self.Confidence_Radias(index)


    def Get_Best_Arm(self):
        """
        :return: the expected reward of best arm
        """
        keys = ["true_mean"]
        true_mean = [self.data_dict.get(key) for key in keys]
        return np.max(true_mean)


    def Algorithm_UCB(self, T=None):
        if T is not None:
            self.Update_Parameters(T=T)

        keys = ["K", "T", "true_mean"]
        K, T, true_mean = [self.data_dict.get(key) for key in keys]
        reward_trajectory = []
        arm_trajectory = []

        # 1. First choose each arm once
        if T <= K:
            for arm in range(T):
                reward_trajectory.append(self.Sample_Value(arm))
                arm_trajectory.append(arm)
            return reward_trajectory, arm_trajectory
        else:
            for arm in range(K):
                reward_trajectory.append(self.Sample_Value(arm))
                arm_trajectory.append(arm)

        # 2. Pull arm = argmax UCB
        self.Update_Confidence_Bounds()  # Update UCB for all arms
        for t in range(T - K):
            arm = np.argmax(self.UCB)
            reward_trajectory.append(self.Sample_Value(arm))
            arm_trajectory.append(arm)
            self.Update_Confidence_Bounds(arm)

        # 3. Return total reward and arm trajectory
        return reward_trajectory, arm_trajectory


    def Clear_History(self):
        self.chosen_number = None
        self.empirical_mean = None
        self.UCB = None
        self.LCB = None

    def Algorithm_SAT_Explore(self, T=None, S=0.5):
        if T is not None:
            self.Update_Parameters(T=T)

        keys = ["K", "T", "true_mean"]
        K, T, true_mean = [self.data_dict.get(key) for key in keys]

        alpha, beta = 0.5, 0.5

        t, i = 0, 0
        Total_Reward = []
        Total_Arm = []
        while t < T:
            i += 1  # i = 1,2,3...

            # 1. Use ALG to get arm trajectory
            t_i = math.ceil(1.2 ** (i / alpha) * K * 1.7)
            # self.Clear_History()  # So that history of round i will not influence round i+1
            explore_reward_t, arm_trajectory = self.Algorithm_TS(T=min(t_i, T - t))
            t += t_i
            Total_Arm += arm_trajectory
            if t >= T:  # Check for a stop
                Total_Reward += explore_reward_t
                break

            arm_pulled, count = np.unique(arm_trajectory, return_counts=True)
            index = np.argmax(count)
            arm = arm_pulled[index]
            # arm = np.argmax(self.empirical_mean)
            # 2. Repeat chosen arm for T_i times
            # arm = random.choice(arm_trajectory[t_i//2:])  # Sample the repetitive arm from the trajectory
            # arm = np.argmax(self.chosen_number)
            # portion = np.zeros(K)
            # for k in range(K):
            #    mean = self.empirical_mean[k] * self.chosen_number[k] / (self.chosen_number[k] + 1)
            #    std = np.sqrt(1 / (self.chosen_number[k] + 1))
            #    portion[k] = (mean - S) / std
            # arm = np.argmax(portion)
            T_i = math.ceil(1.2 ** (2 * i * ((1 - alpha) / alpha)) * K * 1.8)
            explore_reward_arm = []
            for _ in range(min(T_i, T - t)):
                explore_reward_arm.append(self.Sample_Value(arm))
                Total_Arm.append(arm)
            t += T_i
            if t >= T:
                Total_Reward += (explore_reward_t + explore_reward_arm)
                break

            # 3. Repeat while LCB >= S
            k = 0
            while sum(explore_reward_arm) / T_i / (k + 1) - 0.25 * np.sqrt(np.log(k + 1) / T_i / (k + 1)) >= S:
                # explore_reward_arm.append(self.Sample_Value(arm))
                # Total_Arm.append(arm)
                for _ in range(T_i):
                    explore_reward_arm.append(self.Sample_Value(arm))
                    Total_Arm.append(arm)
                    t += 1
                k += 1
                # t += 1
                if t >= T:
                    break  # Note that outer loop will also break due to t>=T

            # 4. Calculate total reward in round i
            Total_Reward += (explore_reward_t + explore_reward_arm)

        return Total_Reward, Total_Arm

    def Algorithm_TS(self, T=None):
        if T is not None:
            self.Update_Parameters(T=T)

        keys = ["K", "T", "true_mean"]
        K, T, true_mean = [self.data_dict.get(key) for key in keys]
        reward_trajectory = []
        arm_trajectory = []

        # 1. Choose arm 1 at the first step
        reward_trajectory.append(self.Sample_Value(0))
        arm_trajectory.append(0)

        # 2. Thompson Sampling
        r_sample = np.zeros(K)
        for t in range(1, T):
            for k in range(K):
                if self.chosen_number[k] == 0:
                    r_sample[k] = np.random.normal(0, 1)
                else:
                    mean = (self.empirical_mean[k] * self.chosen_number[k]) / (self.chosen_number[k] + 1)
                    std = np.sqrt(1 / (self.chosen_number[k] + 1))
                    r_sample[k] = np.random.normal(mean, std)
            arm = np.argmax(r_sample)
            reward_trajectory.append(self.Sample_Value(arm))
            arm_trajectory.append(arm)

        # 3. Return total reward and arm trajectory
        return reward_trajectory, arm_trajectory


    def Algorithm_Sat_UCB(self, T=None, S=0.5, plus=False):
        if T is not None:
            self.Update_Parameters(T=T)

        keys = ["K", "T", "true_mean"]
        K, T, true_mean = [self.data_dict.get(key) for key in keys]

        reward_trajectory = []
        arm_trajectory = []

        # 1. First choose each arm once
        if T <= K:
            for arm in range(T):
                reward_trajectory.append(self.Sample_Value(arm))
                arm_trajectory.append(arm)
            return reward_trajectory, arm_trajectory
        else:
            for arm in range(K):
                reward_trajectory.append(self.Sample_Value(arm))
                arm_trajectory.append(arm)

        # 2. Keep choosing the "best" arm
        self.Update_Confidence_Bounds()  # Update UCB for all arms
        for t in range(T - K):
            # Get the arm
            if plus:
                # Sat-UCB+
                if max(self.empirical_mean) >= S:
                    arm = np.argmax([(self.UCB[i] - max(S, self.LCB[i])) / self.Confidence_Radias(i) for i in range(K)])
                else:
                    arm = np.argmax(self.UCB)
            else:
                # Sat-UCB
                empirical = self.empirical_mean
                UCB = self.UCB
                if max(empirical) >= S:
                    sat_index = [i for i in range(len(empirical)) if empirical[i] >= S]
                    arm = random.choice(sat_index)
                elif max(UCB) >= S:
                    sat_index = [i for i in range(len(UCB)) if UCB[i] >= S]
                    arm = random.choice(sat_index)
                else:
                    arm = np.argmax(UCB)
            # Sample and update CB
            reward_trajectory.append(self.Sample_Value(arm))
            arm_trajectory.append(arm)
            self.Update_Confidence_Bounds(arm)

        return reward_trajectory, arm_trajectory



