import os
import sys
import numpy as np
import argparse
import importlib
import time
import copy
from torch.utils.data import DataLoader
import torch

parser = argparse.ArgumentParser()
parser.add_argument('--dataset', default='scannet', help='Dataset: sunrgbd or scannet [default: scannet]')
parser.add_argument('--num_point', type=int, default=50000, help='Point Number [default: 50000]')
parser.add_argument('--dataloader', action='store_true')
parser.add_argument('--split', default='val')
FLAGS = parser.parse_args()

import torch
import torch.nn as nn
import torch.optim as optim

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
ROOT_DIR = BASE_DIR
sys.path.append(os.path.join(ROOT_DIR, 'utils'))
sys.path.append(os.path.join(ROOT_DIR, 'scannet'))
sys.path.append(os.path.join(ROOT_DIR, 'models'))
from pc_util import random_sampling, read_ply
from models import GroupFreeDetector
from models import parse_predictions, parse_groundtruths
from models.dump_helper import dump_results, dump_boxes
import scannet.scannet_utils as scannet_utils

MAX_NUM_OBJ = 132
MEAN_COLOR_RGB = np.array([109.8, 97.2, 83.8])

import ipdb
st = ipdb.set_trace


def get_model(DATASET_CONFIG):
    num_input_channel = 3

    model = GroupFreeModulator(
                num_class,
                num_heading_bin=dataset_config.num_heading_bin,
                num_size_cluster=dataset_config.num_size_cluster,
                mean_size_arr=dataset_config.mean_size_arr,
                input_feature_dim=num_input_channel,
                width=args.width,
                bn_momentum=args.bn_momentum,
                sync_bn=args.syncbn,
                num_proposal=args.num_target,
                sampling=args.sampling,
                dropout=args.transformer_dropout,
                activation=args.transformer_activation,
                nhead=args.nhead,
                num_decoder_layers=args.num_decoder_layers,
                dim_feedforward=args.dim_feedforward,
                self_position_embedding=args.self_position_embedding,
                size_cls_agnostic=args.size_cls_agnostic,
                contrastive_align_loss=args.use_contrastive_align,
                contrastive_hungarian=args.contrastive_hungarian,
                sa_lang=not args.no_sa_lang,
                sa_vis=not args.no_sa_vis,
                use_gt_box=args.use_gt_box,
                use_gt_class=args.use_gt_class,
                num_obj_classes=485,
                gt_with_bbox_loss=args.gt_with_bbox_loss,
                gt_with_bbox_sampling=args.gt_with_bbox_sampling,
                train_viewpoint_module=args.train_viewpoint_module
            )


    return model


def load_checkpoint(checkpoint_path, model):
    # Load checkpoint if there is any
    if checkpoint_path is not None and os.path.isfile(checkpoint_path):
        checkpoint = torch.load(checkpoint_path, map_location='cpu')
        state_dict = checkpoint['model']
        save_path = checkpoint.get('save_path', 'none')
        for k in list(state_dict.keys()):
            state_dict[k[len("module."):]] = state_dict[k]
            # delete renamed or unused k
            del state_dict[k]
        model.load_state_dict(state_dict)

        del checkpoint
        torch.cuda.empty_cache()
    else:
        raise FileNotFoundError
    return save_path


def preprocess_point_cloud(point_cloud):
    ''' Prepare the numpy point cloud (N,6) for forward pass '''
    point_cloud = point_cloud[:,0:6] # use color
    point_cloud[:, 3:] = (point_cloud[:, 3:] - MEAN_COLOR_RGB) / 256.0
    point_cloud = random_sampling(point_cloud, FLAGS.num_point)
    pc = np.expand_dims(point_cloud.astype(np.float32), 0) # (1,50000,6)
    return pc


def load_axis_aligned_matrix(scene_name):
    mesh_file = os.path.join('./dataset/language_grounding/scans', scene_name, scene_name + '_vh_clean_2.ply')
    meta_file = os.path.join('./dataset/language_grounding/scans', scene_name, scene_name + '.txt')  # includes axisAlignment info for the train set scans.

    mesh_vertices = scannet_utils.read_mesh_vertices_rgb(mesh_file)

    # Load scene axis alignment matrix
    lines = open(meta_file).readlines()
    for line in lines:
        if 'axisAlignment' in line:
            axis_align_matrix = [float(x) \
                                 for x in line.rstrip().strip('axisAlignment = ').split(' ')]
            break
    axis_align_matrix = np.array(axis_align_matrix).reshape((4, 4))
    pts = np.ones((mesh_vertices.shape[0], 4))
    pts[:, 0:3] = mesh_vertices[:, 0:3]
    pts = np.dot(pts, axis_align_matrix.transpose())  # Nx4
    mesh_vertices[:, 0:3] = pts[:, 0:3]
    mesh_vertices = np.expand_dims(mesh_vertices, axis=0)
    return mesh_vertices


def load_scan_list(split_set='val'):
    all_scan_names = list(set([os.path.basename(x)[0:12] \
                               for x in os.listdir('./dataset/language_grounding/scans/scannet_train_detection_data') if x.startswith('scene')]))
    if split_set == 'all':
        scan_names = all_scan_names
    elif split_set in ['train', 'val', 'test']:
        split_filenames = os.path.join(ROOT_DIR, 'scannet/meta_data',
                                       'scannetv2_{}.txt'.format(split_set))
        with open(split_filenames, 'r') as f:
            scan_names = f.read().splitlines()
            # remove unavailiable scans
        scan_names = [sname for sname in scan_names \
                           if sname in all_scan_names]
    return scan_names


if __name__=='__main__':
    # Set file paths and dataset config
    demo_dir =  f'./dataset/language_grounding/group_free_pred_bboxes_{FLAGS.split}'
    if FLAGS.dataset == 'scannet':
        sys.path.append(os.path.join(ROOT_DIR, 'scannet'))
        from scannet_detection_dataset import DC # dataset config
        checkpoint_path = os.path.join('./dataset/language_grounding/full_train_detector/group_free/scannet_1628017918/91896780/ckpt_epoch_400.pth')

        from scannet.scannet_detection_dataset import ScannetDetectionDataset
        if FLAGS.dataloader:
            # Init datasets and dataloaders
            def my_worker_init_fn(worker_id):
                np.random.seed(np.random.get_state()[1][0] + worker_id)

            SCANNET_TEST_DATASET = ScannetDetectionDataset(f'{FLAGS.split}', num_points=FLAGS.num_point,
                                                   augment=False,
                                                   use_color=True,
                                                   use_height=False,
                                                   data_root='./dataset/language_grounding/scans/')
            SCANNET_TEST_DATALOADER = DataLoader(SCANNET_TEST_DATASET, batch_size=1,
                                         shuffle=False,
                                         num_workers=0,
                                         worker_init_fn=my_worker_init_fn)
            scan_list = load_scan_list(split_set=FLAGS.split)
    else:
        print('Unkown dataset %s. Exiting.'%(DATASET))
        exit(-1)

    eval_config_dict = {'remove_empty_box': True, 'use_3d_nms': True, 'nms_iou': 0.25,
        'use_old_type_nms': False, 'cls_nms': False, 'per_class_proposal': False,
        'conf_thresh': 0.5, 'dataset_config': DC}

    # Init the model and optimzier
    net = get_model(DC)
    save_path = load_checkpoint(checkpoint_path, net)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    net = net.cuda()

    # Load and preprocess input point cloud
    net.eval()  # set model to eval mode (for bn and dp)
    not_exported = []

    if FLAGS.dataloader:
        for batch_idx, batch_data_label in enumerate(SCANNET_TEST_DATALOADER):
            for key in batch_data_label:
                batch_data_label[key] = batch_data_label[key].cuda(
                    non_blocking=True)
            scan_name = scan_list[batch_idx]  # "scene_" + str(batch_idx)

            if batch_idx % 10 == 0:
                print('Eval batch: %d'% (batch_idx))

            # Model inference
            inputs = {'point_clouds': batch_data_label['point_clouds']}
            tic = time.time()
            with torch.no_grad():
                end_points = net(inputs)
            toc = time.time()

            for key in batch_data_label:
                assert (key not in end_points)
                end_points[key] = batch_data_label[key]

            prefix = 'last_'
            end_points['point_clouds'] = inputs['point_clouds']
            batch_pred_map_cls = parse_predictions(end_points, eval_config_dict, prefix) 
            point_clouds = end_points['point_clouds'].cpu().numpy()
            batch_size = point_clouds.shape[0]

            for i in range(batch_size):

                sr3d_boxes = dump_boxes(end_points, eval_config_dict, prefix, i)

                class_label_list = [DC.class2type[p[0]] for p in batch_pred_map_cls[i]]

                print(len(class_label_list))
                print(len(sr3d_boxes))
                assert(len(sr3d_boxes) == len(class_label_list))

                data_dict = {
                    "class": class_label_list,
                    "box": sr3d_boxes,
                    "pc": point_clouds[i]
                }
                if not os.path.exists(demo_dir):
                    os.mkdir(demo_dir)
                np.save(f'{demo_dir}/{scan_name}.npy', data_dict)

    else:
        for scan_name in TRAIN_SCAN_NAMES:
            try:
                pc_path = os.path.join('./dataset/language_grounding/scans', 'scannet_train_detection_data', scan_name + '_vert.npy')
                point_cloud = np.load(pc_path)
                pc = preprocess_point_cloud(copy.deepcopy(point_cloud))
                print('Loaded point cloud data: %s'%(pc_path))

                # Model inference
                inputs = {'point_clouds': torch.from_numpy(pc).to(device)}
                tic = time.time()
                with torch.no_grad():
                    end_points = net(inputs)
                toc = time.time()
                print('Inference time: %f'%(toc-tic))
                end_points['point_clouds'] = inputs['point_clouds']
                end_points['colored_pc'] = np.expand_dims(point_cloud.astype(np.float32), 0)
                pred_map_cls = parse_predictions(end_points, eval_config_dict, 'last_')
                gt_map_cls = parse_groundtruths(end_points, eval_config_dict, False)
                print('Finished detection. %d object detected.'%(len(pred_map_cls[0])))
                dump_dir = os.path.join(demo_dir)
                if not os.path.exists(dump_dir): os.mkdir(dump_dir)
                dump_results(end_points, dump_dir, scan_name, DC, FLAGS.viz, FLAGS.gt_boxes, False)
                print('Dumped detection results to folder %s'%(dump_dir))
            except:
                not_exported.append(scan_name)

    print("Not exported", not_exported)
    print("==========> DONE <===========")