import numpy as np
from db.agents.base import AbstractAgent


class AgentMABtoDB(AbstractAgent):
    def __init__(self, env, name='Unknown', Mab=[None, None]):
        self.Mab = [Mab[0] if Mab[0] is not None else MabRandom(env),
                    Mab[1] if Mab[1] is not None else MabRandom(env)]
        # self.Mab = [Mab[0] if Mab[0] is not None else MabEXP3(env),
        #             Mab[1] if Mab[1] is not None else MabEXP3(env)]
        super().__init__(env, name)

    def sample_action(self):
        a = self.Mab[0].sample_action()
        b = self.Mab[1].sample_action(other_action=a)
        return (a, b)

    def reset_learning(self):
        self.Mab[0].reset_learning()
        self.Mab[1].reset_learning()

    def learn(self, action, observation):
        self.Mab[0].learn(action[0], observation)
        self.Mab[1].learn(action[1], 1-observation)


class MabRandom:
    def __init__(self, env):
        self.env = env

    def reset_learning(self):
        pass

    def sample_action(self, other_action=None):
        action = np.random.randint(self.env.K)
        return action

    def learn(self, action, observation):
        pass


# class MabEXP3:
#     def __init__(self, env, eta=0.001):
#         self.env = env
#         self.eta = 0.1
#         self.reset_learning()

#     def reset_learning(self):
#         self.L = np.zeros(self.env.K)
#         self.p = np.ones(self.env.K) / self.env.K

#     def sample_action(self, other_action=None):
#         w = np.exp(-self.eta * self.L)
#         self.p = w / sum(w)
#         action = np.random.choice(range(self.env.K), p=self.p)
#         return action

#     def learn(self, action, observation, other_action=None):
#         self.L[action] += (1 - observation) / self.p[action]
