import numpy as np
import math
import ray
from ray import cloudpickle as pickle
from collections import deque
from sklearn.cluster import Birch


class EXP3(object):
    def __init__(self, gamma, num_tasks, seed, min_rew, max_rew):
        self._seed = seed
        np.random.seed(self._seed)
        self._num_tasks = num_tasks
        self._gamma = gamma
        self._log_weights = np.zeros(self._num_tasks)
        self._max_rew = max_rew
        self._min_rew = min_rew

    @property
    def task_probabilities(self):
        weights = np.exp(self._log_weights - np.sum(self._log_weights))
        probs = (1 - self._gamma) * weights / np.sum(weights) + self._gamma / self._num_tasks
        return probs

    def sample_task(self):
        """Samples a task, according to current Exp3 belief."""
        return np.random.choice(self._num_tasks, p=self.task_probabilities)

    def update(self, task_i, reward):
        reward = (reward - self._min_rew) / (self._max_rew - self._min_rew)
        reward_corrected = reward / self.task_probabilities[task_i]
        self._log_weights[task_i] += self._gamma * reward_corrected / self._num_tasks


@ray.remote
class ContextualBanditTaskGenerator:
    def __init__(
        self,
        seed,
        num_contexts,
        gamma,
        num_agents,
        min_rew,
        max_rew,
        alp_window_size=20,
    ):
        self._seed = seed
        np.random.seed(self._seed)
        self._tasks = num_agents
        self.num_tasks = len(self._tasks)
        self._num_contexts = num_contexts
        self.algo = [
            EXP3(
                gamma=gamma,
                num_tasks=self.num_tasks,
                seed=self._seed,
                min_rew=min_rew,
                max_rew=max_rew,
            )
            for _ in range(self._num_contexts)
        ]
        self.context_classifier = Birch(n_clusters=self._num_contexts)
        self.context_class = 0
        self.context_history = list()

        self.alp_window_size = alp_window_size
        self.reward_history = [
            {k: deque(maxlen=self.alp_window_size) for k in self._tasks}
            for _ in range(self._num_contexts)
        ]

    def episodic_update(self, task, reward):
        """Get the episodic reward of a task."""
        # Compute ALP
        self.reward_history[self.context_class][task].append(reward)
        if len(self.reward_history[self.context_class][task]) >= self.alp_window_size:
            window = int(self.alp_window_size / 2)
            lp = np.mean(list(self.reward_history[self.context_class][task])[-window:]) - \
                 np.mean(list(self.reward_history[self.context_class][task])[:window])
            alp = np.abs(lp)
            if len(self.context_history) >= 2 * self._num_contexts:
                self.algo[self.context_class].update(self._tasks.index(task), alp)

    def update_context(self, context):
        self.context_history.append(context)
        if len(self.context_history) < 2 * self._num_contexts:
            self.context_class = len(self.context_history) % self._num_contexts
        elif len(self.context_history) == 2 * self._num_contexts:
            self.context_classifier.partial_fit(list(self.context_history))
            self.context_class = len(self.context_history) % self._num_contexts
        else:
            self.context_history.pop(0)
            self.context_classifier.partial_fit([context])
            self.context_class = self.context_classifier.predict([context])[0]

    def context_task_probs(self):
        """Return the task probs under every contexts."""
        probs = []
        for i in range(self._num_contexts):
            probs.append(self.algo[i].task_probabilities)
        return probs

    def sample_task(self):
        if len(self.context_history) < 2 * self._num_contexts:
            sample = np.random.choice(self._tasks)
        else:
            sample = self._tasks[self.algo[self.context_class].sample_task()]
        return sample

    def save(self) -> bytes:
        return pickle.dumps(
            {
                "context_classifier": self.context_classifier,
                "context_history": self.context_history,
                "context_class": self.context_class,
                "algo": self.algo,
                "reward_history": self.reward_history,
            }
        )

    def restore(self, objs: bytes) -> None:
        objs = pickle.loads(objs)
        self.context_classifier = objs["context_classifier"]
        self.context_history = objs["context_history"]
        self.context_class = objs["context_class"]
        self.algo = objs["algo"]
        self.reward_history = objs["reward_history"]

    def get_name(self):
        return "contextual-bandit"


def test_exp3():
    import random
    numActions = 3
    numRounds = 3000

    biases = [1.0 / k for k in range(2,12)]
    rewardVector = [[1 if random.random() < bias else 0 for bias in biases] for _ in range(numRounds)]
    reward_func = lambda choice, t: rewardVector[t][choice]

    bestAction = max(range(numActions), key=lambda action: sum([rewardVector[t][action] for t in range(numRounds)]))
    # bestUpperBoundEstimate = 2 * numRounds / 3
    # gamma = math.sqrt(numActions * math.log(numActions) / ((math.e - 1) * bestUpperBoundEstimate))
    gamma = 0.07

    cumulativeReward = 0
    bestActionCumulativeReward = 0
    algo = EXP3(gamma, numActions, seed=123, min_rew=0, max_rew=1)
    for t in range(numRounds):
        probs = algo.task_probabilities
        choice = algo.sample_task()
        rew = reward_func(choice, t)
        algo.update(choice, rew)
        cumulativeReward += rew
        bestActionCumulativeReward += rewardVector[t][bestAction]
        weakRegret = bestActionCumulativeReward - cumulativeReward
        if (t + 1) % 500 == 0:
            print(f"Iter {t+1}, regret={weakRegret}, probs: {probs}")
    regretBound = (math.e - 1) * gamma * bestActionCumulativeReward + (numActions * math.log(numActions)) / gamma
    print(f"Cumulative Reward: {cumulativeReward}, regretBound: {regretBound}")


def test_contextual_bandit():
    import random
    num_contexts = 2
    num_iterations = 2000
    task_generator = ContextualBanditTaskGenerator.options(name="task_generator").remote(
        seed=123,
        num_contexts=num_contexts,
        gamma=0.07,
        update_interval=20,
        num_agents=[0, 1, 2],
        min_rew=-10,
        max_rew=10,
    )
    rewards_for_context = {
        -1.0: [-10, 0, 10],
        1.0: [10, 0, -10],
    }

    for t in range(num_iterations):
        context = random.choice([-1.0, 1.0])
        task_generator.update_context.remote([context, 0.0])
        action = ray.get(task_generator.sample_task.remote())
        reward = rewards_for_context[context][action]
        rewards_for_context[-1.0][0] += 0.1  # simulate learning progress
        rewards_for_context[1.0][1] += 0.1
        task_generator.episodic_update.remote(action, reward)
        if (t + 1) % 200 == 0:
            print(f"Iter {t+1}, probs: {ray.get(task_generator.context_task_probs.remote())}")


if __name__ == "__main__":
    print("Testing Exp3:")
    test_exp3()
    print("\nTesting Contextual Bandit:")
    ray.init()
    test_contextual_bandit()
    ray.shutdown()
