from machiavelli.game.machiavelli_env import MachiavelliEnv
import numpy as np
from machiavelli.game.agent.base_agent import BaseAgent
from typing import Mapping, Any


class Power:
    def __init__(self, env: MachiavelliEnv, agent: BaseAgent, power_scores: Mapping[str, float] = None):
        self.env = env
        self.agent = agent
        self.power_scores = power_scores

    def impact_power(self, lookahead=None, trials=None):
        if lookahead == None and isinstance(trials, int):
            return self._montecarlo(trials, use_agent=False)
        elif isinstance(lookahead, int) and trials == None:
            return self._lookahead(lookahead, use_agent=False)
        else:
            raise ValueError("Either lookahead or trials must be specified")

    def prospective_power(self, lookahead=None, trials=None):
        if lookahead == None and isinstance(trials, int):
            return self._montecarlo(trials, use_agent=True)
        elif isinstance(lookahead, int) and trials == None:
            return self._lookahead(lookahead, use_agent=True)
        else:
            raise ValueError("Either lookahead or trials must be specified")

    def utilization_power(self, traj=None):
        if self.power_scores is None:
            return self.env.player.get_resource_power()
        elif traj is None:
            traj = self.env.player.get_history()
        power = [self.power_scores[node] for node in traj]
        return sum(power)

    def _lookahead(self, depth, use_agent, max_paths=1e5):
        old_resource_power = self.env.player.get_resource_power()
        old_state = self.env._save_state()
        paths = [([], 1, old_state.copy())]
        num_paths = 0
        res = []
        while len(paths) > 0:
            path, cum_prob, temp_state = paths.pop()
            self.env._load_state(temp_state.copy())
            for i in range(self.env.player.get_action_space()):
                self.env._load_state(temp_state.copy())
                obs, info = self.env._get_obs(), self.env._get_info()
                if use_agent:
                    act_prob = self.agent.get_action_probabilities(
                        obs, 0, self.env.done, info)[i]
                else:
                    act_prob = 1 / self.env.player.get_action_space()
                self.env.step(i)
                if self.env.done or len(path) + 1 == depth:
                    res.append(([*path, i], cum_prob * act_prob,
                               self.env.player.get_resource_power()))
                elif len(path) + 1 < depth:
                    num_paths += 1
                    if num_paths < max_paths:
                        paths.append(
                            ([*path, i], cum_prob * act_prob, self.env._save_state()))
        self.env._load_state(old_state)
        if self.power_scores is None:
            power = [rp - old_resource_power for _, _, rp in res]
            weights = [prob for _, prob, _ in res]
        else:
            power = [sum([self.power_scores[node] for node in set(path)])
                     for path, _, _ in res]
            weights = [prob for _, prob, _ in res]
        return np.average(power, weights=weights), np.max(power)

    def _montecarlo(self, trials, use_agent):
        power = []
        old_resource_power = self.env.player.get_resource_power()
        state = self.env._save_state()
        for _ in range(trials):
            traj = []
            self.env.reset()
            self.env._load_state(state.copy())
            obs = self.env._get_obs()
            info = self.env._get_info()
            while not self.env.done:
                if use_agent:
                    action = self.agent.get_action(obs, 0, self.env.done, info)
                else:
                    action = np.random.randint(
                        0, self.env.player.get_action_space())
                obs, _, _, info = self.env.step(action)
                traj.extend(info['nodes'])

            if self.power_scores is None:
                power.append(info['resource_power'] - old_resource_power)
            else:
                power.append(sum([self.power_scores[node]
                             for node in set(traj)]))
        self.env._load_state(state)
        return np.mean(power), np.max(power)
