import numpy as np
import sys
sys.path.append('../env/')
sys.path.append('../policy/')
sys.path.append('../config/')
from AttrDict import AttrDict
from box_world import Environment_Pick
from PPO_agent_pick import Agent_PNet_PPO
import torch
import argparse
from torch.utils.tensorboard import SummaryWriter
import time
from stable_baselines3.common.vec_env import SubprocVecEnv
import os
import yaml

os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'

class Test:
    def __init__(self, agent, start_time):
        self.agent = agent
        self.best_succeed_rate = 0
        self.best_total_reward = 0
        self.start_time = start_time
    def test(self, episode, actor_loss, value_loss):
        step = 0
        self.agent.init_hidden_test()
        total_rewards = np.zeros(1)
        total_reward = 0
        episodes_test = 0
        n_succeed = 0
        env_test = Environment_Pick(isGUI=False, max_steps_per_episode=150)
        state = env_test.reset()
        state_est = self.agent.state_estimator(state, test=True)
        while True:
            if episodes_test >= args.num_episodes_per_test:
                break
            step += 1
            action = self.agent.step(state_est, test=True, exploration_rate=0.0)  # exploration_rate
            states_next, rewards, dones, infos = env_test.step(action)
            states_next = self.agent.state_estimator(states_next, test=True)
            total_reward += rewards
            if dones:
                episodes_test += 1
                if rewards == 10:
                    n_succeed += 1
                total_rewards += total_reward
                total_reward = 0
                states_next = env_test.reset()
                self.agent.init_hidden_test()
                states_next = self.agent.state_estimator(states_next, test=True)
            state_est = states_next
            # self.agent.clear_replay_memory()
        self.agent.clear_replay_memory()
        succeed_rate = n_succeed / episodes_test
        average_total_reward = total_rewards / episodes_test
        # save model
        if succeed_rate >= self.best_succeed_rate:  # self.best_total_reward
            model_name = f"{time.strftime('%Y%m%d_%H%M%S')}_episode{episode:0{len_str_n_episode}d}_policy_model_best_succeed_rate.pth"
            torch.save(self.agent.target_net.state_dict(), os.path.join(model_dir, model_name))
            model_name = f"{time.strftime('%Y%m%d_%H%M%S')}_episode{episode:0{len_str_n_episode}d}_crit_model_best_succeed_rate.pth"
            torch.save(self.agent.value_net.state_dict(), os.path.join(model_dir, model_name))
            self.best_succeed_rate = succeed_rate

        if average_total_reward > self.best_total_reward:
            model_name = f"{time.strftime('%Y%m%d_%H%M%S')}_episode{episode:0{len_str_n_episode}d}_policy_model_best_total_reward.pth"
            torch.save(self.agent.target_net.state_dict(), os.path.join(model_dir, model_name))
            model_name = f"{time.strftime('%Y%m%d_%H%M%S')}_episode{episode:0{len_str_n_episode}d}_crit_model_best_total_reward.pth"
            torch.save(self.agent.value_net.state_dict(), os.path.join(model_dir, model_name))
            self.best_total_reward = average_total_reward
        # write log
        secs = int(time.time() - self.start_time)
        mins = secs // 60
        secs = secs % 60
        hours = mins // 60
        mins = mins % 60
        print('Episode:', episode, '| lr: ', args.lr,
              "| Processing time: %d hours %d minutes %d seconds" % (hours, mins, secs))
        print('Reward average:', average_total_reward,
              '| Reward best ever:', self.best_total_reward,
              '| Succeed rate:', succeed_rate,
              '| Succeed rate best ever:', self.best_succeed_rate,
              '| Average step:', step/args.num_episodes_per_test, '\n'
              )
        writer.add_scalar('policy_net_loss', actor_loss, episode)
        writer.add_scalar('value_net_loss', value_loss, episode)
        writer.add_scalar('success_rate', succeed_rate, episode)
        writer.add_scalar('reward_average_per_game', average_total_reward, episode)

def make_env(n):
    def _init():
        env = Environment_Pick(max_steps_per_episode=n, isGUI=False)
        return env

    return _init

def main():
    make_env_list = [make_env(args.step_per_trajectory) for i in range(args.num_processes)]
    envs = SubprocVecEnv(make_env_list, start_method=args.start_method)
    states = envs.reset()
    agent = Agent_PNet_PPO(input_size=args.input_size,
                                 device=device,
                                 test=False,
                                 lr=args.lr,
                                 discount=args.discount,
                                 batch_size=args.batch_size,
                                 num_envs=args.num_processes,
                                 num_test_envs=args.num_envs_per_test,
                                 LMapNet_pretrain_model=args.lnet_path)
    agent_test = Agent_PNet_PPO(input_size=args.input_size,
                                      device=device,
                                      test=False,
                                      lr=args.lr,
                                      discount=args.discount,
                                      batch_size=args.batch_size,
                                      num_envs=args.num_envs_per_test,
                                      num_test_envs=args.num_envs_per_test,
                                      LMapNet_pretrain_model=args.lnet_path)
    start_time = time.time()
    agent.update_target()
    trajectory = 0
    actor_loss = 0
    value_loss = 0
    c_success = 0
    c_steps = 0
    step = 0
    steps = np.zeros(args.num_processes)
    total_step = 0
    # train
    episode = 0
    start_time = time.time()
    test = Test(agent_test, start_time)
    agent.reset_agent()
    states_est = agent.state_estimator(states)
    # train
    while True:
        if episode >= args.n_episodes:
            break
        # train after a number of state-action pairs are collected
        if total_step > args.batch_size:
            agent.get_adv_func()
            c_steps += total_step
            for i in range(5):
                actor_loss, value_loss = agent.learn()
            agent.clear_replay_memory()
            if episode % 10 == 0:
                print(f"Learn in episode {episode}, {c_success} succeed(s) in"
                      f" {c_steps} steps, in {trajectory} games.")
                c_success = 0
                c_steps = 0
                trajectory = 0
            total_step = 0
            steps = np.zeros(args.num_processes)
            agent.update_target()
            episode += 1
            step = 0
        # test
        if episode % args.test_period_episode == 1 and step == 0:
            print(f"\nTest {args.num_episodes_per_test} times in episode {episode}")
            test.agent.target_net.load_state_dict(agent.policy_net.state_dict())
            test.agent.value_net.load_state_dict(agent.value_net.state_dict())
            test.test(episode, actor_loss, value_loss)
        # collect steps
        step += 1
        steps += np.ones(args.num_processes)
        actions = agent.step_mp(states_est, exploration_rate=0.0)
        states_next, rewards, dones, infos = envs.step(actions)
        agent.store_replay_memory(states_est, actions, rewards, dones)
        for idx, done in enumerate(dones):
            if done:  # type(done) is numpy.bool_, so cannot use "done is True". "done == True" is ok
                total_step += steps[idx]
                steps[idx] = 0
                trajectory += 1
                if rewards[idx] == 10:
                    c_success += 1
                agent.reset_per_memory(idx)
                agent.compute_rtgs(idx)
        states_next = agent.state_estimator(states_next)
        states_est = states_next
    writer.close()
    print('finished')
    
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', default='../config/train_pick.yaml', type=str)
    args = parser.parse_args()
    with open(args.config, 'r') as f:
        args = AttrDict(yaml.safe_load(f))
    print(args)
    
    len_str_n_episode = len(str(args.n_episodes))
    out_dir = time.strftime("%Y%m%d%H") + "_lr_" + str(args.lr) + "_batchSize_" + str(args.batch_size) + \
        "_discount_" + str(args.discount) + "/"
    model_dir = os.path.join(os.path.realpath(args.model_dir), out_dir)
    log_dir = os.path.join(os.path.realpath(args.log_dir), out_dir)
    os.makedirs(model_dir, exist_ok=True)
    os.makedirs(log_dir, exist_ok=True)
    writer = SummaryWriter(log_dir)
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    print('Device:', device)
    main()
    
