import re
import numpy as np
from scipy.special import logsumexp


class Context:
    EPSILON = 1e-5

    def __init__(self, num_of_arms, lr, gamma=None, weights=None):
        self._num_of_arms = num_of_arms
        if gamma is None:
            self.gamma = min(1.0, np.sqrt(np.log(num_of_arms) / num_of_arms))
            self.t = 1
        else:
            self.gamma = gamma
            self.t = int(np.log(num_of_arms) /
                         (num_of_arms * np.square(gamma)))

        if weights is None:
            self.weights = np.ones(num_of_arms)
        else:
            self.weights = np.array(weights)
        self._lr = lr

    def arm_probability(self, arm_idx):
        return self.weights[arm_idx]

    def _get_probabilities(self):
        # c = logsumexp(self.weights)
        # prob_dist = np.exp(self.weights - c)

        c = np.sum(self.weights)
        prob_dist = self.weights / c

        prob_dist = (1 - self._lr) * prob_dist + \
            self._lr * (1 / self._num_of_arms)
        return prob_dist

    def predict(self):
        prob_dist = self._get_probabilities()
        return np.random.choice(a=self._num_of_arms, p=prob_dist)

    def update(self, reward, last_arm=None):
        # if last_arm == None:
        #     last_arm = self.last_arm
        # loss = 1 - reward
        # if self.gamma != 1.0:
        #     c = logsumexp(-self.gamma * self.weights)
        #     prob = np.exp((-self.gamma * self.weights[last_arm]) - c)
        # else:
        #     prob = 1.0 / np.float(self.num_of_arms)

        # assert prob > 0

        # estimated = loss / prob
        # self.weights[last_arm] += estimated
        # self.t += 1
        # self.gamma = min(1.0, np.sqrt(np.log(self.num_of_arms) / (self.num_of_arms * self.t)))

        if last_arm is None:
            raise Exception("Got none for arm")

        reward = np.arctan(reward*1.0/10) / np.pi + 0.5
        # reward = (max(reward, -3000) + 3000) / 3060.0
        prob_dist = self._get_probabilities()
        norm_reward = reward/prob_dist[last_arm]
        self.weights[last_arm] *= np.exp(self._lr*norm_reward/self._num_of_arms)

        # reward = np.arctan(reward) / np.pi + 0.5
        # c = np.sum(self.weights)
        # prob = (1 - Context.EPSILON) * self.weights[last_arm] / c + \
        #     Context.EPSILON * (1 / self._num_of_arms)

        # loss = reward / prob
        # self.weights[last_arm] *= np.exp(self._lr *
        #                                  reward / self._num_of_arms)

    def get_dict(self):
        return {
            "weights": self.weights
        }

    def load_dict(self, dict):
        self.weights = dict["weights"]
