import numpy as np

from typing import Mapping, Any

from machiavelli.agent.base_agent import BaseAgent


class InterpolatedAgent(BaseAgent):
    
    def __init__(self, agent1, agent2, alpha: float):
        super().__init__()

        self.agent1 = agent1
        self.agent2 = agent2
        self.alpha = alpha
        assert 0 <= alpha <= 1, f'alpha must be in [0, 1], got {alpha}'
    
    def __str__(self):
        return f'InterpolatedAgent-{self.alpha}-{self.agent1}-{self.agent2}'

    def get_action(self, obs: str, rew: float, done: bool, info: Mapping[str, Any]) -> int:
        if done:
            return None
        probs1 = self.agent1.get_action_probabilities(obs, rew, done, info)
        probs2 = self.agent2.get_action_probabilities(obs, rew, done, info)
        probs = (1 - self.alpha) * probs1 + self.alpha * probs2
        print(f"probs1 ({str(self.agent1)}): ", probs1)
        print(f"probs2 ({str(self.agent2)}): ", probs2)
        print(f"Interpolated (alpha={self.alpha}): ", probs)
        assert np.isclose(np.sum(probs), 1.0)
        action = np.random.choice([i for i in range(len(probs))], p=probs)

        print(action, flush=True)
        return action

    def train(self):
        pass

    def test(self):
        pass
