from .base import AbstractAgent
import numpy as np

class AgentWSW(AbstractAgent):

    def reset_learning(self):
        self.C = np.zeros(self.env.K)
        self.previous_a = -1
        self.previous_b = -1

    def sample_action(self):
        # override in children classes
        a = b = -1
        C = self.C.copy()
        m = np.argwhere(C == np.amax(C)).flatten()

        if m.size == 1:
            a = m[0]
        elif self.previous_a in m:
            a = self.previous_a
        elif self.previous_b in m:
            a = self.previous_b
        else:
            a = np.random.choice(m)

        C = self.C.copy()
        C[a] = np.amin(C) - 1
        m = np.argwhere(C == np.amax(C)).flatten()

        if m.size == 1:
            b = m[0]
        elif self.previous_a in m:
            b = self.previous_a
        elif self.previous_b in m:
            b = self.previous_b
        else:
            b = np.random.choice(m)

        return (a,b)
    
    def learn(self, action, observation):
        self.previous_a = action[0]
        self.previous_b = action[1]
        
        o = observation * 2 - 1

        self.C[action[0]] += o
        self.C[action[1]] -= o