#  python online_real.py --mode train --load-model False --item-seq load
# import faulthandler
# faulthandler.enable()
import sys
if '/opt/ros/kinetic/lib/python2.7/dist-packages' in sys.path:
    sys.path.remove('/opt/ros/kinetic/lib/python2.7/dist-packages')
import os
import time
from collections import deque
import numpy as np
import torch
import math
from shutil import copyfile
import config
from acktr import algo, utils
from acktr.utils import get_mask_from_obs, get_rotmask_from_obs
from acktr.envs import make_vec_envs
from acktr.arguments import get_args
from acktr.model import Policy
from acktr.storage import RolloutStorage
from evaluation import evaluate
from tensorboardX import SummaryWriter
from unified_test import unified_test
# python online_real.py --mode train --load-model False --use-cuda --item-seq sample
def random_dataset(num):
    default_box_set = []
    for i in range(3):
        for j in range(2):
            for k in range(3):
                default_box_set.append((4+i, 3+j, 2+k))
    return default_box_set[num]

def main(args):
    # input arguments about environment
    config.test = (args.mode == 'test')
    # config.load_name = args.load_name
    config.data_name = args.data_name
    
    if not config.test:
        config.data_type = args.item_seq
    config.cuda = args.use_cuda and torch.cuda.is_available()
    config.no_cuda = not config.cuda

    if config.test:
        test_model()
    else:
        train_model()

def test_model():
    assert config.test is True
    model_url = config.load_dir + config.load_name
    unified_test(model_url, config)

def train_model():
    custom = input('please input the test name: ')
    time_now = time.strftime('%Y.%m.%d-%H-%M', time.localtime(time.time()))

    env_name = config.env_name
    torch.cuda.set_device(config.device)
    # set random seed
    # torch.manual_seed(config.seed)
    # torch.cuda.manual_seed_all(config.seed)
    
    save_path = config.save_dir
    load_path = config.load_dir

    if not os.path.exists(save_path):
        os.makedirs(save_path)
    if not os.path.exists(load_path):
        os.makedirs(load_path)
    data_path = os.path.join(save_path, custom)
    try:
        os.makedirs(data_path)
    except OSError:
        pass
    
    if config.cuda and torch.cuda.is_available() and config.cuda_deterministic:
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

    log_dir = os.path.expanduser(config.log_dir)
    eval_log_dir = log_dir + "_eval"
    utils.cleanup_log_dir(log_dir)
    utils.cleanup_log_dir(eval_log_dir)

    torch.set_num_threads(1)
    device = torch.device("cuda:" + str(config.device) if config.cuda else "cpu")
    envs = make_vec_envs(env_name, config.seed, config.num_processes, config.gamma, config.log_dir, device, True)
    if config.pretrain:
        print("___________LOAD PRETRAIN MODEL____________")
        model_pretrained, ob_rms = torch.load(os.path.join(load_path, config.load_name))
        actor_critic = Policy(
            envs.observation_space.shape, envs.action_space,
            base_kwargs={'recurrent': config.recurrent_policy, 'hidden_size': config.hidden_size})
        load_dict = {k.replace('module.', ''): v for k, v in model_pretrained.items()}
        load_dict = {k.replace('add_bias.', ''): v for k, v in load_dict.items()}
        load_dict = {k.replace('_bias', 'bias'): v for k, v in load_dict.items()}
        # for k, v in load_dict.items():
        #     if len(v.size()) <= 3:
        #         load_dict[k] = v.squeeze(dim=-1)
        actor_critic.load_state_dict(load_dict)
        # print(actor_critic)
        setattr(utils.get_vec_normalize(envs), 'ob_rms', ob_rms)

        model_pretrained_sub, ob_rms = torch.load(os.path.join(load_path, config.load_name_sub))
        actor_critic_sub = Policy(
            None, envs.action_space,
            base_kwargs={'recurrent': config.recurrent_policy, 'hidden_size': config.hidden_size})
        load_dict_sub = {k.replace('module.', ''): v for k, v in model_pretrained_sub.items()}
        load_dict_sub = {k.replace('add_bias.', ''): v for k, v in load_dict_sub.items()}
        load_dict_sub = {k.replace('_bias', 'bias'): v for k, v in load_dict_sub.items()}
        # for k, v in load_dict_sub.items():
        #     if len(v.size()) <= 3:
        #         load_dict_sub[k] = v.squeeze(dim=-1)
        actor_critic_sub.load_state_dict(load_dict_sub)
        # print(actor_critic_sub)
        
    else:
        actor_critic = Policy(
            envs.observation_space.shape, envs.action_space,
            base_kwargs={'recurrent': config.recurrent_policy, 'hidden_size': config.hidden_size})
        actor_critic_sub = Policy(
            None, envs.action_space,
            base_kwargs={'recurrent': config.recurrent_policy, 'hidden_size': config.hidden_size})
        # print(actor_critic)
    print("Rotation:", config.enable_rotation)
    actor_critic.to(device)
    actor_critic_sub.to(device)
    # transformer_trainer = Transformer(device)
    
    # leave a backup for parameter tuning
    copyfile('config.py', os.path.join(data_path, 'config.py'))
    copyfile('online_real.py', os.path.join(data_path, 'main.py'))
    copyfile('./acktr/envs.py', os.path.join(data_path, 'envs.py'))
    copyfile('./acktr/distributions.py', os.path.join(data_path, 'distributions.py'))
    copyfile('./acktr/storage.py', os.path.join(data_path, 'storage.py'))
    copyfile('./acktr/model.py', os.path.join(data_path, 'model.py'))
    copyfile('./acktr/algo/acktr_pipeline.py', os.path.join(data_path, 'acktr_pipeline.py'))
    copyfile('./dataset.py', os.path.join(data_path, 'dataset.py'))
        
    if config.algo == 'a2c':
        agent_packing = algo.ACKTR(actor_critic,
                       config.value_loss_coef,
                       config.env_loss_coef,
                       config.entropy_coef,
                       config.invalid_coef,
                       config.lr,
                       config.eps,
                       config.alpha,
                       config.max_grad_norm
                           )
    elif config.algo == 'acktr':
        agent_packing = algo.ACKTR(actor_critic,
            config.value_loss_coef,
            config.env_loss_coef,
            config.entropy_coef,
            config.invalid_coef,
            acktr=True)
        agent_packing_sub = algo.ACKTR(actor_critic_sub,
            config.value_loss_coef,
            config.env_loss_coef,
            config.entropy_coef,
            config.invalid_coef,
            acktr=True)
    size_x = config.container_size[0]
    size_y = config.container_size[1]
    rollouts = RolloutStorage(config.num_steps,  # forward steps
                              config.num_processes,  # agent processes
                              envs.observation_space.shape,
                              envs.action_space,
                              actor_critic.recurrent_hidden_state_size,
                              hmap_shape=100,
                              can_give_up=config.give_up,
                              enable_rotation=config.enable_rotation,
                              pallet_size=size_x*size_y)
    rollouts.to(device)
    
    episode_rewards = deque(maxlen=config.num_processes)
    episode_ratio = deque(maxlen=config.num_processes)

    start = time.time()
    num_updates = int(config.num_env_steps) // config.num_steps // config.num_processes

    if not os.path.exists('{}/{}/{}'.format(config.tbx_dir, env_name, custom)):
        os.makedirs('{}/{}/{}'.format(config.tbx_dir, env_name, custom))
    if config.tensorboard:
        writer = SummaryWriter(logdir='{}/{}/{}'.format(config.tbx_dir, env_name, custom))
    
    index = 0
    
    # Prep
    ###############################################
    box_list = envs.reset()
    hmap = torch.zeros(config.num_processes, config.channel, size_x*10, size_y*10).cuda()
    hmap_sub = torch.zeros(config.num_processes, config.channel, 10, 10).cuda()
    hmap = envs.get_hmap()
    rollouts.hmap[0].copy_(hmap)

    obs = envs.get_obses()
    location_masks = []
    for observation in obs:
        if not config.enable_rotation:
            box_mask = get_mask_from_obs(1, observation)
        else:
            box_mask = get_rotmask_from_obs(1, observation)
        location_masks.append(box_mask)
    location_masks = torch.FloatTensor(location_masks).to(device)
    rollouts.obs[0].copy_(obs)
    rollouts.location_masks[0].copy_(location_masks)
    
    for env_steps in range(int(config.num_env_steps)):
        # if config.use_linear_lr_decay:
        #     # decrease learning rate linearly
        #     utils.update_linear_schedule(
        #         agent_packing.optimizer, index, num_updates,
        #         agent_packing.optimizer.lr if config.algo == "acktr" else config.lr)
        #     utils.update_linear_schedule(
        #         agent_packing_sub.optimizer, index, num_updates,
        #         agent_packing_sub.optimizer.lr if config.algo == "acktr" else config.lr)
        
        # Update each config.num_steps
        ###############################################
        for step in range(config.num_steps):
            # Get Action
            ###############################################
            with torch.no_grad():
                value, action, action_log_prob, recurrent_hidden_states, env_loss = actor_critic.act(
                    rollouts.obs[step], rollouts.hmap[step], rollouts.recurrent_hidden_states[step],
                    rollouts.masks[step], location_masks, rollouts.box[step])
            box = rollouts.hmap[step,:,1:5,0,0]
            rollouts.box[step + 1,:,:-1] = rollouts.box[step,:,1:]
            rollouts.box[step + 1,:,-1] = box
            
            for proc, act in enumerate(action):
                ly = act % 10
                lx = ((act%(size_x*size_y))-ly) // 10
                hmap_sub[proc] = hmap[proc,:,lx*10:(lx+1)*10,ly*10:(ly+1)*10]
            # Get Subaction
            ###############################################
            obs_sub = envs.get_obses_sub(action.cpu().detach())
            
            location_masks_sub = []
            for observation in obs_sub:
                box_mask = get_mask_from_obs(1, observation)
                location_masks_sub.append(box_mask)
            location_masks_sub = torch.FloatTensor(location_masks_sub).to(device)
            
            # rollouts.obs_sub[0].copy_(obs_sub)
            # rollouts.location_masks_sub[0].copy_(location_masks_sub)
            with torch.no_grad():
                value_sub, action_sub, action_log_prob_sub, recurrent_hidden_states_sub = actor_critic_sub.act_sub(
                    obs_sub, hmap_sub, rollouts.recurrent_hidden_states_sub[step], rollouts.masks[step], location_masks_sub)
            
            rollouts.insert_sub(obs_sub, hmap_sub, recurrent_hidden_states_sub, action_sub, action_log_prob_sub, value_sub, location_masks_sub)
            # print('action_sub',action_sub)

            # Calc action and act in env
            ###############################################
            for proc, act in enumerate(action):
                ly = act % 10
                lx = ((act%(size_x*size_y))-ly) // 10
                lyy = action_sub[proc] % 10
                lxx = (action_sub[proc]-lyy) // 10
                action_sub[proc] = math.floor(act/(size_x*size_y))*(size_x*size_y)*100 + lx*1000 + ly*10 + lxx*100 + lyy
            layout, reward, done, infos= envs.step(action_sub)

            hmap = envs.get_hmap()
            rollouts.rewards[step].copy_(reward)
            
            obs = envs.get_obses()
            for i in range(len(infos)):
                    if 'episode' in infos[i].keys():
                        episode_rewards.append(infos[i]['episode']['r'])
                        episode_ratio.append(infos[i]['ratio'])

            location_masks = []
            for observation in obs:
                if not config.enable_rotation:
                    box_mask = get_mask_from_obs(1, observation)
                else:
                    box_mask = get_rotmask_from_obs(1, observation)
                location_masks.append(box_mask)
            location_masks = torch.FloatTensor(location_masks).to(device)
            masks = torch.FloatTensor([[0.0] if done_ else [1.0] for done_ in done])
            bad_masks = torch.FloatTensor([[0.0] if 'bad_transition' in info.keys() else [1.0] for info in infos])
            
            rollouts.insert(obs, hmap, recurrent_hidden_states, action, action_log_prob, env_loss, value, masks, bad_masks, location_masks)

        # Update Packing Agent
        ###############################################
        with torch.no_grad():
            next_value = actor_critic.get_value(
                rollouts.obs[-1], rollouts.hmap[-1], rollouts.recurrent_hidden_states[-1],
                rollouts.masks[-1], rollouts.box[-1]).detach()
        rollouts.compute_returns(next_value, False, config.gamma, 0.95, config.use_proper_time_limits)
        value_loss, env_loss, action_loss, dist_entropy, prob_loss = agent_packing.update(rollouts)
                
        with torch.no_grad():
            next_value_sub = actor_critic_sub.get_value_sub(
                rollouts.obs_sub[-1], rollouts.hmap_sub[-1], rollouts.recurrent_hidden_states_sub[-1],
                rollouts.masks[-1]).detach()
        rollouts.compute_returns_sub(next_value_sub, False, config.gamma, 0.95, config.use_proper_time_limits)
        value_loss_sub, action_loss_sub, dist_entropy_sub, prob_loss_sub = agent_packing_sub.update_sub(rollouts)
        rollouts.after_update()
        # Save model
        ###############################################
        if config.save_model:
            if (index % config.save_interval == 0) and config.save_dir != "":
                time_now = time.strftime('%Y.%m.%d-%H-%M', time.localtime(time.time()))
                torch.save([
                    actor_critic.state_dict(),
                    getattr(utils.get_vec_normalize(envs), 'ob_rms', None)
                ], os.path.join(data_path, env_name + time_now + ".pt"))
                torch.save([
                    actor_critic_sub.state_dict(),
                    getattr(utils.get_vec_normalize(envs), 'ob_rms', None)
                ], os.path.join(data_path, env_name + time_now + "_s.pt"))

        # print useful information of training
        index += 1
        if index % config.log_interval == 0 and len(episode_rewards) > 1:
            total_num_steps = index * config.num_processes * config.num_steps
            end = time.time()
            print(
                "The algorithm is {}, the recurrent policy is {}\nThe env is {}, the version is {}".format(
                    config.algo, config.recurrent_policy, env_name, custom))
            print(
                "Updates {}, num timesteps {}, FPS {} \n"
                "Last {} training episodes: mean/median reward {:.1f}/{:.1f}, var reward {:.1f}, min/max reward {:.1f}/{:.1f}\n"
                "The dist entropy {:.5f}, The value loss {:.5f}, The env loss {:.5f},the action loss {:.5f}\n"
                "The mean space ratio is {}\n"
                    .format(index*config.num_steps, total_num_steps,
                            int(total_num_steps / (end - start)),
                            len(episode_rewards), np.mean(episode_rewards),
                            np.median(episode_rewards), np.var(episode_rewards), np.min(episode_rewards),
                            np.max(episode_rewards), dist_entropy, value_loss, env_loss, action_loss,
                            np.mean(episode_ratio)))
            
            path = os.path.join(data_path, 'log.txt')
            with open(path,"a") as f:
                f.write("{:.5f}\t, {:.5f}\t, {:.5f}\t, {:.5f}\n".format(np.mean(episode_ratio),action_loss+value_loss,np.var(episode_rewards),env_loss)) 

if __name__ == "__main__":
    args = get_args()
    main(args)