import numpy as np

from .base import AbstractAgent


class AgentAlgOld2(AbstractAgent):
    def sample_action(self):
        # print(self.p)
        a = np.random.choice(range(self.env.K), p=self.p)
        b = np.random.choice(range(self.env.K), p=self.p)
        if a == b:
            a = np.random.choice(range(self.env.K), p=self.p)
            b = np.random.randint(self.env.K)
        else:
            a = np.random.choice(range(self.env.K), p=self.p)
            b = np.random.choice(range(self.env.K), p=self.p)
        return (a, b)

    def reset_learning(self):
        self.p = np.ones(self.env.K) / self.env.K
        self.L = np.zeros(self.env.K)
        self.x = 1

    def learn(self, action, observation):
        a = action[0]
        b = action[1]
        p_same = np.sum(self.p**2)
        K = self.env.K
        pa = self.p[a]
        pb = p_same * (1/K) + (1-p_same) * self.p[b]
        # pb = self.p[b]
        la = (1-observation) / pa
        lb = observation / pb
        # print(la, lb, self.p[a], pa, self.p[b], pb)
        self.L[a] += la / 2
        self.L[b] += lb / 2
        self.update_p()

    # def learn(self, action, observation):
    #     a = action[0]
    #     b = action[1]
    #     p_same = np.sum(self.p**2)
    #     K = self.env.K
    #     pa = self.p[a] + p_same * (1/K) * (1-self.p[a]) + (1-p_same) * self.p[a] * (1-self.p[a])
    #     pb = self.p[b] + p_same * (1/K) * (1-self.p[b]) + (1-p_same) * self.p[b] * (1-self.p[b])
    #     la = (1-observation) / pa
    #     lb = observation / pb
    #     # print(la, lb, self.p[a], pa, self.p[b], pb)
    #     self.L[a] += la / 2
    #     self.L[b] += lb / 2
    #     self.update_p()

    def update_p(self):
        norm = np.inf
        threshold = 0.0000001
        eta = 2 * np.sqrt(1/(self.t+1))
        prew_w = np.zeros(self.env.K)
        i = 0
        while norm > threshold:
            i += 1
            if np.sum(self.L == np.inf) > 0:
                w = np.zeros(self.env.K)
                w[self.L == np.inf] = 1
                w /= sum(w)
                # print(w)
            else:
                w = 4 / ((eta * (self.L-self.x))**2)

                # print(w)
            self.x -= (np.sum(w)-1) / (eta * np.sum(np.power(w, 3/2)))

            norm = np.linalg.norm(prew_w - w)
            prew_w = w
        # if self.t % 100 == 0:
        #     print(self.t, i, sum(w), w)
        # print(self.x, w)
        self.p = w / sum(w)
