#!/usr/bin/python

import numpy as np
import torch
from scipy.spatial import ConvexHull
from helpers.util import denormalize_box_params
from helpers.util import fit_shapes_to_box
from cmath import rect, phase
import open3d as o3d
import trimesh
from shapely.geometry import Polygon

import extension.dist_chamfer as ext
chamfer = ext.chamferDist()

def close_dis(corners1,corners2):
    dist = -2 * np.matmul(corners1, corners2.transpose())
    dist += np.sum(corners1 ** 2, axis=-1)[:, None]
    dist += np.sum(corners2 ** 2, axis=-1)[None, :]
    dist = np.sqrt(dist)
    return np.min(dist)

def cal_l2_distance(point_1, point_2):
    return np.sqrt((point_2[0] - point_1[0])**2 + (point_2[1] - point_1[1])**2)

def angular_distance(a, b):
    a %= 360.
    b %= 360.

    va = np.matmul(rot2d(a), [1, 0])
    vb = np.matmul(rot2d(b), [1, 0])
    return anglebetween2vecs(va, vb) % 360.


def anglebetween2vecs(va, vb):
    rad = np.arccos(np.clip(np.dot(va, vb), -1, 1))
    return np.rad2deg(rad)


def rot2d(degrees):
    rad = np.deg2rad(degrees)
    return np.asarray([[np.cos(rad), -np.sin(rad)],
                       [np.sin(rad), np.cos(rad)]])


def estimate_angular_mean(deg):
    return np.rad2deg(phase(np.sum(rect(1, np.deg2rad(d)) for d in deg)/len(deg))) % 360.


def estimate_angular_std(degs):
    m = estimate_angular_mean(degs)
    std = np.sqrt(np.sum([angular_distance(d, m)**2 for d in degs]) / len(degs))
    return std


def denormalize(box_params, file=None, with_norm=True):
    if with_norm:
        return denormalize_box_params(box_params, file=file, params=box_params.shape[0])
    else:
        return box_params

###########################################################################################
def corners_from_box(box):                 # (7,)  l h w x y z theta(deg)
    l, h, w, x, y, z, theta = box.tolist()
    xt = torch.tensor([ w/2, w/2,-w/2,-w/2, w/2, w/2,-w/2,-w/2])
    yt = torch.tensor([ h/2, h/2, h/2, h/2,-h/2,-h/2,-h/2,-h/2])
    zt = torch.tensor([ l/2,-l/2,-l/2, l/2, l/2,-l/2,-l/2, l/2])
    R  = torch.tensor([[ np.cos(np.deg2rad(theta)), 0, np.sin(np.deg2rad(theta))],
                       [ 0, 1, 0],
                       [-np.sin(np.deg2rad(theta)),0, np.cos(np.deg2rad(theta))]])
    pts = torch.stack([xt, yt, zt], 1).double() @ R.double().T \
      + torch.tensor([x, y, z], dtype=torch.float64)
    return pts                                 # (8,3)

def box_volume_from_corners(c):                # c:(8,3)
    a = (c[0]-c[1]).norm(); b = (c[1]-c[2]).norm(); c_ = (c[0]-c[4]).norm()
    return (a*b*c_).item()


def intersection_volume(c1: np.ndarray, c2: np.ndarray) -> float:

    # ---------- 1)  ----------
    poly1 = Polygon([(c1[i, 0], c1[i, 2]) for i in range(4)])   
    poly2 = Polygon([(c2[i, 0], c2[i, 2]) for i in range(4)])

    inter_poly = poly1.intersection(poly2)
    if inter_poly.is_empty:
        return 0.0
    area_int = inter_poly.area            

    # ---------- 2)  ----------
    y_min1, y_max1 = c1[:, 1].min(), c1[:, 1].max()
    y_min2, y_max2 = c2[:, 1].min(), c2[:, 1].max()
    h_int = max(0.0, min(y_max1, y_max2) - max(y_min1, y_min2))
    if h_int == 0.0:
        return 0.0

    # ---------- 3)  ----------
    return area_int * h_int

def obb_overlap_gt(box_pred, box_gt):
    c_pred = corners_from_box(box_pred).numpy()
    c_gt   = corners_from_box(box_gt).numpy()
    vol_gt = box_volume_from_corners(torch.from_numpy(c_gt))
    v_int  = intersection_volume(c_pred, c_gt)
    return v_int / (vol_gt + 1e-6)


def validate_box_accuracy(boxes_pred: torch.Tensor,
                          gt_boxes:    torch.Tensor):

    last_size = gt_boxes[-1, :3]                 # (l,h,w)
    diag = last_size.square().sum().sqrt()       # √(l²+h²+w²)
    # ----------------------------------------------

    size_p, center_p, ang_p = boxes_pred[:, :3], boxes_pred[:, 3:6], boxes_pred[:, 6]
    size_g, center_g, ang_g = gt_boxes[:,  :3], gt_boxes[:,  3:6], gt_boxes[:, 6]

    pos_err  = (center_p - center_g).square().sum(1).sqrt()
    size_err = (size_p   - size_g  ).square().sum(1).sqrt()
    ang_diff = (ang_p - ang_g).abs() % 360.0
    ang_err  = torch.minimum(ang_diff, 360.0 - ang_diff)

    pos_err_p  = pos_err / diag
    size_norm  = size_g.square().sum(1).sqrt().clamp(min=1e-6)
    size_err_p = size_err / size_norm
    ang_err_p  = ang_err / 360.0

    overlaps = torch.tensor(
        [obb_overlap_gt(bp.cpu(), bg.cpu()) for bp, bg in zip(boxes_pred, gt_boxes)],
        device=boxes_pred.device
    )
    #print("overlaps", overlaps)

    metrics = {
        'num_boxes'     : boxes_pred.shape[0],
        'position'      : pos_err.mean().item(),
        'size'          : size_err.mean().item(),
        'orientation'   : ang_err.mean().item(),
        'Overlap_GT(%)'    : overlaps.mean().item()* 100.,
        'position_p(%)'    : pos_err_p.mean().item()* 100.,
        'size_p(%)'        : size_err_p.mean().item()* 100.,
        'orientation_p(%)' : ang_err_p.mean().item()* 100.,
    }
    return metrics


def _sample_points(mesh, n):

    if isinstance(mesh, o3d.geometry.TriangleMesh):
        return np.asarray(mesh.sample_points_uniformly(n).points)
    elif isinstance(mesh, trimesh.Trimesh):
        pts, _ = trimesh.sample.sample_surface(mesh, n)   # (n,3)
        return pts
    else:
        raise TypeError(f"unsupported mesh type: {type(mesh)}")


def _to_o3d(mesh):
    if isinstance(mesh, o3d.geometry.TriangleMesh):
        return mesh
    elif isinstance(mesh, trimesh.Trimesh):
        m = o3d.geometry.TriangleMesh()
        m.vertices  = o3d.utility.Vector3dVector(mesh.vertices)
        m.triangles = o3d.utility.Vector3iVector(mesh.faces)
        return m
    else:
        raise TypeError


def validate_shape_accuracy(shapes_mesh_pred, shapes_mesh_gt,
                            n_sample   = 2048,   
                            f_thresh   = 0.1,   
                            voxel_size = 0.2):  
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    cd_vals, f_vals, vox_vals = [], [], []

    for mesh_p, mesh_g in zip(shapes_mesh_pred, shapes_mesh_gt):



        # ---------- 1)  ----------
        pts_p_np = _sample_points(mesh_p, n_sample)        # (N,3)
        pts_g_np = _sample_points(mesh_g, n_sample)
        #print("pts_p_np", pts_p_np, "pts_g_np", pts_g_np)
        
        pts_p = torch.from_numpy(pts_p_np).float().unsqueeze(0).to(device) 
        pts_g = torch.from_numpy(pts_g_np).float().unsqueeze(0).to(device)

        #print("pts_p", pts_p.shape, "pts_g", pts_g.shape)

        # ---------- 2) Chamfer Distance ----------
        dist1, dist2 = chamfer(pts_p, pts_g)               # (1,N) (1,N)
        cd = (dist1.mean() + dist2.mean()).item()


        # ---------- 3) F-score ----------
        d1 = dist1.sqrt(); d2 = dist2.sqrt()
        #print("d1", d1, "d2", d2)
        prec = (d1 < f_thresh).float().mean()
        reca = (d2 < f_thresh).float().mean()
        #print("prec", prec, "reca", reca)
        fscore = (2 * prec * reca / (prec + reca)).item()

        # ---------- 4) IoU ----------
        mesh_p_o3d = _to_o3d(mesh_p)
        mesh_g_o3d = _to_o3d(mesh_g)

        vox_p = o3d.geometry.VoxelGrid.create_from_triangle_mesh(mesh_p_o3d, voxel_size)
        vox_g = o3d.geometry.VoxelGrid.create_from_triangle_mesh(mesh_g_o3d, voxel_size)
        vox_idx_p = {tuple(v.grid_index) for v in vox_p.get_voxels()}
        vox_idx_g = {tuple(v.grid_index) for v in vox_g.get_voxels()}

        inter = len(vox_idx_p & vox_idx_g)
        union = len(vox_idx_p | vox_idx_g) + 1e-6
        vox_iou = inter / union

        # ---------- 5) ----------
        cd_vals.append(cd)
        f_vals.append(fscore)
        vox_vals.append(vox_iou)

    # ---------- 6) average ----------
    return {
        'num_shapes': len(cd_vals),
        'CD'        : float(np.mean(cd_vals)),
        'F-score'   : float(np.mean(f_vals)),
        'Vox_IoU'   : float(np.mean(vox_vals)),
    }

#################################################################################################################
# def corners_from_box(box, param6=True, with_translation=False):
#     # box given as: [l, h, w, px, py, pz, z]
#     # l meansures z axis; h measures y axis; w measures x axis.
#     # (px, py, pz) is the bottom center
#     if param6:
#         l, h, w, px, py, pz = box
#     else:
#         l, h, w, px, py, pz, _ = box

#     (tx, ty, tz) = (px, py, pz) if with_translation else (0,0,0)

#     x_corners = [w/2,w/2,-w/2,-w/2,w/2,w/2,-w/2,-w/2]
#     y_corners = [h,h,h,h,0,0,0,0]
#     z_corners = [l/2,-l/2,-l/2,l/2,l/2,-l/2,-l/2,l/2]
#     corners_3d = np.dot(np.eye(3), np.vstack([x_corners,y_corners,z_corners]))
#     corners_3d[0,:] = corners_3d[0,:] + tx
#     corners_3d[1,:] = corners_3d[1,:] + ty
#     corners_3d[2,:] = corners_3d[2,:] + tz
#     corners_3d = np.transpose(corners_3d)

#     return corners_3d


def box3d_iou(box1, box2, param6=True, with_translation=False):
    ''' Compute 3D bounding box IoU.
    Input:
        corners1: numpy array (8,3), assume up direction is positive Y_h
        corners2: numpy array (8,3), assume up direction is positive Y_h
    Output:
        iou: 3D bounding box IoU
        iou_2d: bird's eye view 2D bounding box IoU
    '''
    # corner points are in counter clockwise order
    corners1 = corners_from_box(box1, param6, with_translation)
    corners2 = corners_from_box(box2, param6, with_translation)

    rect1 = [(corners1[i,2], corners1[i,0]) for i in range(0,4)]
    rect2 = [(corners2[i,2], corners2[i,0]) for i in range(0,4)]

    area1 = poly_area(np.array(rect1)[:,0], np.array(rect1)[:,1])
    area2 = poly_area(np.array(rect2)[:,0], np.array(rect2)[:,1])

    inter, inter_area = convex_hull_intersection(rect1, rect2)
    iou_2d = inter_area/(area1+area2-inter_area)
    ymax = min(corners1[0,1], corners2[0,1])
    ymin = max(corners1[4,1], corners2[4,1])

    inter_vol = inter_area * max(0.0, ymax-ymin)

    vol1 = box3d_vol(corners1)
    vol2 = box3d_vol(corners2)

    volmin = min(vol1, vol2)

    iou = inter_vol / volmin #(vol1 + vol2 - inter_vol)

    return iou, iou_2d


def convex_hull_intersection(p1, p2):
    """ Compute area of two convex hull's intersection area.
        p1,p2 are a list of (x,y) tuples of hull vertices.
        return a list of (x,y) for the intersection and its volume
    """
    inter_p = polygon_clip(p1,p2)
    if inter_p is not None:
        hull_inter = ConvexHull(inter_p)
        return inter_p, hull_inter.volume
    else:
        return None, 0.0

def poly_area(x,y):
    """ Ref: http://stackoverflow.com/questions/24467972/calculate-area-of-polygon-given-x-y-coordinates """
    return 0.5*np.abs(np.dot(x,np.roll(y,1))-np.dot(y,np.roll(x,1)))

def box3d_vol(corners):
    ''' corners: (8,3) no assumption on axis direction '''
    a = np.sqrt(np.sum((corners[0,:] - corners[1,:])**2))
    b = np.sqrt(np.sum((corners[1,:] - corners[2,:])**2))
    c = np.sqrt(np.sum((corners[0,:] - corners[4,:])**2))
    return a*b*c

def polygon_clip(subjectPolygon, clipPolygon):
    """ Clip a polygon with another polygon.
    Ref: https://rosettacode.org/wiki/Sutherland-Hodgman_polygon_clipping#Python
    Args:
      subjectPolygon: a list of (x,y) 2d points, any polygon.
      clipPolygon: a list of (x,y) 2d points, has to be *convex*
    Note:
      **points have to be counter-clockwise ordered**
    Return:
      a list of (x,y) vertex point for the intersection polygon.
    """
    def inside(p):
        return(cp2[0]-cp1[0])*(p[1]-cp1[1]) > (cp2[1]-cp1[1])*(p[0]-cp1[0])

    def computeIntersection():
        dc = [ cp1[0] - cp2[0], cp1[1] - cp2[1] ]
        dp = [ s[0] - e[0], s[1] - e[1] ]
        n1 = cp1[0] * cp2[1] - cp1[1] * cp2[0]
        n2 = s[0] * e[1] - s[1] * e[0]
        n3 = 1.0 / (dc[0] * dp[1] - dc[1] * dp[0])
        return [(n1*dp[0] - n2*dc[0]) * n3, (n1*dp[1] - n2*dc[1]) * n3]

    outputList = subjectPolygon
    cp1 = clipPolygon[-1]

    for clipVertex in clipPolygon:
        cp2 = clipVertex
        inputList = outputList
        outputList = []
        s = inputList[-1]

        for subjectVertex in inputList:
            e = subjectVertex
            if inside(e):
                if not inside(s):
                    outputList.append(computeIntersection())
                outputList.append(e)
            elif inside(s):
                outputList.append(computeIntersection())
            s = e
        cp1 = cp2
        if len(outputList) == 0:
            return None
    return(outputList)


def pointcloud_overlap(pclouds, objs, boxes, angles, triples, vocab, overlap_metric):

    obj_classes = vocab['object_idx_to_name']
    pred_classes = vocab['pred_idx_to_name']
    pair = [(t[0].item(),t[2].item()) for t in triples]
    pred = [t[1].item() for t in triples]
    pair2pred = dict(zip(pair, pred))
    structural = ['floor', 'wall', 'ceiling', '_scene_']
    touching = ['none', 'inside', 'attached to', 'part of', 'cover', 'belonging to', 'build in', 'connected to']
    boxes = torch.cat([boxes.float(), angles.view(-1,1).float()], 1)

    for i in range(len(pclouds) - 1):
        for j in range(i+1, len(pclouds)):
            if obj_classes[objs[i]].split('\n')[0] in structural or \
                    obj_classes[objs[j]].split('\n')[0] in structural:
                # do not consider structural objects
                continue
            if (i, j) in pair2pred.keys() and pred_classes[pair2pred[(i,j)]].split('\n')[0] in touching:
                # naturally expected overlap
                continue
            if (j, i) in pair2pred.keys() and pred_classes[pair2pred[(j,i)]].split('\n')[0] in touching:
                # naturally expected overlap
                continue
            pc1 = fit_shapes_to_box(boxes[i].clone(), pclouds[i].clone())
            pc2 = fit_shapes_to_box(boxes[j].clone(), pclouds[j].clone())
            result = pointcloud_overlap_pair(pc1, pc2)
            overlap_metric.append(result)
    return overlap_metric


def pointcloud_overlap_pair(pc1, pc2):
    from sklearn.neighbors import NearestNeighbors
    all_pc = np.concatenate([pc1, pc2], 0)
    nbrs = NearestNeighbors(n_neighbors=2, algorithm='kd_tree')
    nbrs.fit(all_pc)
    distances, indices = nbrs.kneighbors(pc1)
    # first neighbour will likely be itself other neighbour is a point from the same pc or the other pc
    # two point clouds are overlaping, when the nearest neighbours of one set are from the other set
    overlap = np.sum(indices >= len(pc1))
    return overlap
