import numpy as np
from collections import defaultdict
from itertools import product
from scipy.stats import norm, geom, randint


def get_reward(x, y):
    f = 0.7 * ((x - 0.7) ** 2 + (y - 0.8) ** 2) ** 0.5 + 0.4 * ((x - 0) ** 2 + (y - 0.1) ** 2) ** 0.5
    return 1 - f


class Balls:
    def __init__(self, center, radius):
        self.center = center
        self.radius = radius
        self.observed = 0
        self.reward = 0
    
    def split(self):
        ans = []
        for dx, dy in list(product([1, -1], repeat=2)):
            new_radius = self.radius / 2
            new_center = self.center[0] + dx * new_radius, self.center[1] + dy * new_radius
            new_ball = Balls(new_center, new_radius)
            ans.append(new_ball)
        return ans


class DelayedLipsPhasedPruning:
    def __init__(self, mu, mu_max, time_horizon, delta, dim, order):
        self.mu = mu
        self.mu_max = mu_max
        self.time_horizon = time_horizon
        self.delta = delta
        self.dim = dim
        self.order = order
        self.rewards = []
        self.regrets = []
        self.scheduled_rewards = [[] for _ in range(self.time_horizon + 5)]
        self.balls = set(Balls([0.5, 0.5], 0.5).split())
        self.timer = 0
        self.phase_counter = 1

    def get_delay(self):
        #return 0
        return randint.rvs(0, 101)
        # return geom.rvs(1 / 51) - 1
        
    def sample(self, ball):
        x = np.random.uniform(low=ball.center[0] - ball.radius, high=ball.center[0] + ball.radius)
        y = np.random.uniform(low=ball.center[1] - ball.radius, high=ball.center[1] + ball.radius)
        return np.random.normal(self.mu(x, y), 0.01), self.mu(x, y)
    
    def prune(self, threshold):
        mu_hat_max = np.max([ball.reward for ball in self.balls])
        for ball in self.balls.copy():
            if mu_hat_max - ball.reward > 0.25 * threshold:
                self.balls.remove(ball)

        for ball in self.balls.copy():
            new_balls = ball.split()
            self.balls.remove(ball)
            for new_ball in new_balls:
                self.balls.add(new_ball)

        self.phase_counter += 1


    def pull_arms(self, phase):
        exit_number = 0.01 * (8 * np.log(self.time_horizon) + 2 * np.log(2 / self.delta)) / (0.5 ** (2 * phase))
        print(phase)
        print(exit_number)
        balls_to_play = self.balls.copy()
        while balls_to_play and self.timer <= self.time_horizon:
            for ball in balls_to_play.copy():
                if ball not in balls_to_play:
                    continue
                reward, mu_s = self.sample(ball)
                delay = self.get_delay()
                if self.timer + delay <= self.time_horizon:
                    self.scheduled_rewards[self.timer + delay].append((ball, reward))
                self.regrets.append(self.mu_max - mu_s)
                for prev_ball, prev_reward in self.scheduled_rewards[self.timer]:
                    if prev_ball not in balls_to_play:
                        continue
                    prev_ball.reward = (prev_ball.reward * prev_ball.observed + prev_reward) / (prev_ball.observed + 1)
                    prev_ball.observed += 1
                    if prev_ball.observed >= exit_number:
                        balls_to_play.remove(prev_ball)
                self.timer += 1
                if self.timer >= self.time_horizon:
                    break


class DelayedZooming:
    def __init__(self, mu, mu_max, time_horizon, sigma, delta, dim, order):
        self.mu = mu
        self.mu_max = mu_max
        self.time_horizon = time_horizon
        self.sigma = sigma
        self.delta = delta
        self.dim = dim
        self.order = order
        self.active_arms = []
        self.mu_hat = []
        self.conf_radius = []
        self.num_of_pulled = []
        self.ucb = 0
        self.selected_arm_idx = None
        self.num_of_observed = []
        self.rewards = []
        self.regrets = []
        self.scheduled_rewards = [[] for _ in range(self.time_horizon + 5)]
        self.timer = 0

    def get_uncovered_arms(self):
        """
        return the coordinate of a new uncovered arm
        """
        covered_arms = [(center, radius) for center, radius in zip(self.active_arms, self.conf_radius)]
        if len(covered_arms) == 0:
            return np.random.uniform(low=0, high=1, size=self.dim)
        if self.dim == 1:
            intervals = [[x - r, x + r] for x, r in covered_arms]
            intervals.sort(key=lambda x: x[0])
            low = 0
            for interval in intervals:
                if interval[0] <= low:
                    low = max(low, interval[1])
                    if low >= 1:
                        return None
                else:
                    return np.random.uniform(low=low, high=interval[0])
            return np.random.uniform(low=low, high=1)
        if self.dim == 2:
            grids_x, grids_y = np.mgrid[0:1:6j, 0:1:6j]
            grids_pos = np.vstack([grids_x.ravel(), grids_y.ravel()])
            grids = np.array([np.array([x, y]) for x, y in zip(grids_pos[0], grids_pos[1])])
            for grid in grids:
                dist_diff = np.array([np.linalg.norm(center - grid, ord=self.order) - radius
                                      for center, radius in covered_arms])
                if np.all(dist_diff > 0):
                    k = np.argmax(dist_diff)
                    d = np.random.uniform(low=0, high=dist_diff[k])
                    return grid + d * (covered_arms[k][0] - grid) / (covered_arms[k][1] + dist_diff[k])
            return None
        

    def select_or_activate(self):
        uncovered_arm = self.get_uncovered_arms()
        if uncovered_arm is not None:
            new_arm = uncovered_arm
            self.active_arms.append(new_arm)
            self.mu.append(0)
            self.conf_radius.append(1)
            self.num_of_pulled.append(0)
            self.num_of_observed.append(0)
            self.selected_arm_idx = len(self.active_arms) - 1
        return self.selected_arm_idx
    
    def get_delay(self):
        return geom.rvs(1 / 21) - 1

    def pull_arm(self, *args, **kwargs):
        idx = self.selected_arm_idx
        self.num_of_pulled[idx] += 1
        arm = self.active_arms[idx]
        reward = np.random.normal(self.mu(arm), self.sigma)
        self.regrets.append(self.mu_max - self.mu(arm))
        delay = self.get_delay()
        if self.timer + delay <= self.time_horizon:
            self.scheduled_rewards[self.timer + delay].append((idx, reward))
        for idx, reward in self.scheduled_rewards[self.timer]:
            self.mu_hat[idx] = (self.mu_hat[idx] * self.num_of_observed[idx] + reward) / (self.num_of_observed[idx] + 1)
            self.num_of_observed[idx] += 1
            self.conf_radius[idx] = self.sigma * np.sqrt((4 * np.log(self.time_horizon) + 2 * np.log(2 / self.delta)) / (1 + self.num_of_observed[idx]))
            if self.mu_hat[idx] + 2 * self.conf_radius[idx] > self.ucb:
                self.ucb = self.mu_hat[idx] + 2 * self.conf_radius[idx]
                self.selected_arm_idx = idx
        self.timer += 1



def simulate(trial, time_horizon):
    mu_max = get_reward(*[0.7, 0.8])
    cum_regret = np.zeros(time_horizon + 1)
    for _ in range(trial):
        alg = DelayedLipsPhasedPruning(get_reward, mu_max, time_horizon, 0.01, dim=2, order=np.inf)
        while alg.timer <= alg.time_horizon:
            alg.pull_arms(alg.phase_counter)
            alg.prune(0.5 ** alg.phase_counter)
        cum_regret += np.cumsum(alg.regrets)
    return cum_regret / trial


def run_simulation(trial, time_horizon):
    cum_regret = simulate(trial, time_horizon)
    #np.savetxt("output_dlpp_unif_50_2d.txt", cum_regret, delimiter='\n')


if __name__ == '__main__':
    run_simulation(trial=30, time_horizon=60000)