from __future__ import print_function

import time

import open3d as o3d
import open3d.visualization.gui as gui
import open3d.visualization.rendering as rendering

import math
import argparse
import os
import random
import numpy as np
import torch
import torch.nn.parallel
import torch.utils.data
import pprint
import sys
from pathlib import Path
parent_dir = Path(__file__).resolve().parent.parent
sys.path.append(str(parent_dir))
from model.AE import AE
from dataset.threedfront_dataset_incremental import ThreedFrontDatasetSceneGraphIncremental,RandomStepGroupedSampler
from helpers.util import bool_flag, batch_torch_denormalize_box_params, batch_torch_denormalize_box_params_tmp,sample_points
from helpers.metrics_3dfront import validate_box_accuracy, validate_shape_accuracy
from helpers.visualize_scene import  render_v1_full,render_incremental
from helpers.visualize_graph import run as vis_graph
from helpers.visualize_graph import incremental_run as vis_incremental_graph
from helpers.gui_hub import init_main_window
import extension.dist_chamfer as ext
chamfer = ext.chamferDist()
import json

parser = argparse.ArgumentParser()

parser.add_argument('--dataset', required=False, type=str, default="/media/xxx/xxx_ssd/FRONT", help="dataset path")
parser.add_argument('--with_feats', type=bool_flag, default=False, help="Load Feats directly instead of points.")


parser.add_argument('--exp', default='/home/xxx/data_weight/baseline3', help='experiment name')
parser.add_argument('--epoch', type=str, default='30', help='saved epoch')
parser.add_argument('--visualize', default=True, type=bool_flag)
parser.add_argument('--export_3d', default=False, type=bool_flag, help='Export the generated shapes and boxes in json files for future use')
parser.add_argument('--no_stool', default=False, type=bool_flag)
parser.add_argument('--room_type', default='livingroom', help='all, bedroom, livingroom, diningroom, library')

args = parser.parse_args()


# _APP_GUI_READY = False
# _WIN_MAIN      = None          # gui.Window
# _PANEL         = {}            # {"graph": SceneWidget, "scene": SceneWidget}

# def _init_main_window(title="Incremental Viewer", w=1600, h=900):
#     global _APP_GUI_READY, _WIN_MAIN, _PANEL

#     app = gui.Application.instance
#     if not _APP_GUI_READY:
#         app.initialize()              
#         _APP_GUI_READY = True

#     if _WIN_MAIN:                      
#         return _PANEL

#     win  = app.create_window(title, w, h)
#     splitter = gui.Splitter(gui.Splitter.HORIZONTAL)

#     w_graph = gui.SceneWidget()
#     w_graph.scene = rendering.Open3DScene(win.renderer)
#     splitter.add_child(w_graph)

#     w_scene = gui.SceneWidget()
#     w_scene.scene = rendering.Open3DScene(win.renderer)
#     splitter.add_child(w_scene)

#     splitter.set_divider_position(int(w*0.45))   # 45% : 55%
#     win.add_child(splitter)

#     _WIN_MAIN = win
#     _PANEL = {"graph": w_graph, "scene": w_scene}
#     return _PANEL


def reseed(num):
    np.random.seed(num)
    torch.manual_seed(num)
    random.seed(num)
def to_cuda(x, device='cuda'):
    if torch.is_tensor(x):
        return x.to(device, non_blocking=True)
    if isinstance(x, dict):
        return {k: to_cuda(v, device) for k, v in x.items()}
    if isinstance(x, list):
        return [to_cuda(v, device) for v in x]
    return x

def evaluate():
    print(torch.__version__)

    random.seed(48)
    torch.manual_seed(48)

    argsJson = os.path.join(args.exp, 'args.json')
    assert os.path.exists(argsJson), 'Could not find args.json for experiment {}'.format(args.exp)
    with open(argsJson) as j:
        modelArgs = json.load(j)
    normalized_file = os.path.join(args.dataset, 'boxes_centered_stats_{}_test.txt').format(modelArgs['room_type'])


    test_dataset = ThreedFrontDatasetSceneGraphIncremental(
        root=args.dataset,
        split='val_scans',
        shuffle_objs=modelArgs['shuffle_objs'],
        eval=True,
        with_feats=modelArgs['with_feats'],
        large=modelArgs['large'],
        room_type=args.room_type)

    test_dataloader  = torch.utils.data.DataLoader(test_dataset,
                        batch_sampler=RandomStepGroupedSampler(test_dataset, 1),
                        collate_fn=test_dataset.collate_fn_inc,
                        num_workers=0)

    modeltype_ = modelArgs['network_type']
    modelArgs['no_stool'] = args.no_stool if 'no_stool' not in modelArgs else modelArgs['no_stool']

    model = AE(root=args.dataset, type=modeltype_, vocab=test_dataset.vocab, residual=modelArgs['residual'], gconv_pooling=modelArgs['pooling'], num_box_params=modelArgs['num_box_params'] )

    model.load_networks(exp=args.exp, epoch=args.epoch, restart_optim=False)
    if torch.cuda.is_available():
        model = model.cuda()

    model = model.eval()
    
    reseed(47)
    print('\nGeneration Mode')
    incremental_validate_(modelArgs, test_dataloader, model, epoch=args.epoch, normalized_file=normalized_file,
                              vocab=test_dataset.vocab,
                             point_classes_idx=test_dataset.point_classes_idx,
                             export_3d=args.export_3d, datasize='large' if modelArgs['large'] else 'small')


def incremental_validate_(modelArgs, test_dataloader, model, epoch=None, normalized_file=None,  vocab=None,point_classes_idx=None,
                            export_3d=False, datasize='small'):

    print('Evaluating...')
    step_num = 0
    shape_step_num = 0

    all_inference_time = 0.0
    box_keys   = ['position','size','orientation','Overlap_GT(%)','position_p(%)','size_p(%)','orientation_p(%)']
    shape_keys = ['CD','F-score','Vox_IoU']
    box_sum    = {k: 0.0 for k in box_keys}
    shape_sum  = {k: 0.0 for k in shape_keys}


    all_pred_shapes_exp = {} # for export
    all_pred_boxes_exp = {}
    bbox_file = "/media/xxx/xxx_ssd/FRONT/cat_jid_test.json" if datasize == 'large' else "/media/xxx/xxx_ssd/FRONT/cat_jid_test_small.json"

    with open(bbox_file, "r") as read_file:
        box_data = json.load(read_file)
        if modelArgs['no_stool']:
            box_data['chair'].update(box_data['stool'])

    for data in test_dataloader:                 
        model.reset_all_scene_cls_states()   

        enc   = to_cuda(data['encoder'])       
        meta  = to_cuda(data['step_meta'])              
        scan = meta['scan_id_str'][0] 
        print('evaluating scan_id--------------------------------------:', scan)


        K = int(enc['obj_to_step'].max().item()) + 1    

        if args.visualize:

            panel = init_main_window()
            w_graph = panel["graph"]      # for scene-graph 
            w_scene = panel["scene"]      # for 3d scene
            cache_graph    = None
            cache_rendered = None
            # vis, rendered = None, None   
            # w3d_vis =None
            # cache = None

        # --------------------
        for k in range(K):
            obj_mask = enc['obj_to_step'] == k          
            if not obj_mask.any():                     
                continue
            step_num=step_num+1
            shape_step_num=shape_step_num+1
            print('evaluating step --------------------------', k)
            print('all steps number in this batch:', K)

            idx_flat_step = torch.nonzero(obj_mask, as_tuple=False).squeeze(1) 

            # take only the objects in the current step
            objs   = enc['objs'][obj_mask]
            boxes  = enc['boxes'][obj_mask]
            feats  = enc.get('feats')
            feats  = feats[obj_mask] if feats is not None else None
            new_mask  = enc['new_mask'][obj_mask]           # that only work for selecting objs and boxes

            # —— scene-id —— #
            obj_scene_ids =enc['obj_to_scene'][obj_mask]

            if enc['triples'].numel():
                tri_step_mask     = enc['triple_to_step'] == k
                triples_step      = enc['triples'][tri_step_mask]
                triple_scene_ids  = enc['triple_to_scene'][tri_step_mask]  # (N_tri_step,)
            else:
                triples_step     = enc['triples'].new_empty(0, 3)
                triple_scene_ids = enc['triple_to_scene'].new_empty(0, dtype=torch.long)

            start_inference = time.time()
            all_pred_boxes = []


            with torch.no_grad():
                # ----------  forward  ----------
                boxes_pred, pred_shapes = model.forward_incremental_3D_(
                        obj_batch_scene_ids = obj_scene_ids,
                        objs    = objs,
                        boxes   = boxes,
                        triples = triples_step,    # (N_tri, 3)  # triples among new-new, new-old, old-old
                        new_mask= new_mask,
                        obj_indices = idx_flat_step,
                        triple_scene_ids = triple_scene_ids)
                gt_boxes   = boxes[new_mask]
                gt_shapes = feats[new_mask] if feats is not None else None
                
                boxes_pred[:, 6]=  (boxes_pred[:, 6])* 57.29577951308232 # convert to degree
                gt_boxes[:, 6] =  (gt_boxes[:, 6])* 57.29577951308232    # convert to degree

                incre_objs = objs[new_mask]
                incremental_objs_gt = objs[new_mask]

                shapes_mesh_pred, _ = model.decode_g2sv1(incre_objs, pred_shapes, box_data, retrieval=True)
                shapes_mesh_gt, _ = model.decode_g2sv1(incremental_objs_gt, gt_shapes, box_data, retrieval=True)



            boxes_pred_den_no_angle  = batch_torch_denormalize_box_params(boxes_pred[:, :6],file=normalized_file)
            boxes_pred_den = torch.cat([boxes_pred_den_no_angle, boxes_pred[:, 6].unsqueeze(1)], dim=1)
            row_boxes_den_no_angle  = batch_torch_denormalize_box_params(boxes[:, :6],file=normalized_file)
            row_boxes_den = torch.cat([row_boxes_den_no_angle, boxes[:, 6].unsqueeze(1)], dim=1)

            #gt_boxes[:, 6] = -180 + (torch.argmax(gt_boxes[:, 6], dim=0, keepdim=True) + 1)* 15.0
            boxes_gt_den_no_angle  = batch_torch_denormalize_box_params(gt_boxes[:, :6],file=normalized_file)
            boxes_gt_den = torch.cat([boxes_gt_den_no_angle, gt_boxes[:, 6].unsqueeze(1)], dim=1)

            end_inference = time.time()
            all_inference_time=all_inference_time+(end_inference - start_inference)


            scene_metrics_box = validate_box_accuracy( boxes_pred_den, boxes_gt_den)
            for k_ in box_keys: box_sum[k_] += float(scene_metrics_box[k_])  
            scene_metrics_shape = validate_shape_accuracy(shapes_mesh_pred, shapes_mesh_gt, n_sample=2000, f_thresh=0.1, voxel_size=0.35)
            for key in shape_keys:
                v = scene_metrics_shape[key]
                if math.isfinite(v):
                    shape_sum[key]+= float(v) 
                else:
                    shape_step_num=shape_step_num-1


            if args.visualize:

                classes = sorted(list(set(vocab['object_idx_to_name'])))
               
                row_global_ids = torch.where(obj_mask)[0].cpu().numpy()   
                new_global_ids = row_global_ids[new_mask.cpu().numpy()]  
                row_cls_ids    = objs.cpu().numpy()  
                print('new node number:', len(new_global_ids))
                

                w_graph, cache_graph = vis_incremental_graph(
                        scan, classes, new_global_ids, triples_step,
                        row_global_ids, row_boxes_den, row_cls_ids,
                        data_path=args.dataset,
                        w_graph=w_graph, cache=cache_graph)

                w_scene, cache_rendered = render_incremental(
                        scan, incre_objs.cpu().numpy(), boxes_gt_den,
                        shapes_mesh_gt, classes,w_scene,
                         rendered=cache_rendered,
                        without_lamp=False, no_stool=False)
                # updated_w3d_vis, cache = vis_incremental_graph(
                #         scan, classes,
                #         new_global_ids,           # new_obj_ids
                #         triples_step_k,           # triples_step_k  (contian triples of previous steps? )
                #         row_global_ids,           # row_global_ids
                #         boxes,                    # boxes_row
                #         row_cls_ids    = row_cls_ids, 
                #         data_path = args.dataset,
                #         w3d_vis       = w3d_vis,
                #         cache     = cache)

                # w3d_vis=updated_w3d_vis
                # # layout and shape visualization through open3d

                # print('visualizing generated 3D incremental scenes')

                # start_gen = time.time()
                # #render_v1_full( scan, incre_objs.detach().cpu().numpy(), boxes_pred_den, datasize=datasize, classes=classes, render_type='v1', classed_idx=incre_objs,
                # #    shapes_pred=shapes_mesh_pred, store_img=True, render_boxes=False, visual=True, demo=False, no_stool = args.no_stool, without_lamp=True)
                # vis, rendered = render_incremental(scan,  incre_objs.detach().cpu().numpy(),boxes_pred_den,shapes_mesh_pred, classes,vis,rendered ,False,True)
                #end_gen = time.time()
                #print(f'Generated scene rendering time: {end_gen - start_gen:.3f} seconds')
                #print(f'Generated scene inference and rendering time: {(end_inference - start_inference)+(end_gen - start_gen):.3f} seconds')
    
                #print('visualizing ground-truth 3D scenes')
                #render_v1_full( scan, incremental_objs_gt.detach().cpu().numpy(), boxes_gt_den,  datasize=datasize, classes=classes, render_type='v1', classed_idx=incremental_objs_gt,
                #     shapes_pred=shapes_mesh_gt, store_img=True, render_boxes=False, visual=True, demo=False, no_stool = args.no_stool, without_lamp=True)
            
            all_pred_boxes.append(boxes_pred_den.cpu().detach())

        if args.visualize:
            gui.Application.instance.run()
            #vis.run()
            #vis.destroy_window()
    avg_box = {k: box_sum[k] / step_num for k in box_keys}
    print('\n[Box] average over steps:', {**avg_box, 'num_steps': step_num})
    avg_shape = {k: shape_sum[k] / shape_step_num for k in shape_keys}
    print('[Shape]  average over steps:', {**avg_shape, 'num_steps': shape_step_num})
    print(f'average inference time: {all_inference_time/step_num:.3f} seconds')
    if export_3d:
        # export box and shape predictions for future evaluation
        result_path = os.path.join(args.exp, 'results')
        if not os.path.exists(result_path):
            # Create a new directory for results
            os.makedirs(result_path)
        shape_filename = os.path.join(result_path, 'shapes_' + ('large' if datasize else 'small') + '.json')
        box_filename = os.path.join(result_path, 'boxes_' + ('large' if datasize else 'small') + '.json')
        json.dump(all_pred_boxes_exp, open(box_filename, 'w')) # 'dis_nomani_boxes_large.json'
        json.dump(all_pred_shapes_exp, open(shape_filename, 'w'))



def normalize(vertices, scale=1):
    xmin, xmax = np.amin(vertices[:, 0]), np.amax(vertices[:, 0])
    ymin, ymax = np.amin(vertices[:, 1]), np.amax(vertices[:, 1])
    zmin, zmax = np.amin(vertices[:, 2]), np.amax(vertices[:, 2])

    vertices[:, 0] += -xmin - (xmax - xmin) * 0.5
    vertices[:, 1] += -ymin - (ymax - ymin) * 0.5
    vertices[:, 2] += -zmin - (zmax - zmin) * 0.5

    scalars = np.max(vertices, axis=0)
    scale = scale

    vertices = vertices / scalars * scale
    return vertices


if __name__ == "__main__":
    evaluate()
