from tqdm import tqdm

import matplotlib.pylab as pl
import matplotlib.gridspec as gridspec
import numpy as np
import itertools
import abc

np.set_printoptions(precision=3, suppress=True)


class Learner:
    def __init__(self, n_arms: int, T: int) -> None:
        self.T = T  # time horizon
        self.n_arms = n_arms
        self.t = 0  # current time
        self.arm_pulls = np.zeros(n_arms)
        self.estimates = np.zeros(n_arms)
        self.rewards = np.zeros(T)

    def update(self, reward, arm) -> None:
        self.rewards[self.t] = reward
        self.estimates[arm] = (self.estimates[arm] * self.arm_pulls[arm] + reward) / (
            self.arm_pulls[arm] + 1
        )
        self.arm_pulls[arm] += 1
        self.t += 1

    def reset(self):
        self.t = 0
        self.arm_pulls = np.zeros(self.n_arms)
        self.estimates = np.zeros(self.n_arms)
        self.rewards = np.zeros(self.T)

    def get_rewards(self):
        return self.rewards

    def get_arm_pulls(self):
        return self.arm_pulls

    def __str__(self) -> str:
        return f"""
            Learner: {self.__class__.__name__}
            T={self.T}, K={self.n_arms}
            arm_pulls  : {self.arm_pulls}
            estimates  : {self.estimates}
            tot reward : {np.sum(self.rewards)}
        """

    @abc.abstractclassmethod
    def pull_arm(self) -> int:
        pass


class Greedy(Learner):
    def __init__(self, n_arms: int, T: int) -> None:
        super().__init__(n_arms, T)

    def pull_arm(self) -> int:
        if self.t < self.n_arms:
            return self.t
        else:
            return np.argmax(self.estimates)


class EpsilonGreedy(Learner):
    def __init__(self, n_arms: int, T: int, epsilon: callable = lambda x: 1 / x):
        super().__init__(n_arms, T)
        self.epsilon = epsilon

    def pull_arm(self) -> int:
        if self.t < self.n_arms:
            return self.t
        else:
            if np.random.random() <= self.epsilon(self.t):
                return np.random.choice(range(self.n_arms))
            else:
                return np.argmax(self.estimates)


class UCB(Learner):
    def __init__(self, n_arms: int, T: int, sigma: float = 1) -> None:
        super().__init__(n_arms, T)
        self.sigma = sigma

    def pull_arm(self) -> int:
        if self.t < self.n_arms:
            return self.t
        else:
            # exploration = np.log(10**5) / self.arm_pulls
            exploration = np.log(self.t) / self.arm_pulls
            exploration = 3 * self.sigma * np.sqrt(exploration)
            sel = np.add(self.estimates, exploration)
            return np.argmax(sel)


class CombinatorialLearner(Learner):
    def __init__(self, n_arms: int, T: int, d: int = None) -> None:
        super().__init__(n_arms, T)
        self.rewards = np.zeros((T, n_arms))
        self.d = n_arms if d is None else d

        assert self.d <= n_arms, "d cannot be >= n_arms"

    def update(self, reward: np.array, superarm: np.array) -> None:
        # reward    shape (n_arms,)
        # superarm  shape (n_arms,)
        # superarm is a vector {0, 1}^n_arms
        self.rewards[self.t, :] = reward
        # print(f'n {self.__class__.__name__[:2]} t={self.t} arm_pulls {self.arm_pulls} estimates {self.estimates}')
        self.estimates = (self.estimates * self.arm_pulls + reward) / (
            self.arm_pulls + superarm
        )
        self.estimates[np.isnan(self.estimates)] = 0
        self.arm_pulls += superarm
        self.t += 1

    def get_rewards(self) -> np.array:
        # return reward collected at each timestep
        # reward for a given timestep is the sum
        # of rewards for each basic arm
        return np.sum(self.rewards, axis=1)

    def reset(self):
        self.t = 0
        self.arm_pulls = np.zeros(self.n_arms)
        self.estimates = np.zeros(self.n_arms)
        self.rewards = np.zeros((self.T, self.n_arms))

    def __str__(self) -> str:
        return f"""
            Learner: {self.__class__.__name__}
            T={self.T}, K={self.n_arms}
            arm_pulls  : {self.arm_pulls}
            estimates  : {self.estimates}
            tot reward : {np.sum(self.get_rewards())}
        """


class CUCB(CombinatorialLearner):
    def __init__(
        self, n_arms: int, T: int, oracle: callable = None, sigma: float = 0.1, d: int = None
    ) -> None:
        super().__init__(n_arms, T, d)
        self.oracle = oracle
        self.sigma = sigma

    def default_oracle(self, estimates, d):
        best_arm = np.zeros(len(estimates))
        best_val = -np.inf
        for arm in itertools.product([0, 1], repeat=len(estimates)):
            if sum(arm) > d:
                continue
            curr_val = sum(np.array(arm) * estimates)
            if curr_val > best_val:
                best_val = curr_val
                best_arm = arm
        return np.array(best_arm)

    def pull_arm(self):
        if self.t < self.n_arms:
            superarm = np.zeros(self.n_arms)
            indexes = np.random.randint(self.n_arms, size=self.d - 1)
            superarm[self.t] = 1
            superarm[indexes] = 1
            return superarm

        exploration = np.log(self.t) / self.arm_pulls
        exploration = 3 * self.sigma * np.sqrt(exploration)
        ucb_arms = np.add(self.estimates, exploration)

        return self.oracle(ucb_arms, self.d) if self.oracle is not None else self.default_oracle(ucb_arms, self.d)


