import numpy as np

from .base import AbstractAgent


class AgentAlg2(AbstractAgent):
    def sample_action(self):
        I_prime = np.random.choice(range(self.env.K), p=self.p)
        J_prime = np.random.choice(range(self.env.K), p=self.p)
        if I_prime == J_prime:
            I = np.random.choice(range(self.env.K), p=self.p)
            J = np.random.choice(range(self.env.K), p=self.r)
            self.q = self.r[J]
        else:
            I = np.random.choice(range(self.env.K), p=self.p)
            J = np.random.choice(range(self.env.K), p=self.p)
            self.q = self.p[J]
        return (I, J)

    def reset_learning(self):
        self.p = np.ones(self.env.K) / self.env.K
        self.r = np.ones(self.env.K) / self.env.K
        self.L = np.zeros(self.env.K)
        self.q = 1      # q_{k,t} used to create loss estimate
        self.x = 1

    def learn(self, action, observation):
        l_hat = observation / self.q
        self.L[action[1]] += l_hat
        self.update_p()
        self.r = (self.p ** (2/3)) / (np.sum(self.p ** (2/3)))

    def update_p(self):
        norm = np.inf
        threshold = 0.0000001
        eta = (20 * self.env.K**(-1/6)) * np.sqrt(1/(self.t+1))
        prew_w = np.zeros(self.env.K) + 1/self.env.K
        while norm > threshold:
            if np.sum(self.L-self.x <= 0) > 0:
                w = np.zeros(self.env.K)
                w[self.L-self.x <= 0] = 1
                w /= sum(w)
            else:
                w = 4 / ((eta * (self.L-self.x))**2)

            self.x -= (np.sum(w)-1) / (eta * np.sum(w ** (3/2)))

            norm = np.linalg.norm(prew_w - w)
            prew_w = w
        self.p = w / sum(w)
