import sys
import torch.cuda
# if '/opt/ros/kinetic/lib/python2.7/dist-packages' in sys.path:
#     sys.path.remove('/opt/ros/kinetic/lib/python2.7/dist-packages')
import time
from model import *
from tools import *
from envs import make_vec_envs
import numpy as np
import random
from train_tools_asap import train_tools_proposal
from tensorboardX import SummaryWriter
from tools import get_args, registration_envs
from insCreator import insCreator
from torch.distributions.categorical import Categorical





dataset_list = [] # input your generated dataset name
seq_len_list = [] # with its length
num_steps = [] # can be same as your seq length (The actual steps you want to run)

def main(args):

    # The name of this experiment, related file backups and experiment tensorboard logs will
    # be saved to '.\logs\experiment' and '.\logs\runs'
    # custom = input('Please input the experiment name\n')
    # timeStr = custom + '-' + time.strftime('%Y.%m.%d-%H-%M-%S', time.localtime(time.time()))
    default_mod = True
    timeStr = time.strftime('%Y.%m.%d-%H-%M-%S', time.localtime(time.time()))
    train_steps = 10 if args.continuous else 20
    num_episode = 30 if args.continuous else 50
    # print(num_steps)
    load_model = True
    args.epochs = num_episode
    args.num_steps = max(num_steps)
    args.internal_node_holder = max(80,args.num_steps)
    args.next_holder = args.num_steps
    print(args)

    dataset = []
    box_set_list = []

    for i in range(len(dataset_list)):
        dataset_name = dataset_list[i]
        if args.continuous:
            data_save_path = os.path.join('./test_dataset/continuous_dataset',dataset_name)
        else:
            data_save_path = os.path.join('./test_dataset/discrete_dataset',dataset_name)
        seq_len_set = [seq_len_list[i]]
        dist, ins, box_set = load_multiple_datasets(data_save_path,seq_len_set)
        dataset.append({'instances':ins[seq_len_list[i]],'distributions':dist[seq_len_list[i]]})
        box_set_list.append(box_set)


    if args.no_cuda:
        device = torch.device('cpu')
    else:
        device = torch.device('cuda', args.device)
        torch.cuda.set_device(args.device)

    # Backup all py files and create tensorboard logs
    # backup(timeStr, args, None)
    if args.continuous:
        log_writer_path = './logs/runs/continuous-{}'.format('PCT-' + timeStr)
    else:
        log_writer_path = './logs/runs/discrete-{}'.format('PCT-' + timeStr)

    if not os.path.exists(log_writer_path):
        os.makedirs(log_writer_path)
    writer = SummaryWriter(logdir=log_writer_path)

    # Create parallel packing environments to collect training samples online
    envs = make_vec_envs(args, './logs/runinfo', True)

    # Create the main actor & critic networks of PCT
    if args.allow_dist_input:
        action_policy =  DRL_GAT_search_with_dist(args,device,search_model=True)
        action_policy =  action_policy.to(device)
    else:
        action_policy =  DRL_GAT(args,search_model=True)
        action_policy =  action_policy.to(device)

    proposal_head =  DRL_GAT(args)
    proposal_head =  proposal_head.to(device)


    ins_policy = insCreator()

    # Perform all training.
    
    trainTool = train_tools_proposal(writer, timeStr, action_policy, proposal_head, ins_policy, args)

    if load_model:
        tag = 'meta-pct'
        trainTool.load_action_model(args.model_path,tag,args.sub_time_str)
        trainTool.load_proposal_model(args.model_path,tag,args.sub_time_str)


    # trainTool.train_n_steps(envs, args, train_steps, num_episode, device, log_writer_path, dataset, 20, num_steps, box_set_list, change_box_set=default_mod)
    if args.continuous:
        trainTool.train_n_steps_proposal(envs, args, train_steps, num_episode, device, log_writer_path, dataset, 20, num_steps, box_set_list, change_box_set=default_mod)
        trainTool.train_n_steps_action(envs, args, train_steps, num_episode, device, log_writer_path, dataset, 20, num_steps, box_set_list, change_box_set=default_mod)
    else:
        trainTool.train_n_steps_action(envs, args, train_steps, num_episode, device, log_writer_path, dataset, 20, num_steps, box_set_list, change_box_set=default_mod)
        trainTool.train_n_steps_proposal(envs, args, train_steps, num_episode, device, log_writer_path, dataset, 20, num_steps, box_set_list, change_box_set=default_mod)


    writer.close()

if __name__ == '__main__':
    registration_envs()
    args = get_args()
    main(args)