import numpy as np

from .base import AbstractAgent


class AgentRRDB(AbstractAgent):
    def __init__(self, env, name='Unknown', delta=0.01):
        self.delta = delta
        super().__init__(env, name=name)

    def reset_learning(self):
        self.nij = np.zeros((self.env.K, self.env.K))
        self.pij = np.ones((self.env.K, self.env.K)) * np.inf
        # self.uij = np.ones((self.env.K,self.env.K)) * np.Inf
        self.oij = np.zeros((self.env.K, self.env.K))
        self.active_arms = np.arange(self.env.K)

    def sample_action(self):
        M = self.nij == np.min(self.nij)
        (a, b) = np.argwhere(M)[np.random.randint(np.argwhere(M).shape[0])]
        return (a, b)

    def learn(self, action, observation):
        (a, b) = action
        self.nij[a, b] += 1
        self.nij[b, a] += 1
        self.oij[a, b] += observation
        self.oij[b, a] += 1 - observation
        self.pij[a, b] = self.oij[a, b] / self.nij[a, b]
        self.pij[b, a] = self.oij[b, a] / self.nij[b, a]
        new_active_arms = self.active_arms.copy()
        new_nij = self.nij.copy()
        for i in self.active_arms:
            for j in self.active_arms:
                # print(math.sqrt((math.log(self.env.K * (self.t+1) / self.delta))/(self.nij[i,j])))
                # print(self.pij[i,j])
                # self.uij[i,j] = self.pij[i,j] + math.sqrt((math.log(self.env.K * (self.t+1) / self.delta))/(self.nij[i,j]))
                if self.pij[i, j] + np.sqrt((np.log(self.env.K * (self.t+1) / self.delta))/(self.nij[i, j])) < 1/2:
                    new_nij[i, :] = np.inf
                    new_nij[:, i] = np.inf
                    new_active_arms = np.delete(new_active_arms, new_active_arms == i)
        self.active_arms = new_active_arms
        self.nij = new_nij
