import numpy as np
import random
from K_Arm_Bandit import K_Arm_Bandit
import matplotlib.pyplot as plt
import math


def divide_number(a, K):
    # Generate K-1 random cut points within the interval from 0 to a
    cut_points = np.random.random(K) * a
    cut_points = np.sort(cut_points)  # Sort these points to maintain the order

    # Include 0 and a as the start and end points
    full_range = np.concatenate(([0], cut_points, [a]))

    # Calculate the differences between consecutive points
    random_parts = np.diff(full_range)

    return np.cumsum(random_parts)[:-1]

class Lipschitz_Bandit:
    def __init__(self, data_dict=None):
        self.keys = ["K", "lip_coeff", "true_mean", "noise_var", "T", "centers", "min_max", "coeff"]
        self.data_dict = dict.fromkeys(self.keys, None)
        if data_dict is not None:
            for k, v in data_dict.items():
                if k in self.keys:
                    self.data_dict[k] = v
        self.chosen_number = {}
        self.empirical_mean = {}
        self.UCB = {}
        self.LCB = {}


    def Update_Parameters(self, **kwargs):
        # update any parameters passed through kwargs
        for k, v in kwargs.items():
            if k in self.keys:
                self.data_dict[k] = v


    def Get_Random_True_Value(self, lip_coeff=None):
        """
        Randomly set discretization based on lip_coeff and accuracy
        Consider (K+1)-arm (K is very large), accuracy = 1/K
        r(arm + accuracy) \in r(arm) \pm accuracy*lip_coeff
        :param lip_coeff:
        :return:
        """
        if lip_coeff is not None:
            self.Update_Parameters(lip_coeff=lip_coeff)
        keys = ["K", "lip_coeff"]
        K, lip_coeff = [self.data_dict.get(key) for key in keys]

        true_mean = np.zeros(K+1)
        true_mean[0] = 0.5
        for i in range(K):
            delta = random.uniform(-lip_coeff, lip_coeff)
            next_mean = true_mean[i] + delta / K
            if next_mean >= 1:
                true_mean[i+1] = 1
            elif next_mean <= 0:
                true_mean[i+1] = 0
            else:
                true_mean[i+1] = next_mean
        self.Update_Parameters(true_mean=true_mean)
        return true_mean


    def Get_Lipschitz_Coefficient(self):
        """
        Although we initialize true value with input lip_coeff,
        real lip_coeff is smaller than the input one
        :return: real lip_coeff
        """
        keys = ["K", "true_mean"]
        K, true_mean = [self.data_dict.get(key) for key in keys]

        max_lip = 0
        for i in range(K):
            lip = (true_mean[i+1] - true_mean[i]) * K
            if lip > max_lip:
                max_lip = lip
        return max_lip


    def Get_Real_Value(self, arm):
        """
        Get real value of the arm for 1-dim bandit
        :param arm:
        :return:
        """
        keys = ["K", "true_mean"]
        K, true_mean = [self.data_dict.get(key) for key in keys]

        i = int(arm * K)  # i/K <= arm < (i+1)/K
        if i == K:
            return true_mean[i]
        return true_mean[i] + (true_mean[i+1] - true_mean[i]) * (arm - i/K) * K


    def Plot_True_Mean(self, S=0.5, setting=None):
        keys = ["K", "true_mean"]
        K, true_mean = [self.data_dict.get(key) for key in keys]

        # Plotting
        plt.clf()
        plt.figure(figsize=(9, 5))
        plt.plot(true_mean, label='Expected rewards')
        plt.axhline(y=S, color='r', linestyle='-', label=f'Satisficing S={S:.3f}')

        # Adding legend, title, and labels
        plt.legend()
        #plt.title(f'Expected reward distribution for {self.Get_Lipschitz_Coefficient():.2f} Lipschitz bandit with Satisficing {S:.3f}')
        plt.xlabel('Arms in [0,1]')
        plt.xticks([])
        plt.ylabel('Expected rewards')

        # Show plot
        if setting is None:
            plt.savefig(f"./figures/Arm_Distribution_Lip.png")
        else:
            plt.savefig(f"./figures/Arm_Distribution_Lip_{setting}.png")
        #plt.show()


    def Plot_True_Mean_2(self, S=0.5):
        keys = ["centers", "noise_var", "min_max", "coeff"]
        centers, noise_var, min_max, coeff = [self.data_dict.get(key) for key in keys]

        # Define the x and y range
        x, y = np.linspace(0, 1, 400), np.linspace(0, 1, 400)
        X, Y = np.meshgrid(x, y)

        Z = 0
        for center in centers:
            Z += coeff * np.exp(-100 * (X - center[0]) ** 2 - 100 * (Y - center[1]) ** 2)
        Z = np.clip((Z - Z.min()), 0, 1)

        # Plot using 3D axes
        fig = plt.figure(figsize=(8, 6))
        ax = fig.add_subplot(111, projection='3d')
        ax.plot_surface(X, Y, Z, cmap='viridis')

        # Labels and title
        ax.set_xlabel('X axis')
        ax.set_ylabel('Y axis')
        ax.set_zlabel('Z axis')

        # Show the plot
        # plt.show()
        plt.savefig(f"./figures/Two-Dim-Arm-Distribution.png")


    def Sample_Value(self, arm, dim=1):
        keys = ["K", "true_mean", "noise_var"]
        K, true_mean, noise_var = [self.data_dict.get(key) for key in keys]

        if arm not in self.chosen_number:
            self.chosen_number[arm] = 1
        else:
            self.chosen_number[arm] += 1
        num = self.chosen_number[arm]  # Number of time arm is pulled, with this pull included

        if dim == 1:
            observed_value = np.random.normal(self.Get_Real_Value(arm), np.sqrt(noise_var))
        else:
            observed_value = np.random.normal(self.Get_Real_Value_2(arm), np.sqrt(noise_var))

        if arm not in self.empirical_mean:
            self.empirical_mean[arm] = observed_value
        else:
            self.empirical_mean[arm] = self.empirical_mean[arm] * (num-1)/num + observed_value / num
        return observed_value


    def Confidence_Radias(self, index):
        T = self.data_dict["T"]
        return np.sqrt(2 * np.log(T) / (self.chosen_number[index]))


    def Update_Confidence_Bounds(self, index=None):
        """
        Update confidence for one or all arms
        :param index: update specific arm; update all arms if index=None
        """
        keys = ["K", "true_mean", "noise_var"]
        K, true_mean, noise_var = [self.data_dict.get(key) for key in keys]

        empirical_mean = self.empirical_mean

        if index is None:
            # Update for all arms
            self.UCB = {key: empirical_mean[key] + self.Confidence_Radias(key) for key in empirical_mean}
            self.LCB = {key: empirical_mean[key] - self.Confidence_Radias(key) for key in empirical_mean}
        else:
            # Update LUCB for a specific arm
            self.UCB[index] = empirical_mean[index] + self.Confidence_Radias(index)
            self.LCB[index] = empirical_mean[index] - self.Confidence_Radias(index)


    def Get_Uncovered_Intervals(self, centers, radii):
        """
        Find uncovered intervals
        :param centers: activated arms
        :param radii: confidence radius of activated arms
        :return: a list of uncovered intervals. e.g. [(0,0.2), (0.5,0.6)]
        """
        if not centers:
            return [(0, 1)]
        # Sort intervals according to center-radius
        intervals = [(max(0, x - r), min(1, x + r)) for x, r in zip(centers, radii)]
        intervals.sort(key=lambda x: (x[0], -x[1]))

        # Merge intervals
        merged_intervals = []
        for start, end in intervals:
            if merged_intervals and start <= merged_intervals[-1][1]:
                merged_intervals[-1] = (merged_intervals[-1][0], max(end, merged_intervals[-1][1]))
            else:
                merged_intervals.append((start, end))

        # Get uncovered intervals
        uncovered_intervals = []
        prev_end = 0
        for start, end in merged_intervals:
            if start > prev_end:
                uncovered_intervals.append((prev_end, start))
            prev_end = end
        if prev_end < 1:
            uncovered_intervals.append((prev_end, 1))

        return uncovered_intervals

    def Get_Uncovered_Intervals_2(self, centers, radii):
        """
        This works for 2-dimension [0,1]X[0,1] lipschitz bandit.
        Find a arm that is not covered by active arm sets
        :param centers: list of sampled arms
        :param radii: list of confidence radius of each sampled arms
        :return: None if covered; A random arm from uncovered set if not covered
        """
        if not centers:
            return (0.5, 0.5)
        grid_size = int(10 / min(radii))  # Increase the denominator to increase the resolution
        grid = np.linspace(0, 1, grid_size)

        # Initialize boolean matrix to mark covered points
        coverage_matrix = np.full((grid_size, grid_size), False)

        # Mark the covered squares in the boolean matrix
        for center, rad in zip(centers, radii):
            x_start = max(0, int((center[0] - rad) * grid_size))
            x_end = min(grid_size, int((center[0] + rad) * grid_size) + 1)
            y_start = max(0, int((center[1] - rad) * grid_size))
            y_end = min(grid_size, int((center[1] + rad) * grid_size) + 1)

            coverage_matrix[y_start:y_end, x_start:x_end] = True

        # Find the first uncovered point
        uncovered_indices = np.argwhere(coverage_matrix == False)
        if uncovered_indices.size > 0:
            # Pick a random uncovered index if there are multiple
            random_uncovered_index = uncovered_indices[np.random.choice(uncovered_indices.shape[0])]
            return (grid[random_uncovered_index[1]], grid[random_uncovered_index[0]])
        else:
            return None

    def Algorithm_Zooming(self, T=None, dim=1):
        """
        Adaptive discretization for lipschitz bandit
        :param T:
        :param dim:
        :return:
        """
        if T is not None:
            self.Update_Parameters(T=T)

        keys = ["T", "lip_coeff"]
        T, lip_coeff = [self.data_dict.get(key) for key in keys]
        reward_trajectory = []
        arm_trajectory = []

        for t in range(T):
            arm_list = list(self.empirical_mean.keys())
            if dim == 1:
                uncovered_intervals = self.Get_Uncovered_Intervals(arm_list, [self.Confidence_Radias(arm) / lip_coeff for arm in arm_list])
            else:
                # Note that in dim=2, Get_Uncovered_Intervals_2 returns an arm
                # This is just to ensure consistence in checking whether uncovered set is found or not
                uncovered_intervals = self.Get_Uncovered_Intervals_2(arm_list, [self.Confidence_Radias(arm) / lip_coeff for arm in arm_list])
            if uncovered_intervals:
                # Exist uncovered intervals:
                if dim == 1:
                    arm = random.uniform(uncovered_intervals[0][0], uncovered_intervals[0][1])
                else:
                    arm = uncovered_intervals
                arm_trajectory.append(arm)
                reward_trajectory.append(self.Sample_Value(arm, dim=dim))
                self.Update_Confidence_Bounds(arm)
            else:
                # No uncovered_intervals
                arm = max(self.UCB, key=self.UCB.get)
                arm_trajectory.append(arm)
                reward_trajectory.append(self.Sample_Value(arm, dim=dim))
                self.Update_Confidence_Bounds(arm)

        return reward_trajectory, arm_trajectory


    def Get_Mesh_Size(self, half=True):
        """
        Get mesh size based on
        epsilon * T + C * sqrt(|S| * T * logT)
        Consider 1-dim lipschitz bandit
        :return:
        """
        keys = ["T", "true_mean", "lip_coeff"]
        T, true_mean, lip_coeff = [self.data_dict.get(key) for key in keys]

        const = np.sqrt(8 * np.log(T) / T) / lip_coeff
        if half:
            epsilon = 0.5
            while epsilon / np.sqrt(0.5/epsilon + 1) >= const:
                epsilon /= 2
            return 2*epsilon
        else:
            l, c, r = 0, 0.5, 1
            while abs(l-r) >= 0.001 * const:
                if c / np.sqrt(0.5/c + 1) > const:
                    r = c
                else:
                    l = c
                c = (l + r) / 2
            return c


    def Get_Mesh_Size_2(self, half=True):
        """
        Get mesh size based on
        epsilon * T + C * sqrt(|S| * T * logT)
        Consider 2-dim lipschitz bandit
        :return:
        """
        keys = ["T", "lip_coeff"]
        T, lip_coeff = [self.data_dict.get(key) for key in keys]

        const = np.sqrt(8 * np.log(T) / T) / lip_coeff
        if half:
            epsilon = 0.5
            while epsilon / (0.5/epsilon + 1) >= const:
                epsilon /= 2
            return 2*epsilon
        else:
            l, c, r = 0, 0.5, 1
            while abs(l-r) >= 0.001 * const:
                if c / (0.5/c + 1) > const:
                    r = c
                else:
                    l = c
                c = (l + r) / 2
            return c


    def Get_Discretized_K_Arm(self, dim=1):
        """
        Discretize the domain uniformly with got mesh size
        :return:
        """
        keys = ["T", "noise_var"]
        T, noise_var = [self.data_dict.get(key) for key in keys]

        if dim == 1:
            epsilon = self.Get_Mesh_Size(half=False) / 2
            new_K = int(0.5 / epsilon + 1)
            remainder = 1 - (new_K - 1) * epsilon * 2
            noise = divide_number(remainder, new_K)
            true_mean = [self.Get_Real_Value(i * 2 * epsilon + noise[i]) for i in range(new_K)]
            K_Arm = K_Arm_Bandit({"K": new_K, "T": T, "noise_var": noise_var, "true_mean": true_mean})
        else:
            epsilon = self.Get_Mesh_Size_2(half=False) / 2
            new_K = int(0.5 / epsilon + 1)
            remainder = 1 - (new_K - 1) * epsilon * 2
            noise_x = divide_number(remainder, new_K)
            noise_y = divide_number(remainder, new_K)
            noise = (noise_x, noise_y)
            arm_list = [(2 * i * epsilon + noise_x[i], 2 * j * epsilon + noise_y[j]) for i in range(new_K) for j in range(new_K)]
            true_mean = np.array([self.Get_Real_Value_2(arm) for arm in arm_list])
            K_Arm = K_Arm_Bandit({"K": new_K ** 2, "T": T, "noise_var": noise_var, "true_mean": true_mean})
        return K_Arm, epsilon, noise


    def Algorithm_Uniform(self, T=None, S=0.5, dim=1, half=True):
        """
        Uniform discretize for lipschitz bandit
        :param T:
        :param S:
        :param dim:
        :param half: Half epsilon if used in SELECT; Solve epsilon o.w.
        :return:
        """
        if T is not None:
            self.Update_Parameters(T=T)

        T = self.data_dict['T']
        if dim == 1:
            epsilon = self.Get_Mesh_Size(half=half) / 2
            new_K = int(0.5 / epsilon) + 1
            remainder = 1 - (new_K - 1) * epsilon * 2
            noise = divide_number(remainder, new_K)
            arm_list = np.array([2 * i * epsilon for i in range(new_K)]) + np.array(noise)
        else:
            # dimension == 2
            epsilon = self.Get_Mesh_Size_2(half=half)
            if not half:
                epsilon /= 2
                new_K = int(0.5 / epsilon) + 1
                remainder = 1 - (new_K - 1) * epsilon * 2
                noise_x = divide_number(remainder, new_K)
                noise_y = divide_number(remainder, new_K)
                arm_list = [(2 * i * epsilon + noise_x[i], 2 * j * epsilon + noise_y[j]) for i in range(new_K) for j in range(new_K)]
            else:
                new_K = int(0.5 / epsilon) + 1
                arm_list = np.array([(2 * i * epsilon, 2 * j * epsilon) for i in range(new_K) for j in range(new_K)])
        if dim == 1:
            new_arm_list = [item for item in arm_list if item not in self.empirical_mean.keys()]
        else:
            new_arm_list = [(item[0], item[1]) for item in arm_list if (item[0], item[1]) not in self.empirical_mean.keys()]

        reward_trajectory, arm_trajectory = [], []

        if T < len(new_arm_list):
            for t in range(T):
                reward_trajectory.append(self.Sample_Value(new_arm_list[t], dim=dim))
                arm_trajectory.append(new_arm_list[t])
            return reward_trajectory, arm_trajectory
        else:
            for arm in new_arm_list:
                reward_trajectory.append(self.Sample_Value(arm, dim=dim))
                arm_trajectory.append(arm)

        self.Update_Confidence_Bounds()
        for t in range(T - len(new_arm_list)):
            arm = max(self.UCB, key=self.UCB.get)
            reward_trajectory.append(self.Sample_Value(arm, dim=dim))
            arm_trajectory.append(arm)
            self.Update_Confidence_Bounds(arm)

        return reward_trajectory, arm_trajectory


    def Get_Real_Value_2(self, arm):
        """
        Get real value of the arm for 2d case
        :param arm:
        :return:
        """
        keys = ["centers", "noise_var", "min_max", "coeff"]
        centers, noise_var, min_max, coeff = [self.data_dict.get(key) for key in keys]
        x, y = arm[0], arm[1]
        z = 0

        for center in centers:
            # "3" need to changed for different coefficient
            z += coeff * np.exp(-100 * (x - center[0]) ** 2 - 100 * (y - center[1]) ** 2)

        z = np.clip((z - min_max[0]), 0, 1)
        return z


    def Algorithm_Sat_UCB_Uniform(self, T=None, S=0.5, plus=False, dim=1):
        """
        SAT-UCB using uniform discretization
        :param T:
        :param S:
        :param dim:
        :param plus: whether use SAT-UCB or SAT-UCB+
        :return:
        """
        if T is not None:
            self.Update_Parameters(T=T)

        K_Arm, epsilon, noise = self.Get_Discretized_K_Arm(dim=dim)
        reward_trajectory, arm_trajectory = K_Arm.Algorithm_Sat_UCB(S=S, plus=plus)
        if dim == 1:
            arm_trajectory = [2 * epsilon * arm + noise[arm] for arm in arm_trajectory]
            return reward_trajectory, arm_trajectory
        else:
            noise_x, noise_y = noise[0], noise[1]
            arm_list = [(2 * i * epsilon + noise_x[i], 2 * j * epsilon + noise_y[j]) for i in range(int(0.5 / epsilon) + 1) for j in range(int(0.5 / epsilon) + 1)]
            arm_trajectory = [arm_list[i] for i in arm_trajectory]
            return reward_trajectory, arm_trajectory


    def Algorithm_Sat_UCB(self, T=None, S=0.5, dim=1, plus=False):
        """
        SAT-UCB using adaptive discretization
        :param T:
        :param S:
        :param dim:
        :param plus: whether use SAT-UCB or SAT-UCB+
        :return:
        """
        if T is not None:
            self.Update_Parameters(T=T)

        keys = ["T", "true_mean", "lip_coeff"]
        T, true_mean, lip_coeff = [self.data_dict.get(key) for key in keys]
        reward_trajectory = []
        arm_trajectory = []

        for t in range(T):

            arm_list = list(self.empirical_mean.keys())
            if dim == 1:
                uncovered_intervals = self.Get_Uncovered_Intervals(arm_list, [self.Confidence_Radias(arm)/(2 * lip_coeff) for arm in arm_list])
            else:
                # Note that in dim=2, Get_Uncovered_Intervals_2 returns an arm
                # This is just to ensure consistence in checking whether uncovered set is found or not
                uncovered_intervals = self.Get_Uncovered_Intervals_2(arm_list, [self.Confidence_Radias(arm)/(2 * lip_coeff) for arm in arm_list])
            if uncovered_intervals:
                # Exist uncovered intervals:
                if dim == 1:
                    # Randomly pick an arm from uncovered interval
                    arm = random.uniform(uncovered_intervals[0][0], uncovered_intervals[0][1])
                else:
                    # Directly use the result
                    arm = uncovered_intervals
                arm_trajectory.append(arm)
                reward_trajectory.append(self.Sample_Value(arm, dim=dim))
                self.Update_Confidence_Bounds(arm)
            else:
                # No uncovered_intervals
                empirical = self.empirical_mean
                UCB = self.UCB

                if plus:
                    # This is Sat-UCB+
                    if empirical[max(empirical, key=empirical.get)] >= S:
                        heuri_values = {i: (self.UCB[i] - max(S, self.LCB[i])) / self.Confidence_Radias(i) for i in empirical.keys()}
                        arm = max(heuri_values, key=heuri_values.get)
                    else:
                        arm = max(self.UCB, key=self.UCB.get)
                else:
                    # This is Sat-UCB
                    if empirical[max(empirical, key=empirical.get)] >= S:
                        sat_values = {i: empirical[i] for i in empirical.keys() if empirical[i] >= S}
                        arm = random.choice(list(sat_values.keys()))
                    elif UCB[max(UCB, key=UCB.get)] >= S:
                        UCB_values = {i: UCB[i] for i in UCB.keys() if UCB[i] >= S}
                        arm = random.choice(list(UCB_values.keys()))
                    else:
                        arm = max(self.UCB, key=self.UCB.get)

                arm_trajectory.append(arm)
                reward_trajectory.append(self.Sample_Value(arm, dim=dim))
                self.Update_Confidence_Bounds(arm)

        return reward_trajectory, arm_trajectory


    def Clear_History(self):
        self.chosen_number = {}
        self.empirical_mean = {}
        self.UCB = {}
        self.LCB = {}


    def Algorithm_SAT_Explore(self, T=None, S=0.5, dim=1, method="uniform"):
        """
        Our algorithm
        :param T:
        :param S:
        :param dim:
        :param method: Use uniform discretize or adaptive discretize
        :return:
        """
        if T is not None:
            self.Update_Parameters(T=T)

        keys = ["K", "T", "true_mean", "lip_coeff"]
        K, T, true_mean, lip_coeff = [self.data_dict.get(key) for key in keys]

        if dim == 1:
            alpha, beta = 2/3, 1/3
        else:
            alpha, beta = 3/4, 1/4

        t, i = 0, 1
        Total_Reward = []
        Total_Arm = []
        # print(f"Satisficing threshold is {S:.4f}")
        while t < T:
            i += 1  # i = 1,2,3...

            # 1. Use ALG to get arm trajectory
            t_i = math.ceil(2 ** (i / alpha))
            # self.Clear_History()  # So that history of round i will not influence round i+1
            if method == "uniform":
                explore_reward_t, arm_trajectory = self.Algorithm_Uniform(T=min(t_i, T - t), dim=dim)
            else:
                explore_reward_t, arm_trajectory = self.Algorithm_Zooming(T=min(t_i, T-t), dim=dim)
            t += t_i
            Total_Arm += list(arm_trajectory)
            if t >= T:  # Check for a stop
                Total_Reward += explore_reward_t
                break

            # 2. Repeat chosen arm for T_i times
            arm = max(self.empirical_mean, key=self.empirical_mean.get)
            # arm = max(self.chosen_number, key=self.chosen_number.get)

            T_i = math.ceil(2 ** (2 * i * (1 - alpha) / alpha))
            explore_reward_arm = []
            for _ in range(min(T_i, T - t)):
                explore_reward_arm.append(self.Sample_Value(arm, dim=dim))
                Total_Arm.append(arm)
            t += T_i
            if t >= T:
                Total_Reward += (explore_reward_t + explore_reward_arm)
                # print("Explore for " + str(len(explore_reward_t)) + " rounds")
                # print(f"Repeat sampled arm with reward {self.Get_Real_Value(arm):.4f} for {len(explore_reward_arm)} rounds")
                break

            # 3. Repeat while LCB >= S
            k = 0
            while self.empirical_mean[arm] - np.sqrt(np.log(self.chosen_number[arm]) / self.chosen_number[arm]) >= S:
                explore_reward_arm.append(self.Sample_Value(arm, dim=dim))
                Total_Arm.append(arm)
                k += 1
                t += 1
                if t >= T:
                    break  # Note that outer loop will also break due to t>=T

            # 4. Calculate total reward in round i
            Total_Reward += (explore_reward_t + explore_reward_arm)

        #     print("Explore for "+str(len(explore_reward_t))+" rounds")
        #     print(f"Repeat sampled arm with reward {self.Get_Real_Value(arm):.4f} for {len(explore_reward_arm)} rounds")
        # print("----END----")
        return Total_Reward, Total_Arm


    def Get_Best_Arm(self):
        """
        :return: the expected reward of best arm
        """
        keys = ["true_mean"]
        true_mean = [self.data_dict.get(key) for key in keys]
        return np.max(true_mean)




