import numpy as np

from .base import AbstractAgent


class AgentVersatileDB(AbstractAgent):
    def sample_action(self):
        a = np.random.choice(range(self.env.K), p=self.p)
        b = np.random.choice(range(self.env.K), p=self.p_prime)
        return (a, b)

    def reset_learning(self):
        self.p = np.ones(self.env.K) / self.env.K
        self.p_prime = np.ones(self.env.K) / self.env.K
        self.L = np.zeros(self.env.K)
        self.L_prime = np.zeros(self.env.K)
        self.x = 1
        self.x_prime = 1

    def learn(self, action, observation):
        la = (1-observation) / self.p[action[0]]
        lb = observation / self.p_prime[action[1]]
        self.L[action[0]] += la / 2
        self.L_prime[action[1]] += lb / 2
        self.update_p()

    def update_p(self):
        threshold = 0.0000001
        eta = (4 * self.env.K**(-1/6)) * np.sqrt(1/(self.t+1))
        norm = np.inf
        prew_w = np.zeros(self.env.K)
        while norm > threshold:
            if np.sum(self.L-self.x <= 0) > 0:
                w = np.zeros(self.env.K)
                w[self.L-self.x <= 0] = 1
                w /= sum(w)
            else:
                w = 4 * np.power((eta * (self.L-self.x)), -2)
            self.x -= (np.sum(w)-1) / (eta * np.sum(w ** (3/2)))
            norm = np.linalg.norm(prew_w - w)
            prew_w = w

        self.p = w / sum(w)

        norm = np.inf
        prew_w = np.zeros(self.env.K)
        while norm > threshold:
            if np.sum(self.L_prime-self.x_prime <= 0) > 0:
                w = np.zeros(self.env.K)
                w[self.L_prime-self.x_prime <= 0] = 1
                w /= sum(w)
            else:
                w = 4 * np.power((eta * (self.L_prime-self.x_prime)), -2)
            self.x_prime -= (np.sum(w)-1) / (eta * np.sum(w ** (3/2)))
            norm = np.linalg.norm(prew_w - w)
            prew_w = w

        self.p_prime = w / sum(w)
