import os
import time
import argparse
import cv2
import glob
import numpy as np
from tqdm import tqdm
import _pickle as cPickle
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
from lib.network import DeformNet, PoseNet, PoseNetV2, PoseNetV3
from lib.align import estimateSimilarityTransform, RansacPnP
from lib.utils import load_depth, get_bbox, compute_mAP, plot_mAP, zoom_in, xywh_to_cs, load_obj, find_model
from lib.transformations import quaternion_matrix
from lib.nn_distance.chamfer_loss import ChamferLoss
import pdb

parser = argparse.ArgumentParser()
parser.add_argument('--data', type=str, default='val', help='val, real_test')
parser.add_argument('--data_dir', type=str, default='data/NOCS/', help='data directory')
parser.add_argument('--n_cat', type=int, default=6, help='number of object categories')
parser.add_argument('--nv_prior', type=int, default=1024, help='number of vertices in shape priors')
parser.add_argument('--model', type=str, default='results/camera/model_50.pth', help='resume from saved model')
parser.add_argument('--n_pts', type=int, default=1024, help='number of foreground points')
parser.add_argument('--img_size', type=int, default=192, help='cropped image size')
parser.add_argument('--gpu', type=str, default='0', help='GPU to use')
parser.add_argument('--select_class', type=str, default='bottle', help='resume from saved model')
parser.add_argument('--use_pose_reg', action='store_true')
parser.add_argument('--only_eval', action='store_true')
parser.add_argument('--use_point_reg', action='store_true')
parser.add_argument('--use_fc', action='store_true')
parser.add_argument('--use_nocs_map', action='store_true')
parser.add_argument('--implict', action='store_true')
parser.add_argument('--max_point', action='store_true')
parser.add_argument('--version', type=str, default='v1')
parser.add_argument('--use_rgb', action='store_true')
parser.add_argument('--with_recon', action='store_true')



opt = parser.parse_args()
mean_shapes = np.load('assets/mean_points_emb.npy')
cat_names = ['bottle', 'bowl', 'camera', 'can', 'laptop', 'mug']
mean_meshes = []
for cat in cat_names:
    mean_meshes.append(load_obj('./assets/{}.obj'.format(cat))[0])

assert opt.data in ['val', 'real_test']
if opt.data == 'val':
    result_dir = 'results/eval_camera'
    file_path = 'CAMERA/val_list.txt'
    cam_fx, cam_fy, cam_cx, cam_cy = 577.5, 577.5, 319.5, 239.5
else:
    result_dir = 'results/eval_real'
    file_path = 'Real/test_list.txt'
    cam_fx, cam_fy, cam_cx, cam_cy = 591.0125, 590.16775, 322.525, 244.11084

result_dir = os.path.join(result_dir, opt.model.split('/')[2])
cam = np.identity(3, dtype=np.float32)
cam[0, 0] = cam_fx
cam[1, 1] = cam_fy
cam[0, 2] = cam_cx
cam[1, 2] = cam_cy

if not os.path.exists(result_dir):
    os.makedirs(result_dir)

xmap = np.array([[i for i in range(640)] for j in range(480)])
ymap = np.array([[j for i in range(640)] for j in range(480)])
norm_scale = 1000.0
norm_color = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]
)

chamferloss = ChamferLoss()

def compute_T(trans, s_box, c_box, bbox, cam):
    # compute T from translation head
    ratio_delta_c = trans[:, :2]
    ratio_depth = trans[:, 2]
    pred_depth = ratio_depth * (opt.img_size / s_box)
    pred_c = ratio_delta_c * bbox[:, 2:] + c_box
    pred_x = (pred_c[:, 0] - cam[:, 0, 2]) * pred_depth / cam[:, 0, 0]
    pred_y = (pred_c[:, 1] - cam[:, 1, 2]) * pred_depth / cam[:, 1, 1]
    return torch.stack([pred_x, pred_y, pred_depth], dim=1)

def detect():
    # resume model
    # os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu
    # estimator = DeformNet(opt.n_cat, opt.nv_prior)
    if opt.version == 'v3' or opt.version == 'v4':
        estimator = PoseNetV3(opt)
    elif opt.version == 'v2':
        estimator = PoseNetV2(opt)
    else:
        estimator = PoseNet(opt)
    estimator.cuda()
    estimator.load_state_dict(torch.load(opt.model))
    estimator.eval()
    # get test data list
    img_list = [os.path.join(file_path.split('/')[0], line.rstrip('\n'))
                for line in open(os.path.join(opt.data_dir, file_path))]
    # frame by frame test
    t_inference = 0.0
    t_umeyama = 0.0
    inst_count = 0
    img_count = 0
    cd_dist_total = 0.0
    t_start = time.time()
    for path in tqdm(img_list):
        img_path = os.path.join(opt.data_dir, path)
        raw_rgb = cv2.imread(img_path + '_color.png')[:, :, :3]
        raw_rgb = raw_rgb[:, :, ::-1]
        raw_depth = load_depth(img_path)
        # load mask-rcnn detection results
        img_path_parsing = img_path.split('/')
        mrcnn_path = os.path.join('results/mrcnn_results', opt.data, 'results_{}_{}_{}.pkl'.format(
            opt.data.split('_')[-1], img_path_parsing[-2], img_path_parsing[-1]))
        with open(mrcnn_path, 'rb') as f:
            mrcnn_result = cPickle.load(f)
        num_insts = len(mrcnn_result['class_ids'])
        f_sRT = np.zeros((num_insts, 4, 4), dtype=float)
        f_size = np.zeros((num_insts, 3), dtype=float)
        f_sRT_pose = np.zeros((num_insts, 4, 4), dtype=float)
        # prepare frame data
        f_points, f_rgb, f_choose, f_catId, f_prior = [], [], [], [], []
        f_box_s, f_box_c, f_cam, f_bbox = [], [], [], []
        f_select_pts_2d = []
        f_models = [] 
        valid_inst = []
        for i in range(num_insts):
            cat_id = mrcnn_result['class_ids'][i] - 1
            verts = mean_meshes[cat_id]
            # verts = mean_meshes[cat_names[cat_id]]['verts']
            # rmin, rmax, cmin, cmax, box_s, box_c = get_bbox(mrcnn_result['rois'][i])
            y1, x1, y2, x2 = mrcnn_result['rois'][i]
            bbox = np.array([x1, y1, x2-x1, y2-y1]).astype(np.int)
            c, s = xywh_to_cs(bbox, 1.5, s_max=max(640, 480))
            rgb, c_h_, c_w_, s_, crop_bbox = zoom_in(raw_rgb, c, s, opt.img_size)
            rmin, rmax, cmin, cmax = crop_bbox[0], crop_bbox[1], crop_bbox[2], crop_bbox[3]
            box_c = np.array([c_w_, c_h_])
            box_s = s_
            mask = np.logical_and(mrcnn_result['masks'][:, :, i], raw_depth > 0)
            choose = mask[rmin:rmax, cmin:cmax].flatten().nonzero()[0]
            w_begin = c_w_ - box_s / 2.
            h_begin = c_h_ - box_s / 2.
            w_unit = box_s * 1.0 / opt.img_size
            h_unit = box_s * 1.0 / opt.img_size
            model = find_model(img_path, mrcnn_result['rois'][i])
            # no depth observation for background in CAMERA dataset
            # beacuase of how we compute the bbox in function get_bbox
            # there might be a chance that no foreground points after cropping the mask
            # cuased by false positive of mask_rcnn, most of the regions are background
            if opt.select_class == 'all': 
                if len(choose) < 32:
                    f_sRT[i] = np.identity(4, dtype=float)
                    f_size[i] = 2 * np.amax(np.abs(verts), axis=0)
                    f_sRT_pose[i] = np.identity(4, dtype=float)
                    continue
                else:
                    valid_inst.append(i)
            else:
                if len(choose) < 32 or cat_id != cat_names.index(opt.select_class):
                    f_sRT[i] = np.identity(4, dtype=float)
                    f_size[i] = 2 * np.amax(np.abs(verts), axis=0)
                    f_sRT_pose[i] = np.identity(4, dtype=float)
                    continue
                else:
                    valid_inst.append(i)
            # process objects with valid depth observation
            if len(choose) > opt.n_pts:
                c_mask = np.zeros(len(choose), dtype=int)
                c_mask[:opt.n_pts] = 1
                np.random.shuffle(c_mask)
                choose = choose[c_mask.nonzero()]
            else:
                choose = np.pad(choose, (0, opt.n_pts-len(choose)), 'wrap')
            depth_masked = raw_depth[rmin:rmax, cmin:cmax].flatten()[choose][:, np.newaxis]
            xmap_masked = xmap[rmin:rmax, cmin:cmax].flatten()[choose][:, np.newaxis]
            ymap_masked = ymap[rmin:rmax, cmin:cmax].flatten()[choose][:, np.newaxis]
            select_pts_2d = [xmap_masked[:, 0], ymap_masked[:, 0]]
            pt2 = depth_masked / norm_scale
            pt0 = (xmap_masked - cam_cx) * pt2 / cam_fx
            pt1 = (ymap_masked - cam_cy) * pt2 / cam_fy
            points = np.concatenate((pt0, pt1, pt2), axis=1)
            # points[:, -1] = 1.0
            # rgb = raw_rgb[rmin:rmax, cmin:cmax, :]
            # rgb = cv2.resize(rgb, (opt.img_size, opt.img_size), interpolation=cv2.INTER_LINEAR)
            rgb = norm_color(rgb)
            crop_w = cmax - cmin
            crop_h = rmax - rmin
            ratio_w = opt.img_size / float(crop_w)
            ratio_h = opt.img_size / float(crop_h)
            col_idx = choose % crop_w
            row_idx = choose // crop_w
            choose = (np.floor(row_idx * ratio_h) * opt.img_size + np.floor(col_idx * ratio_w)).astype(np.int64)
            choose_2d = np.vstack((np.floor(col_idx * ratio_w), np.floor(row_idx * ratio_h))).T
            # select_pts_2d = [w_begin + np.floor(col_idx * ratio_w) * w_unit, \
            #     h_begin + np.floor(row_idx * ratio_h) * h_unit]
            # for x in range(opt.img_size):
            #     for y in range(opt.img_size):
            #         if [x, y] in choose_2d:
            #             select_pts_2d.append([w_begin + y * w_unit, h_begin + x * h_unit])
            select_pts_2d = np.array(select_pts_2d, dtype=np.float).T
            # concatenate instances
            f_points.append(points)
            f_rgb.append(rgb)
            f_choose.append(choose)
            f_catId.append(cat_id)
            f_prior.append(verts)
            f_box_c.append(box_c)
            f_box_s.append(box_s)
            f_cam.append(cam)
            f_bbox.append(bbox)
            f_select_pts_2d.append(select_pts_2d)
            f_models.append(model)
        if len(valid_inst):
            f_points = torch.cuda.FloatTensor(f_points)
            f_rgb = torch.stack(f_rgb, dim=0).cuda()
            f_choose = torch.cuda.LongTensor(f_choose)
            f_catId = torch.cuda.LongTensor(f_catId)
            f_prior = torch.cuda.FloatTensor(f_prior)
            f_box_c = torch.cuda.FloatTensor(f_box_c)
            f_box_s = torch.cuda.FloatTensor(f_box_s)
            f_cam = torch.cuda.FloatTensor(f_cam)
            f_bbox = torch.cuda.FloatTensor(f_bbox)
            f_models = torch.cuda.FloatTensor(f_models)
            # inference
            torch.cuda.synchronize()
            t_now = time.time()
            outputs = estimator(f_points, f_rgb, f_choose, f_catId, f_prior)
            # assign_mat, deltas = estimator(f_rgb, f_choose, f_catId, f_prior)
            if 'deltas' in outputs.keys():
                deltas = outputs['deltas']
                inst_shape = f_prior + deltas
            elif 'deformed_shape' in outputs.keys():
                inst_shape = outputs['deformed_shape'].float()
            if opt.version == 'v3' or opt.version == 'v4':
                f_coords = outputs['assign_mat']
            else:
                assign_mat = outputs['assign_mat']
                assign_mat = F.softmax(assign_mat, dim=2)
                f_coords = torch.bmm(assign_mat, inst_shape)  # bs x n_pts x 3
            cd_dist, _, _ = chamferloss(inst_shape.detach(), f_models)
            cd_dist_total += cd_dist * len(valid_inst)    
            pred_scales, pred_trans, pred_rots = outputs['pose']
            pred_pose_trans = compute_T(pred_trans, f_box_s, f_box_c, f_bbox, f_cam)
            pred_rots = pred_rots.detach().cpu().numpy()
            pred_scales = pred_scales.detach().cpu().numpy()
            # pred_pose_trans = pred_pose_trans + inst_shape.mean(dim=1)
            pred_pose_trans = pred_pose_trans.detach().cpu().numpy()
            

            torch.cuda.synchronize()
            t_inference += (time.time() - t_now)
            f_coords = f_coords.detach().cpu().numpy()
            f_points = f_points.cpu().numpy()
            f_choose = f_choose.cpu().numpy()
            f_insts = inst_shape.detach().cpu().numpy()
            f_cam = f_cam.cpu().numpy()
            t_now = time.time()
            for i in range(len(valid_inst)):
                inst_idx = valid_inst[i]
                choose = f_choose[i]
                _, choose = np.unique(choose, return_index=True)
                nocs_coords = f_coords[i, choose, :]
                pts_2d = f_select_pts_2d[i][choose]
                f_size[inst_idx] = 2 * np.amax(np.abs(f_insts[i]), axis=0)
                points = f_points[i, choose, :]
                _, _, _, pred_sRT = estimateSimilarityTransform(nocs_coords, points)
                # pred_R, pred_T = RansacPnP(nocs_coords * np.amax(np.abs(f_insts[i]), axis=0), pts_2d, f_cam[i])
                if pred_sRT is None:
                    pred_sRT = np.identity(4, dtype=float)
                f_sRT[inst_idx] = pred_sRT
                pred_sRT_pose = np.identity(4)
                pred_sRT_pose[:3, :3] = quaternion_matrix(pred_rots[i])[:3, :3] * pred_scales[i]
                pred_sRT_pose[:3, 3] = pred_pose_trans[i]
                pred_sRT_pose[3, 3] = 1.0
                # pred_sRT_pose[:3, :3] = pred_R * f_size[inst_idx]
                # pred_sRT_pose[:3, 3:] = pred_T
                # pred_sRT_pose[3, 3] = 1.0
                f_sRT_pose[inst_idx] = pred_sRT_pose
                t_umeyama += (time.time() - t_now)
                img_count += 1
                inst_count += len(valid_inst)

        # save results
        result = {}
        with open(img_path + '_label.pkl', 'rb') as f:
            gts = cPickle.load(f)
        result['gt_class_ids'] = gts['class_ids']
        result['gt_bboxes'] = gts['bboxes']
        result['gt_RTs'] = gts['poses']
        result['gt_scales'] = gts['size']
        result['gt_handle_visibility'] = gts['handle_visibility']

        result['pred_class_ids'] = mrcnn_result['class_ids']
        result['pred_bboxes'] = mrcnn_result['rois']
        result['pred_scores'] = mrcnn_result['scores']
        result['pred_RTs'] = f_sRT
        result['pred_RTs_pose'] = f_sRT_pose
        result['pred_scales'] = f_size

        image_short_path = '_'.join(img_path_parsing[-3:])
        save_path = os.path.join(result_dir, 'results_{}.pkl'.format(image_short_path))
        with open(save_path, 'wb') as f:
            cPickle.dump(result, f)
    # write statistics
    fw = open('{0}/eval_logs.txt'.format(result_dir), 'w')
    messages = []
    messages.append("Total images: {}".format(len(img_list)))
    messages.append("Valid images: {},  Total instances: {},  Average: {:.2f}/image".format(
        img_count, inst_count, inst_count/img_count))
    messages.append("Inference time: {:06f}  Average: {:06f}/image".format(t_inference, t_inference/img_count))
    messages.append("Umeyama time: {:06f}  Average: {:06f}/image".format(t_umeyama, t_umeyama/img_count))
    messages.append("Total time: {:06f}".format(time.time() - t_start))
    messages.append("Total Chamfer Distance of {} : {}".format(opt.select_class, cd_dist_total/inst_count))
    for msg in messages:
        print(msg)
        fw.write(msg + '\n')
    fw.close()


def evaluate():
    degree_thres_list = list(range(0, 61, 1))
    shift_thres_list = [i / 2 for i in range(21)]
    iou_thres_list = [i / 100 for i in range(101)]
    # predictions
    print(result_dir)
    result_pkl_list = glob.glob(os.path.join(result_dir, 'results_*.pkl'))
    result_pkl_list = sorted(result_pkl_list)
    assert len(result_pkl_list)
    pred_results = []
    for pkl_path in result_pkl_list:
        with open(pkl_path, 'rb') as f:
            result = cPickle.load(f)
            if 'gt_handle_visibility' not in result:
                result['gt_handle_visibility'] = np.ones_like(result['gt_class_ids'])
            else:
                assert len(result['gt_handle_visibility']) == len(result['gt_class_ids']), "{} {}".format(
                    result['gt_handle_visibility'], result['gt_class_ids'])
        if type(result) is list:
            pred_results += result
        elif type(result) is dict:
            pred_results.append(result)
        else:
            assert False
    # To be consistent with NOCS, set use_matches_for_pose=True for mAP evaluation
    iou_aps, pose_aps, iou_acc, pose_acc = compute_mAP(pred_results, result_dir, degree_thres_list, shift_thres_list,
                                                       iou_thres_list, iou_pose_thres=0.1, use_matches_for_pose=True, 
                                                       select_class=opt.select_class, use_pose_reg=opt.use_pose_reg)
    # metric
    fw = open('{0}/eval_logs.txt'.format(result_dir), 'a')
    iou_25_idx = iou_thres_list.index(0.25)
    iou_50_idx = iou_thres_list.index(0.5)
    iou_75_idx = iou_thres_list.index(0.75)
    degree_05_idx = degree_thres_list.index(5)
    degree_10_idx = degree_thres_list.index(10)
    shift_02_idx = shift_thres_list.index(2)
    shift_05_idx = shift_thres_list.index(5)
    messages = []
    messages.append('mAP:')
    messages.append('3D IoU at 25: {:.1f}'.format(iou_aps[-1, iou_25_idx] * 100))
    messages.append('3D IoU at 50: {:.1f}'.format(iou_aps[-1, iou_50_idx] * 100))
    messages.append('3D IoU at 75: {:.1f}'.format(iou_aps[-1, iou_75_idx] * 100))
    messages.append('5 degree, 2cm: {:.1f}'.format(pose_aps[-1, degree_05_idx, shift_02_idx] * 100))
    messages.append('5 degree, 5cm: {:.1f}'.format(pose_aps[-1, degree_05_idx, shift_05_idx] * 100))
    messages.append('10 degree, 2cm: {:.1f}'.format(pose_aps[-1, degree_10_idx, shift_02_idx] * 100))
    messages.append('10 degree, 5cm: {:.1f}'.format(pose_aps[-1, degree_10_idx, shift_05_idx] * 100))
    messages.append('Acc:')
    messages.append('3D IoU at 25: {:.1f}'.format(iou_acc[-1, iou_25_idx] * 100))
    messages.append('3D IoU at 50: {:.1f}'.format(iou_acc[-1, iou_50_idx] * 100))
    messages.append('3D IoU at 75: {:.1f}'.format(iou_acc[-1, iou_75_idx] * 100))
    messages.append('5 degree, 2cm: {:.1f}'.format(pose_acc[-1, degree_05_idx, shift_02_idx] * 100))
    messages.append('5 degree, 5cm: {:.1f}'.format(pose_acc[-1, degree_05_idx, shift_05_idx] * 100))
    messages.append('10 degree, 2cm: {:.1f}'.format(pose_acc[-1, degree_10_idx, shift_02_idx] * 100))
    messages.append('10 degree, 5cm: {:.1f}'.format(pose_acc[-1, degree_10_idx, shift_05_idx] * 100))
    for msg in messages:
        print(msg)
        fw.write(msg + '\n')
    fw.close()
    # load NOCS results
    pkl_path = os.path.join('results/nocs_results', opt.data, 'mAP_Acc.pkl')
    with open(pkl_path, 'rb') as f:
        nocs_results = cPickle.load(f)
    nocs_iou_aps = nocs_results['iou_aps'][-1, :]
    nocs_pose_aps = nocs_results['pose_aps'][-1, :, :]
    iou_aps = np.concatenate((iou_aps, nocs_iou_aps[None, :]), axis=0)
    pose_aps = np.concatenate((pose_aps, nocs_pose_aps[None, :, :]), axis=0)
    # plot
    plot_mAP(iou_aps, pose_aps, result_dir, iou_thres_list, degree_thres_list, shift_thres_list)


if __name__ == '__main__':
    print('Detecting ...')
    if not opt.only_eval:
        detect()
    print('Evaluating ...')
    evaluate()
