import os 
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
import numpy as np
import torch
import time

from a2c_ppo_acktr.envs import make_vec_envs_eval, make_vec_rewenvs_eval

def evaluate_attack(params, args, obs_size, ckpt_path, device):
    if args.use_rew_model:
        eval_envs = make_vec_rewenvs_eval(params['seed'], params, args.max_steps, args.num_processes, args.gamma, obs_size, args.cuda_id, eval=True)
    else:
        eval_envs = make_vec_envs_eval(params['seed'], params, args.max_steps, args.num_processes, args.gamma, obs_size, args.cuda_id, eval=True)
    actor_critic = torch.load(ckpt_path, map_location=device)[0]
    num_processes = args.num_processes

    obs = eval_envs.reset()
    eval_recurrent_hidden_states = torch.zeros(
        num_processes, actor_critic.recurrent_hidden_state_size, device=device)
    eval_masks = torch.zeros(num_processes, 1, device=device)

    time_start = time.time()
    eval_episode_rewards = []
    while len(eval_episode_rewards) < 400:
       
        with torch.no_grad():
            _, action, _, eval_recurrent_hidden_states = actor_critic.act(
                obs,
                eval_recurrent_hidden_states,
                eval_masks,
                deterministic=True)
        
        obs, _, done, infos = eval_envs.step(action)

        eval_masks = torch.tensor(
            [[0.0] if done_ else [1.0] for done_ in done],
            dtype=torch.float32,
            device=device)
        
        # for info in infos:
        #     if 'episode' in info.keys():
        #         eval_episode_rewards.append(info['episode']['r'])
        if done[0]:
            eval_episode_rewards.append(infos[0]['episode_r'])
            if 'finish_all' in infos[0].keys():
                break
                
    time_end = time.time()
    print('='*50)
    print('Running time:', (time_end - time_start) / 60, 'm')
    print('='*50)
    eval_envs.close()

    print(" Evaluation using {} episodes: mean reward {:.5f}\n".format(
        len(eval_episode_rewards), np.mean(eval_episode_rewards)))


