import numpy as np
from numba import float64, int64    # import the types
from bisect import insort


spec = [
    ('means', float64[:]),
    ('nb_arms', int64),
    ('T', int64),
    ('Sa', float64[:]),
    ('Na', int64[:]),
    ('Sa_bin', int64[:]),
    ('Na_bin', int64[:]),
    ('reward', float64[:]),
    ('arm_sequence', int64[:]),
    ('t', int64)
]


class Tracker2:
    def __init__(self,
                 means,
                 T,
                 store_rewards_arm=False,
                 store_max_rewards_arm=False,
                 store_sorted_rewards_arm=False,
                 q_left=None,
                 q_right=None,
                 binarization_trick=False,
                 ):
        self.means = means
        self.nb_arms = means.shape[0]
        self.T = T
        self.store_rewards_arm = store_rewards_arm
        self.store_max_rewards_arm = store_max_rewards_arm
        self.store_sorted_rewards_arm = store_sorted_rewards_arm
        self.q_left = q_left
        self.q_right = q_right
        self.binarization_trick = binarization_trick
        self.reset()

    def reset(self):
        """
        Initialization of quantities of interest used for all methods
            - Sa: np array, cumulative reward of arm a
            - Na: np array, number of times arm a has been pulled
            - reward: np array, rewards
            - arm_sequence: np array, arm chose at each step
        """
        self.Sa = np.zeros(self.nb_arms)
        self.Na = np.zeros(self.nb_arms, dtype='int')
        self.reward = np.zeros(self.T)
        self.arm_sequence = np.zeros(self.T, dtype=int)
        self.t = 0
        if self.binarization_trick:
            self.Sa_bin = np.zeros(self.nb_arms)
            self.Na_bin = np.zeros(self.nb_arms, dtype='int')
        if self.store_rewards_arm:
            self.rewards_arm = [[] for _ in range(self.nb_arms)]
        if self.store_max_rewards_arm:
            self.max_rewards_arm = [-np.inf for _ in range(self.nb_arms)]
        if self.store_sorted_rewards_arm:
            self.sorted_rewards_arm = [[] for _ in range(self.nb_arms)]
            self.idx_q_left = np.zeros(self.nb_arms, dtype='int')
            self.idx_q_right = np.zeros(self.nb_arms, dtype='int')

    def update(self, t, arm, reward):
        """
        Update all the parameters of interest after choosing the correct arm
        :param t: int, current time/round
        :param arm: int, arm chose at this round
        :param Sa:  np array, cumulative reward array up to time t-1
        :param Na:  np array, number of times arm has been pulled up to time t-1
        :param reward: np array, rewards obtained with the policy up to time t-1
        :param arm_sequence: np array, arm chose at each step up to time t-1
        """
        self.Na[arm] += 1
        self.arm_sequence[t] = arm
        self.reward[t] = reward
        self.Sa[arm] += reward
        self.t = t
        if self.store_rewards_arm:
            self.rewards_arm[arm].append(reward)
        if self.store_max_rewards_arm:
            if reward > self.max_rewards_arm[arm]:
                self.max_rewards_arm[arm] = reward
        if self.store_sorted_rewards_arm:
            self.idx_q_left[arm] = np.ceil(self.q_left * self.Na[arm]).astype(int)
            self.idx_q_right[arm] = np.ceil(self.q_right * self.Na[arm]).astype(int)
            insort(self.sorted_rewards_arm[arm], reward)

    def update_binarized(self, t, arm, reward_bin):
        """
        Update all the parameters of interest for the binarization trick
        after choosing the correct arm
        :param t: int, current time/round
        :param arm: int, arm chose at this round
        :param Sa_bin:  np array,  number of successes array up to time t-1
        :param Na_bin:  np array, number of times arm has been pulled up to time t-1
        """
        self.Na_bin[arm] += 1
        self.Sa_bin[arm] += reward_bin

    def regret(self):
        """
        Compute the regret of a single experiment
        :param reward: np array, the array of reward obtained from the policy up to time T
        :param T: int, time horizon
        :return: np.array, cumulative regret for a single experiment
        """
        return self.means.max() * np.arange(1, self.T + 1) - np.cumsum(np.array(self.means)[self.arm_sequence])
