import random
import sys
sys.path.append('../env/')
sys.path.append('../policy/')
sys.path.append('../config/')
from AttrDict import AttrDict
from box_world import Environment_Decision
from PPO_agent_decision import Agent_PNet_PPO
import time
import os
import torch
import numpy as np
import argparse
import yaml

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', default='../config/test_decision.yaml', type=str)
    args = parser.parse_args()
    with open(args.config, 'r') as f:
        args = AttrDict(yaml.safe_load(f))
    print(args)

    def set_seed(seeds):
        torch.manual_seed(seeds)
        torch.cuda.manual_seed_all(seeds)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        np.random.seed(seeds)
        random.seed(seeds)
        os.environ['PYTHONHASHSEED'] = str(seeds)
        
    set_seed(args.random_seed)
    os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
    device = "cuda:0"
    # Init processes
    env = Environment_Decision(max_steps_per_episode=args.step_per_trajectory, isGUI=False, test=True)
    state = env.reset()

    agent = Agent_PNet_PPO(input_size=args.input_size,
                                    device=device,
                                    test=True,
                                    num_test_envs=args.num_test_envs,
                                    test_model=args.decision_path,
                                    pick_model_path=args.pick_path,
                                    drop_model_path=args.drop_path,
                                    LMapNet_pretrain_model=args.lnet_path)
    
    agent.init_hidden_test()
    agent.init_test()
    t = 0 
    num_test = args.num_test
    done = False
    success = 0 
    strict_fixed = 0
    remain_energy = 0
    print("*" * 20)
    # print("Test begins.")
    tic = time.time()
    state_est = agent.state_estimator([state], test=True)
    while True:
        if t >= num_test:
            break
        action = agent.step(state_est[:, 0], idx=0, exploration_rate=0, test=True)
        state_next, reward, done, info = env.step(action)
        states_next_est = agent.state_estimator([state_next], test=True)
        if done:
            t += 1
            print("game:", t)
            success += info["success"]
            strict_fixed += info["fixed_strict"]
            remain_energy += info["Energy_Remaining"]
            print(success)
            print(strict_fixed)
            print(remain_energy)
            state_next = env.reset()
            agent.init_hidden_test()
            agent.init_test()
            states_next_est = agent.state_estimator([state_next], test=True)
        state_est = states_next_est
    time = time.time()-tic
    minute = time // 60
    sec = time % 60
    print(f'test result in {num_test} games:')
    print(f'average success: {success/t}')
    print(f'average strict_fixed: {strict_fixed/t}')
    print(f'average remain_energy: {remain_energy/t}')
    print(f'Processing time: {minute}min {sec}sec')
    print("*" * 20)