import sys
import torch.cuda
import os
# if '/opt/ros/kinetic/lib/python2.7/dist-packages' in sys.path:
#     sys.path.remove('/opt/ros/kinetic/lib/python2.7/dist-packages')
import time
from model import *
from tools import *
from envs import make_vec_envs
import numpy as np
import random
from train_tools_ppo import train_tools_ppo
from train_tools_acktr import train_tools_acktr
from tensorboardX import SummaryWriter
from tools import get_args, registration_envs
from insCreator import insCreator
from torch.distributions.categorical import Categorical
import torch

def generate_dataset(ins_policy,box_set,num_distribution,batch_size,seq_len,save_path,device,continuous):

    save_path = os.path.join(save_path, 'num-box-{}'.format(seq_len))

    if not os.path.exists(save_path):
        os.makedirs(save_path)
    

    for step in range(num_distribution):

        # generate distributions
        distribution = torch.rand((box_set.size(0),)).repeat((batch_size,)).view((batch_size,box_set.size(0))).to(device)
        distribution = distribution/torch.sum(distribution,dim=1).view((-1,1))
        
        # generate instances
        with torch.no_grad():
            ins, _ = ins_policy(box_set,seq_len,batch_size,distribution,deterministic=False, random_mode=True,continuous=continuous)

        print(ins)
        data = {
            'dist':distribution,'ins':ins
        }

        # save data
        torch.save(
            data,
            os.path.join(save_path, 'instance-{}.pt'.format(step))
        )




def main(args):

    # The name of this experiment, related file backups and experiment tensorboard logs will
    # be saved to '.\logs\experiment' and '.\logs\runs'
    # custom = input('Please input the experiment name\n')
    # timeStr = custom + '-' + time.strftime('%Y.%m.%d-%H-%M-%S', time.localtime(time.time()))
    # print(args)
    continuous = args.continuous
    seq_len_set = [70,80]

    if args.no_cuda:
        device = torch.device('cpu')
    else:
        device = torch.device('cuda', args.device)
        torch.cuda.set_device(args.device)

    least_size = args.least_size
    least_size_1 = args.least_size
    if least_size > least_size_1:
        least_size, least_size_1 = least_size_1, least_size
    cross_range = args.cross_range
    box_set = []
    for i in range(cross_range):
        for j in range(cross_range):
            for k in range(cross_range):
                out_shape_1 = [least_size_1+i, least_size_1+j, least_size+k]
                out_shape_1.sort()
                box_set.append(tuple(out_shape_1))
                out_shape_2 = [least_size+i, least_size+j, least_size_1+k]
                out_shape_2.sort()
                box_set.append(tuple(out_shape_2))

    
    new_box_set = list(set(box_set))
    new_box_set.sort(key=box_set.index)
    # random.shuffle(new_box_set)
    # new_box_set = new_box_set[0:20]
    # new_box_set.sort(key=box_set.index)
    print(new_box_set)
    box_set = torch.tensor(new_box_set)

    ins_policy = insCreator()

    num_distribution = 100

    tag = 'id'
    if continuous:
        save_path = './test_dataset/continuous_dataset/{}_box_set_{}_{}_{}'.format(tag, least_size,least_size_1,cross_range)
    else:
        save_path = './test_dataset/discrete_dataset/{}_box_set_{}_{}_{}'.format(tag, least_size,least_size_1,cross_range)
    if not os.path.exists(save_path):
        os.makedirs(save_path)

    torch.save(
        box_set,
        os.path.join(save_path, 'box-set.pt')
    )

    for i in range(len(seq_len_set)):
        generate_dataset(ins_policy,box_set,num_distribution,args.num_processes,seq_len_set[i],save_path,device,continuous)


if __name__ == '__main__':
    args = get_args()
    main(args)