import sys
import torch.cuda
# if '/opt/ros/kinetic/lib/python2.7/dist-packages' in sys.path:
#     sys.path.remove('/opt/ros/kinetic/lib/python2.7/dist-packages')
import time
from model import *
from tools import *
from envs import make_vec_envs
import numpy as np
import random
from train_tools_meta import train_tools_meta
from tensorboardX import SummaryWriter
from tools import get_args, registration_envs
from insCreator import insCreator
from torch.distributions.categorical import Categorical



dataset_list = [] # input your generated dataset name
seq_len_list = [] # with its length
num_steps = [] # can be same as your seq length (The actual steps you want to run)

def main(args):

    # The name of this experiment, related file backups and experiment tensorboard logs will
    # be saved to '.\logs\experiment' and '.\logs\runs'
    # custom = input('Please input the experiment name\n')
    # timeStr = custom + '-' + time.strftime('%Y.%m.%d-%H-%M-%S', time.localtime(time.time()))
    
    timeStr = time.strftime('%Y.%m.%d-%H-%M-%S', time.localtime(time.time()))
    train_steps = 8 if args.continuous else 16
    num_episode = 70 if args.continuous else 150
    load_model = False # whether to load from checkpoint
    args.epochs = num_episode
    dataset = []
    box_set_list = []

    for i in range(len(dataset_list)):
        dataset_name = dataset_list[i]
        if args.continuous:
            data_save_path = os.path.join('./dataset/continuous_dataset',dataset_name)
        else:
            data_save_path = os.path.join('./dataset/discrete_dataset',dataset_name)
        seq_len_set = [seq_len_list[i]]
        dist, ins, box_set = load_multiple_datasets(data_save_path,seq_len_set)
        dataset.append({'instances':ins[seq_len_list[i]],'distributions':dist[seq_len_list[i]]})
        box_set_list.append(box_set)

    args.num_steps = max(num_steps)
    args.internal_node_holder = max(80,args.num_steps)
    args.next_holder = args.num_steps
    print(args)


    if args.no_cuda:
        device = torch.device('cpu')
    else:
        device = torch.device('cuda', args.device)
        torch.cuda.set_device(args.device)

    # Backup all py files and create tensorboard logs
    # backup(timeStr, args, None)
    if args.continuous:
        log_writer_path = './logs/runs/continuous-{}'.format('PCT-' + timeStr)
    else:
        log_writer_path = './logs/runs/discrete-{}'.format('PCT-' + timeStr)
        
    if not os.path.exists(log_writer_path):
        os.makedirs(log_writer_path)
    writer = SummaryWriter(logdir=log_writer_path)

    # Create parallel packing environments to collect training samples online
    envs = make_vec_envs(args, './logs/runinfo', True)

    # Create the main actor & critic networks of PCT
    PCT_policy =  DRL_GAT(args)
    PCT_policy =  PCT_policy.to(device)

    PCT_inner_policy = DRL_GAT(args)
    PCT_inner_policy =  PCT_inner_policy.to(device)


    ins_policy = insCreator(3,128,512,2,3,8,1000)

    # Perform all training.

    # trainTool = train_tools_ppo(writer, timeStr, PCT_policy, ins_policy, args)
    trainTool = train_tools_meta(writer, timeStr, PCT_policy, PCT_inner_policy, ins_policy, args)

    if load_model:
        # load from checkpoint 
        tag = 'meta-pct'
        trainTool.load_model(args.model_path,tag,args.sub_time_str)

    trainTool.train_n_steps(envs, args, train_steps, num_episode, device, log_writer_path, dataset, 20, num_steps, box_set_list, default_mod=True)

    writer.close()

if __name__ == '__main__':
    registration_envs()
    args = get_args()
    main(args)