import numpy as np
import sys
sys.path.append('../env/')
sys.path.append('../policy/')
sys.path.append('../config/')
from AttrDict import AttrDict
from box_world import Environment_Decision
from PPO_agent_decision import Agent_PNet_PPO
import torch
import argparse
from torch.utils.tensorboard import SummaryWriter
import time
from stable_baselines3.common.vec_env import SubprocVecEnv
import os
import yaml

os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'


class Test:

    def __init__(self, agent, start_time):
        self.agent = agent
        self.best_succeed_rate = 0
        self.best_total_reward = 0
        self.start_time = start_time

    def test(self, episode, actor_loss, value_loss):
        make_env_list = [make_env(args.step_per_trajectory) for i in range(args.num_envs_per_test)]
        env_test = SubprocVecEnv(make_env_list, start_method=args.start_method)
        rewards_test_total_list = [] 
        success_rate_list = []
        total_steps_list = []

        for _ in range(5):
            rewards_test_total = np.zeros(args.num_envs_per_test)
            success_rate = np.zeros(args.num_envs_per_test)
            steps = np.zeros(args.num_envs_per_test)
            total_steps = 0
            done_flags = [False] * args.num_envs_per_test
            self.agent.init_hidden_test()
            self.agent.init_test()
            episodes_test = 0
            state = env_test.reset()
            state_est = self.agent.state_estimator(state, test=True)
            while True:
                if episodes_test >= args.num_envs_per_test:
                    break
                steps += np.ones(args.num_envs_per_test)
                actions = self.agent.step_MP(state_est, test=True, exploration_rate=0.0)  # exploration_rate
                states_next, rewards, dones, infos = env_test.step(actions)
                for i, reward in enumerate(rewards):
                    if done_flags[i] is False:
                        rewards_test_total[i] += reward
                    if done_flags[i]:
                        continue
                    if dones[i]:
                        total_steps += steps[i]
                        success_rate[i] = infos[i]['success_rate']
                        episodes_test += 1
                        done_flags[i] = True
                states_est_next = self.agent.state_estimator(states_next, test=True)
                state_est = states_est_next
            rewards_test_total_list.append(rewards_test_total)
            success_rate_list.append(success_rate)
            total_steps_list.append(total_steps)
        reward_test_average = sum(sum(rewards_test_total_list)) / (5*args.num_envs_per_test)
        success_rate_average = sum(sum(success_rate_list)) / (5*args.num_envs_per_test)
        steps_average = sum(total_steps_list) / (5*args.num_envs_per_test) 
        # save model
        if success_rate_average >= self.best_succeed_rate:  # self.best_total_reward
            model_name = f"{time.strftime('%Y%m%d_%H%M%S')}_episode{episode:0{len_str_n_episode}d}_policy_model_best_succeed_rate.pth"
            torch.save(self.agent.target_net.state_dict(), os.path.join(model_dir, model_name))
            model_name = f"{time.strftime('%Y%m%d_%H%M%S')}_episode{episode:0{len_str_n_episode}d}_crit_model_best_succeed_rate.pth"
            torch.save(self.agent.value_net.state_dict(), os.path.join(model_dir, model_name))
            self.best_succeed_rate = success_rate_average
        if reward_test_average > self.best_total_reward:
            model_name = f"{time.strftime('%Y%m%d_%H%M%S')}_episode{episode:0{len_str_n_episode}d}_policy_model_best_total_reward.pth"
            torch.save(self.agent.target_net.state_dict(), os.path.join(model_dir, model_name))
            model_name = f"{time.strftime('%Y%m%d_%H%M%S')}_episode{episode:0{len_str_n_episode}d}_crit_model_best_total_reward.pth"
            torch.save(self.agent.value_net.state_dict(), os.path.join(model_dir, model_name))
            self.best_total_reward = reward_test_average
        # write log
        secs = int(time.time() - self.start_time)
        mins = secs // 60
        secs = secs % 60
        hours = mins // 60
        mins = mins % 60
        print('Episode:', episode, '| lr: ', args.lr,
              "| Processing time: %d hours %d minutes %d seconds" % (hours, mins, secs))
        print('Reward average:', reward_test_average,
              '| Reward best ever:', self.best_total_reward,
              '| Succeed rate:', success_rate_average,
              '| Succeed rate best ever:', self.best_succeed_rate,
              '| Average step:', steps_average, '\n'
              )
        writer.add_scalar('policy_net_loss', actor_loss, episode)
        writer.add_scalar('value_net_loss', value_loss, episode)
        writer.add_scalar('success_rate', success_rate_average, episode)
        writer.add_scalar('reward_average_per_game', reward_test_average, episode)

def make_env(n):
    def _init():
        env = Environment_Decision(max_steps_per_episode=n, isGUI=False)
        return env
    return _init

def main():
    make_env_list = [make_env(args.step_per_trajectory) for i in range(args.num_processes)]
    envs = SubprocVecEnv(make_env_list, start_method=args.start_method)
    states = envs.reset()
    agent = Agent_PNet_PPO(input_size=args.input_size,
                                 device=device,
                                 test=False,
                                 discount=args.discount,
                                 num_envs=args.num_processes,
                                 num_test_envs=args.num_envs_per_test,
                                 lr=args.lr,
                                 batch_size=args.batch_size,
                                 pick_model_path=args.pick_path,
                                 drop_model_path=args.drop_path,
                                 LMapNet_pretrain_model=args.lnet_path)
    start_time = time.time()
    agent.update_target()
    trajectory = 0
    actor_loss = 0
    value_loss = 0
    c_steps = 0
    total_step = 0
    step = 0
    # train
    episode = 0
    start_time = time.time()
    test = Test(agent, start_time)
    agent.reset_agent()
    states_est = agent.state_estimator(states)
    # train
    while True:
        if episode >= args.n_episodes:
            break
        # train after a number of state-action pairs are collected
        if total_step > args.batch_size:
            # print("total_step:", total_step)
            agent.get_adv_func()
            c_steps += total_step
            for i in range(5):
                actor_loss, value_loss = agent.learn()
            agent.clear_replay_memory()
            if episode % 10 == 0:
                print(f"Learn in episode {episode},"
                      f" {c_steps} steps, in {trajectory} games.")
                secs = int(time.time() - start_time)
                mins = secs // 60
                secs = secs % 60
                hours = mins // 60
                mins = mins % 60
                print('Episode:', episode, '| lr: ', args.lr,
                    "| Processing time: %d hours %d minutes %d seconds" % (hours, mins, secs))
                c_steps = 0
                trajectory = 0
            total_step = 0
            agent.update_target()
            episode += 1
            step = 0

        # test
        if episode % args.test_period_episode == 1 and step == 0:
            print(f"\nTest {5*args.num_envs_per_test} times in episode {episode}")
            test.test(episode, actor_loss, value_loss)
        # collect steps
        step += 1
        actions = agent.step_MP(states_est, test=False, exploration_rate=0.0)
        states_next, rewards, dones, infos = envs.step(actions)
        num_new_records = agent.store_replay_memory_MP(rewards, dones)
        total_step += num_new_records
        for idx, done in enumerate(dones):
            if done:  # type(done) is numpy.bool_, so cannot use "done is True". "done == True" is ok
                trajectory += 1
                agent.compute_rtgs(idx)
        states_next = agent.state_estimator(states_next)
        states_est = states_next
    writer.close()
    print('finished')

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', default='../config/train_decision.yaml', type=str)
    args = parser.parse_args()
    with open(args.config, 'r') as f:
        args = AttrDict(yaml.safe_load(f))
    print(args)
    
    len_str_n_episode = len(str(args.n_episodes))
    out_dir = time.strftime("%Y%m%d%H") + "_lr_" + str(args.lr) + "_batchSize_" + str(args.batch_size) + \
        "_discount_" + str(args.discount) + "/"
    model_dir = os.path.join(os.path.realpath(args.model_dir), out_dir)
    log_dir = os.path.join(os.path.realpath(args.log_dir), out_dir)
    os.makedirs(model_dir, exist_ok=True)
    os.makedirs(log_dir, exist_ok=True)
    writer = SummaryWriter(log_dir)
    device = torch.device('cuda:1') if torch.cuda.is_available() else torch.device('cpu')
    print('Device:', device)
    main()
