from common.arguments import get_args
from common.env_wrappers import SubprocVecEnv
from common.helper import Logger, set_all_seeds
import pprint
import sys
from tqdm import tqdm
from agent import Agents
from common.replay_buffer import ReplayBuffer
import torch
import os
import numpy as np
import matplotlib.pyplot as plt


class Runner:
    def __init__(self, args, env):
        self.args = args
        self.eps = args.eps
        self.episode_limit = args.max_episode_len
        self.env = env
        self.agents = Agents(self.args)
        self.buffer = ReplayBuffer(self.args.buffer_size)
        self.save_path = self.args.save_dir
        if not os.path.exists(self.save_path):
            os.makedirs(self.save_path)

    def eval(self):

        rewards, steps = 0, 0
        iter = 1e4 // self.args.vec_env

        print(iter)
        for _ in tqdm(range(int(iter))):
            reward, step = self.run_episode(evaluate=True)
            rewards += reward
            steps += step

        print('Total {} episodes, Average reward: {:.2f}, Average step_len: {:.2f}, Reward pre step: {:.2f}'
              .format(self.args.vec_env * iter, rewards / iter, steps / iter, rewards / steps))


    def run_episode(self, evaluate=False):

        s, available_actions = self.env.reset()
        step = 0

        reward = 0
        step_len = 0
        done = np.zeros(self.args.vec_env)

        hidden = self.agents.init_hidden()

        actions, hidden1 = self.agents.select_actions(s, available_actions, hidden, 0 if evaluate else self.eps)

        while step < self.episode_limit:
            step_len += (1 - done).mean()

            s_next, r, done, available_actions_next = self.env.step(actions)

            actions_next, hidden2 = self.agents.select_actions(s_next, available_actions_next, hidden1, 0 if evaluate else self.eps)

            reward += sum(r[:, 0])/self.args.vec_env

            # print('step', step, 'reward:', r[:, 0], 'done', done, 'actions', actions)

            s = s_next
            actions = actions_next
            available_actions = available_actions_next

            hidden1 = hidden2

            step += 1

        return reward, step_len

def make_env(args):
    import sys
    sys.path.append(
        os.path.abspath(
            os.path.join(os.path.dirname(sys.modules[__name__].__file__), "..")
        )
    )

    if args.scenario_name == 'GuessingNumber':
        from Env.guessing_number import GuessingNumber

        env = GuessingNumber()

        args.n_agents = env.num_agents
        args.obs_shape = env.observation_space # 每一维代表该agent的obs维度

        args.action_shape = env.action_space_noop

        def get_env_fn(rank):
            def init_env():
                env1 = GuessingNumber()
                set_all_seeds(args.seed + rank * 12345)
                return env1
            return init_env

        if args.vec_env > 1:
            return SubprocVecEnv([get_env_fn(i) for i in range(args.vec_env)]), args
        else:
            return env, args


if __name__ == '__main__':
    # get the params
    args = get_args()

    torch.backends.cudnn.benchmark = True

    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    logger_path = os.path.join(args.save_dir, "eval.log")
    sys.stdout = Logger(logger_path)

    set_all_seeds(args.seed)

    env, args = make_env(args)
    runner = Runner(args, env)

    runner.agents.policy.critic_network.load_state_dict(torch.load('local/80000_critic_params.pkl'))

    pprint.pprint(vars(args))

    eturns = runner.eval()
