# python online_test.py --mode test --load-model True --use-cuda --data-name cut_2.pt --load-name RealBpp-v02022.09.20-18-06.pt --load-name-sub RealBpp-v02022.09.20-18-06_s.pt --cases 49
# python online_test.py --mode test --load-model True --use-cuda --data-name cut_2.pt --load-name RealBpp-v02022.09.29-15-51.pt --load-name-sub RealBpp-v02022.09.29-15-51_s.pt --cases 500
# python online_test.py --mode test --load-model True --use-cuda --data-name cut_2.pt --load-name RealBpp-v02022.10.28-18-39.pt --load-name-sub RealBpp-v02022.10.28-18-39_s.pt --cases 500
# 
import sys
if '/opt/ros/kinetic/lib/python2.7/dist-packages' in sys.path:
    sys.path.remove('/opt/ros/kinetic/lib/python2.7/dist-packages')
import os
from time import clock
from acktr.model_loader import nnModel
from acktr.reorder import ReorderTree
import gym
import copy
import config
import torch
import argparse
import json
import numpy as np
import pybullet as p
# import rospy
# from geometry_msgs.msg import Pose
# from order.Tester import Transformer_Tester
# from geometry_msgs.msg import Pose
# from xyz_bpp_benchmark.simulator import BulletSimulator
# pallet_size = [1.0,1.0,1e-9]
# tool_size = [0.6,0.4]
def trans_point_to_pose(box_size, box_point):
    box_pose = Pose()
    box_pose.position.x = (box_point[0] + 0.5 * box_size[0])*0.1
    box_pose.position.y = (box_point[1] + 0.5 * box_size[1])*0.1
    box_pose.position.z = (box_point[2] + 0.5 * box_size[2])*0.1
    box_pose.orientation.x = 0.0
    box_pose.orientation.y = 0.0
    box_pose.orientation.z = 0.0
    box_pose.orientation.w = 1.0
    return box_pose

def trans_act_to_location(act):
    ly = act % 100
    lx = (act-ly) // 100
    return lx, ly

# def simulation(box_size,box_pose):
#     # Initialize the simulator.
#     # pallet_size = [10,10,1e-9]
#     # tool_size = [6,4.6]
#     # bullet_simulator = BulletSimulator(pallet_size, tool_size, mode="gui")
#     bullet_simulator.reset()
#     if bullet_simulator.plan(box_size, box_pose):
#         bullet_simulator.place(box_size, box_pose)
#     else:
#         return
#     if bullet_simulator.plan(box_size, box_pose):
#         bullet_simulator.place(box_size, box_pose)
#     bullet_simulator.spin()

def run_sequence(bullet_simulator, nmodel, nmodel_sub, raw_env, preview_num, c_bound):
    env = copy.deepcopy(raw_env)
    default_counter = 0
    box_counter = 0
    start = clock()
    # bullet_simulator.reset()
    succ = 0
    while True:
        box_size = env.box_creator.preview(preview_num)
        tree = ReorderTree(nmodel, nmodel_sub, box_size, env, times=100)
        act, val, default = tree.reorder_search()
        lx, ly = trans_act_to_location(act)
        layout, _, done, info = env.step([act])
        hmap = env.get_hmap()
        if done:
            end = clock()
            print('Time cost:', end-start)
            print('Ratio:', info['ratio'])
            if box_counter <1:
                return 0, 0, 0, 0, 1
            return info['ratio'], info['counter'], end-start,default_counter/box_counter, succ
        # else:
        #     hmap_ = hmap.reshape(-1,)
        #     box_size = box_size[0]
        #     lz = hmap_[act%10000] - box_size[2]
        #     box_pose = trans_point_to_pose(box_size, [lx,ly,lz])
        #     box_size[0] *= 0.1
        #     box_size[1] *= 0.1
        #     box_size[2] *= 0.1
            # if bullet_simulator.plan(box_size, box_pose):
            #     bullet_simulator.place(box_size, box_pose)
            # else:
            #     print("Bullet simulator detects unstable boxes:",box_size, box_pose)
            #     end = clock()
            #     print('Time cost:', end-start)
            #     print('Ratio:', info['ratio'])
            #     return info['ratio'], info['counter'], end-start,default_counter/box_counter, succ
        box_counter += 1
        default_counter += int(default)
        
def unified_test(url, url_sub, config):
    # Load packing mode build env
    ##############################
    nmodel = nnModel(url, config, sub=False)
    nmodel_sub = nnModel(url_sub, config, sub=True)
    data_url = config.data_dir+config.data_name
    env = gym.make(config.env_name, _adjust_ratio=0, adjust=False,
                    box_set=config.box_size_set,
                    container_size=config.container_size,
                    boxlist_len=config.boxlist_len,
                    test=True, data_name=data_url,
                    enable_rotation=config.enable_rotation,
                    data_type=config.data_type)
    print('Env name: ', config.env_name)
    print('Data url: ', data_url)
    print('Model url: ', url)
    print('Case number: ', config.cases)
    print('pruning threshold: ', config.pruning_threshold)
    print('Known item number: ', config.preview)
    times = config.cases
    ratios = []
    avg_ratio, avg_counter, avg_time, avg_drate, avg_succ = 0.0, 0.0, 0.0, 0.0, 0.0
    avg_ratio_class, avg_counter_class = 0, 0
    c_bound = config.pruning_threshold
    list_var = []
    list_var_c = []
    # bullet_simulator = BulletSimulator(pallet_size, tool_size, mode="gui")
    bullet_simulator=0
    for i in range(times):
        print("______",i)
        all_box_list = env.reset()
        # env.box_creator.preview(500)
        ratio, counter, time, depen_rate, succ= run_sequence(bullet_simulator, nmodel, nmodel_sub, env, config.preview, c_bound)
        if succ == 1:
            avg_ratio += ratio
            ratios.append(ratio)
            avg_succ += succ
            avg_time += time
            avg_counter += counter
            avg_time += time
            avg_drate += depen_rate
        else:
            avg_ratio += ratio
            ratios.append(ratio)
            avg_time += time
            avg_counter += counter
            avg_time += time
            avg_drate += depen_rate
            avg_ratio_class += ratio
            avg_counter_class += counter
        if((i+1)%50==0):
            print()
            print('All cases have been done!')
            print('----------------------------------------------')
            print('average space utilization: %.4f'%(avg_ratio/(i+1)))
            print('average put item number: %.4f'%(avg_counter/(i+1)))
            print('average sequence time: %.4f'%(avg_time/(i+1)))
            print('average time per item: %.4f'%(avg_time/avg_counter))
            print('successfully packing all items: %.4f'%(avg_succ/(i+1)))
            print('----------------------------------------------')
            list_var.append(avg_ratio_class/50)
            avg_ratio_class = 0.0
            list_var_c.append(avg_counter_class/50)
            avg_counter_class = 0.0

    print(np.var(list_var))
    print(np.var(list_var_c))



            
            
    print()
    print('All cases have been done!')
    print('----------------------------------------------')
    print('average space utilization: %.4f'%(avg_ratio/times))
    print('average put item number: %.4f'%(avg_counter/times))
    print('average sequence time: %.4f'%(avg_time/times))
    print('average time per item: %.4f'%(avg_time/avg_counter))
    print('successfully packing all items: %.4f'%(avg_succ/times))
    print('----------------------------------------------')

# def generate_data(size):
#     data_test = []
#     env = gym.make(config.env_name, _adjust_ratio=0, adjust=False,
#                     box_set=config.box_size_set,
#                     container_size=config.container_size,
#                     boxlist_len=config.boxlist_len,
#                     test=False,
#                     enable_rotation=config.enable_rotation,
#                     data_type=config.data_type)
#     for i in range(10000):
#         print(i)
#         box_list = env.reset()
#         result = [list(x) for x in box_list]
#         data_test.append(result)
#     torch.save(data_test, 'cut_20_10000.pt')

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='RL')
    parser.add_argument(
       '--mode', default='test')
    parser.add_argument(
        '--enable-rotation', action='store_true', default=False)
    parser.add_argument(
        '--load-model', default=True)
    parser.add_argument(
        '--load-name', default='default_cut_2.pt')
    parser.add_argument(
        '--load-name-sub', default='default_cut_2.pt')
    parser.add_argument(
        '--tsf-name', default='checkpoint-2022.01.12-16-11.pt')
    parser.add_argument(
        '--data-name', default='cut_2.pt')
    parser.add_argument(
        '--use-cuda', default=True, action='store_true')
    parser.add_argument(
        '--preview', default=1, type=int)
    parser.add_argument(
        '--cases', default=50, type=int)
    args = parser.parse_args()
    config.cases = args.cases
    config.load_name = args.load_name
    config.load_name_sub = args.load_name_sub
    config.data_name = args.data_name
    config.tester_params['model_name'] = args.tsf_name
    config.cuda = args.use_cuda and torch.cuda.is_available()
    config.no_cuda = not config.cuda
    model_url = config.load_dir + config.load_name
    model_url_sub = config.load_dir + config.load_name_sub
    torch.manual_seed(config.seed)
    torch.cuda.manual_seed_all(config.seed)
    unified_test(model_url, model_url_sub, config)
    torch.cuda.set_device(config.device)
    # generate_data(1000)