import random
import sys
sys.path.append('../env/')
sys.path.append('../policy/')
sys.path.append('../config/')
from AttrDict import AttrDict
from box_world import Environment_Pick
from PPO_agent_pick import Agent_PNet_PPO
import time
import os
import torch
import numpy as np
import argparse
import yaml

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', default='../config/test_pick.yaml', type=str)
    args = parser.parse_args()
    with open(args.config, 'r') as f:
        args = AttrDict(yaml.safe_load(f))
    print(args)

    def set_seed(seeds):
        torch.manual_seed(seeds)
        torch.cuda.manual_seed_all(seeds)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        np.random.seed(seeds)
        random.seed(seeds)
        os.environ['PYTHONHASHSEED'] = str(seeds)
        
    set_seed(args.random_seed)
    os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
    device = "cuda"
    # Init processes
    env = Environment_Pick(max_steps_per_episode=args.step_per_trajectory, isGUI=False)
    state = env.reset()
    agent = Agent_PNet_PPO(input_size=args.input_size,
                                device=device,
                                test=True,
                                num_test_envs=args.num_test_envs,
                                test_model=args.pick_path,
                                LMapNet_pretrain_model=args.lnet_path)

    agent.init_hidden_test()
    t = 0
    c_success = 0
    num_test = args.num_test
    step = 0
    step_success = 0 
    step_success_total = 0
    total_reward = 0
    done = False

    print("*" * 20)
    # print("Test begins.")
    tic = time.time()
    state_est = agent.state_estimator(state, test=True)
    while True:
        if t >= num_test:
            break
        action = agent.step(state_est, exploration_rate=0, test=True)
        state_next, reward, done, info = env.step(action)
        states_next_est = agent.state_estimator(state_next, test=True)
        if done:
            t += 1
            print("game:", t)
            if reward == 10:
                c_success += 1
                step_success_total += step_success
            step_success = 0
            state_next = env.reset()
            agent.init_hidden_test()
            states_next_est = agent.state_estimator(state_next, test=True)
        total_reward += reward
        step += 1
        step_success += 1
        state_est = states_next_est

    time = time.time()-tic
    minute = time // 60
    sec = time % 60
    print(f'test result in {num_test} games:')
    print(f"average step: {step/num_test}")
    print(f"average success step: {step_success_total/c_success}")
    print(f"average reward: {total_reward/num_test}")
    print(f'success rate: {c_success/t}')
    print(f'Processing time: {minute}min {sec}sec')
    print("*" * 20)

    exit()