from common.arguments import get_args
from common.env_wrappers import SubprocVecEnv
from common.helper import Logger, set_all_seeds
import pprint
import sys
from tqdm import tqdm
from agent import Agents
import os
import torch
import numpy as np
import matplotlib.pyplot as plt

def make_env(args):
    import sys
    sys.path.append(
        os.path.abspath(
            os.path.join(os.path.dirname(sys.modules[__name__].__file__), "..")
        )
    )

    if args.scenario_name == 'GuessingNumber':
        from Env.guessing_number import GuessingNumber

        env = GuessingNumber()

        args.n_agents = env.num_agents
        args.obs_shape = env.observation_space # 每一维代表该agent的obs维度

        args.action_shape = env.guess_space_noop
        args.message_shape = env.reveal_space_noop

        def get_env_fn(rank):
            def init_env():
                env1 = GuessingNumber()
                set_all_seeds(args.seed + rank * 12345)
                return env1
            return init_env

        if args.vec_env > 1:
            return SubprocVecEnv([get_env_fn(i) for i in range(args.vec_env)]), args
        else:
            return env, args

class Runner:
    def __init__(self, args, env):
        self.args = args
        self.episode_limit = args.max_episode_len
        self.env = env
        self.agents = Agents(self.args)
        self.save_path = self.args.save_dir
        self.eps = self.args.eps
        if not os.path.exists(self.save_path):
            os.makedirs(self.save_path)

    def eval(self):

        rewards, steps = 0, 0
        iter = 1e4//self.args.vec_env
        print(iter)
        for _ in tqdm(range(int(iter))):
            reward, step = self.run_episode(evaluate=True)
            rewards += reward
            steps += step

        print('Total {} episodes, Average reward: {:.2f}, Average step_len: {:.2f}, Reward pre step: {:.2f}'
              .format(1e4//self.args.vec_env*iter, rewards/iter, steps/iter, rewards/steps))


    def run_episode(self, evaluate=False):

        s, available_actions = self.env.reset()
        step = 0
        reward = 0
        step_len = 0

        m = self.agents.init_message()
        hidden = self.agents.init_hidden()

        done = np.zeros(self.args.vec_env)
        a_u, u = None, None

        while step < self.episode_limit and not done.all():

            if self.args.scenario_name == 'GuessingNumber':
                a_u = self.guessing_number_mask(available_actions)

            actions, q_value, hidden_next, message, m_next = self.agents.select_actions(s, m, a_u, hidden,
                                                                                        0 if evaluate else self.eps)

            if self.args.scenario_name == 'GuessingNumber':
                u = self.guessing_number_action(actions, message)

            step_len += (1-done).mean()

            s_next, r, done, available_actions_next = self.env.step(u)

            reward += sum(r[:, 0]) / self.args.vec_env

            s = s_next
            available_actions = available_actions_next

            hidden = hidden_next
            m = m_next

            step += 1

        return reward, step_len

    def guessing_number_mask(self, available_actions):
        results = []
        for env_id in range(self.args.vec_env):
            result = []
            for player_id in range(self.args.n_agents):
                if available_actions[env_id][player_id][-1] == 1:
                    action_mask = np.zeros(self.args.action_shape)
                    # reveal_mask = np.zeros(self.args.message_shape)
                    action_mask[-1] = 1
                    # reveal_mask[-1] = 1
                else:
                    action_mask = np.ones(self.args.action_shape)
                    # reveal_mask = np.ones(self.args.message_shape)
                    action_mask[-1] = 0
                    # reveal_mask[-1] = 0
                    action_mask[:self.args.action_shape - 1] = available_actions[env_id][player_id][
                                                               :self.args.action_shape - 1]
                # result.append(np.concatenate([action_mask, reveal_mask]))
                result.append(action_mask)
            results.append(result)
        return np.array(results)

    def guessing_number_action(self, actions, message):
        results = []
        for env_id in range(self.args.vec_env):
            result = []
            for player_id in range(self.args.n_agents):
                action_id = actions[env_id][player_id][0]
                message_id = from_one_hot(message[env_id][player_id])
                if action_id == self.args.action_shape - 1 or message_id == self.args.message_shape - 1:
                    assert (action_id == self.args.action_shape - 1 and message_id == self.args.message_shape - 1), (
                        'action and reveal are not both NOOP1', action_id, self.args.action_shape - 1, message_id,
                        self.args.message_shape - 1)
                    final_action = self.args.action_shape + self.args.message_shape - 4
                elif action_id == self.args.action_shape - 2:
                    assert (message_id != self.args.message_shape - 2), (
                    'action and reveal are both NOOP2', message_id, self.args.message_shape - 2)
                    final_action = message_id + self.args.action_shape - 2
                elif action_id < self.args.action_shape - 2:
                    assert (message_id == self.args.message_shape - 2), (
                    'action working, reveal are not NOOP2', message_id, self.args.message_shape - 2)
                    final_action = action_id
                else:
                    raise NotImplementedError(action_id, message_id)
                result.append(final_action)
            results.append(result)
        return np.array(results)


def from_one_hot(one_hot):
    if np.sum(one_hot == 1) != 1 or np.sum(one_hot == 0) != len(one_hot) - 1:
        raise ValueError("input is not one-hot", one_hot)
    index = np.argmax(one_hot)  # 找到值为1的元素的索引
    return index


if __name__ == '__main__':
    # get the params
    args = get_args()

    torch.backends.cudnn.benchmark = True
    torch.autograd.set_detect_anomaly(True)

    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    logger_path = os.path.join(args.save_dir, "eval.log")
    sys.stdout = Logger(logger_path)

    set_all_seeds(args.seed)

    env, args = make_env(args)

    runner = Runner(args, env)

    runner.agents.policy.ICN.load_state_dict(torch.load(''))

    pprint.pprint(vars(args))

    runner.eval()

