import sys

sys.path.append('../corr/lib')
sys.path.append('../corr')
import torch
import numpy as np
import os
from PIL import Image, ImageDraw
from MeshUtils import *
from pytorch3d.renderer import PerspectiveCameras
from pytorch3d.io import load_obj
from pytorch3d.renderer.cameras import look_at_view_transform
from pytorch3d import _C
from config import get_config, print_usage
from capsule import Network
from sklearn.neighbors import KDTree
import matplotlib.pyplot as plt
import point_cloud_utils as pcu
# import open3d as o3d
import trimesh
import json
from mpl_toolkits.mplot3d import Axes3D
import warnings
from pytorch3d.loss import chamfer_distance


def as_mesh(scene_or_mesh):
    """
    Convert a possible scene to a mesh.

    If conversion occurs, the returned mesh has only vertex and face data.
    """
    if isinstance(scene_or_mesh, trimesh.Scene):
        if len(scene_or_mesh.geometry) == 0:
            mesh = None  # empty scene
        else:
            # we lose texture information here
            mesh = trimesh.util.concatenate(
                tuple(trimesh.Trimesh(vertices=g.vertices, faces=g.faces)
                    for g in scene_or_mesh.geometry.values()))
    else:
        assert(isinstance(mesh, trimesh.Trimesh))
        mesh = scene_or_mesh
    return mesh


def PointFaceDistance(points, points_first_idx, tris, tris_first_idx, max_points, min_triangle_area=5e-3):
    """
    Args:
        points: FloatTensor of shape `(P, 3)`
        points_first_idx: LongTensor of shape `(N,)` indicating the first point
            index in each example in the batch
        tris: FloatTensor of shape `(T, 3, 3)` of triangular faces. The `t`-th
            triangular face is spanned by `(tris[t, 0], tris[t, 1], tris[t, 2])`
        tris_first_idx: LongTensor of shape `(N,)` indicating the first face
            index in each example in the batch
        max_points: Scalar equal to maximum number of points in the batch
        min_triangle_area: (float, defaulted) Triangles of area less than this
            will be treated as points/lines.
    Returns:
        dists: FloatTensor of shape `(P,)`, where `dists[p]` is the squared
            euclidean distance of `p`-th point to the closest triangular face
            in the corresponding example in the batch
        idxs: LongTensor of shape `(P,)` indicating the closest triangular face
            in the corresponding example in the batch.

        `dists[p]` is
        `d(points[p], tris[idxs[p], 0], tris[idxs[p], 1], tris[idxs[p], 2])`
        where `d(u, v0, v1, v2)` is the distance of point `u` from the triangular
        face `(v0, v1, v2)`

    """
    dists, idxs = _C.point_face_dist_forward(
        points,
        points_first_idx,
        tris,
        tris_first_idx,
        max_points,
        min_triangle_area,
    )
    return dists, idxs

def fps(points, n_samples):
    """
    points: [N, 3] array containing the whole point cloud
    n_samples: samples you want in the sampled point cloud typically << N
    """
    points = np.array(points)

    # Represent the points by their indices in points
    points_left = np.arange(len(points))  # [P]

    # Initialise an array for the sampled indices
    sample_inds = np.zeros(n_samples, dtype='int')  # [S]

    # Initialise distances to inf
    dists = np.ones_like(points_left) * float('inf')  # [P]

    # Select a point from points by its index, save it
    selected = 0
    sample_inds[0] = points_left[selected]

    # Delete selected
    points_left = np.delete(points_left, selected)  # [P - 1]

    # Iteratively select points for a maximum of n_samples
    for i in range(1, n_samples):
        # Find the distance to the last added point in selected
        # and all the others
        last_added = sample_inds[i - 1]

        dist_to_last_added_point = (
                (points[last_added] - points[points_left]) ** 2).sum(-1)  # [P - i]

        # If closer, updated distances
        dists[points_left] = np.minimum(dist_to_last_added_point,
                                        dists[points_left])  # [P - i]

        # We want to pick the one that has the largest nearest neighbour
        # distance to the sampled points
        selected = np.argmax(dists[points_left])
        sample_inds[i] = points_left[selected]

        # Update points_left
        points_left = np.delete(points_left, selected)

    return points[sample_inds], sample_inds


def vis_pts_att(pts, label_map, fn="temp.png", marker=".", alpha=0.9):
    # pts (n, d): numpy, d-dim point cloud
    # label_map (n, ): numpy or None
    # fn: filename of visualization
    assert pts.shape[1] == 3
    TH = 0.7
    fig = plt.figure()
    ax = fig.add_subplot(111, projection="3d")
    ax.set_zlim(-TH,TH)
    ax.set_xlim(-TH,TH)
    ax.set_ylim(-TH,TH)
    xs = pts[:, 0]
    ys = pts[:, 1]
    zs = pts[:, 2]
    if label_map is not None:
        ax.scatter(xs, ys, zs, c=label_map, cmap="jet", marker=marker, alpha=alpha)
    else:
        ax.scatter(xs, ys, zs, marker=marker, alpha=alpha, edgecolor="none")

    ax.set_xticklabels([])
    ax.set_yticklabels([])
    ax.set_zticklabels([])
    plt.savefig(
        fn,
        bbox_inches='tight',
        pad_inches=0,
        dpi=300,)
    plt.close()


device = 'cuda:0'

# prepare network
config, unparsed = get_config()
if len(unparsed) > 0:
    print_usage()
    exit(1)

config.res_dir = '../capsule/logs'

network = Network(config)
network.model.eval()
network.load_checkpoint()


def capsule_decompose(pc, R=None, T=None):
    with torch.no_grad():
        _x = pc[None]
        _labels, _feats = network.model.decompose_one_pc(_x, R, T)
    return _labels, _feats


# render config
render_image_size = (512, 512)
image_size = (512, 512)

blend_params = BlendParams(sigma=1e-4, gamma=1e-4)
raster_settings = RasterizationSettings(
    image_size=render_image_size,
    blur_radius=0.0,
    faces_per_pixel=1,
    bin_size=0
)
# We can add a point light in front of the object.
lights = PointLights(device=device, location=((2.0, 2.0, -2.0),))

# prepare camera
cameras = PerspectiveCameras(focal_length=1.0 * 3200,
                             principal_point=((render_image_size[1] // 2, render_image_size[0] // 2),),
                             image_size=(render_image_size,), device=device, in_ndc=False)

phong_renderer = MeshRenderer(
    rasterizer=MeshRasterizer(
        cameras=cameras,
        raster_settings=raster_settings
    ),
    shader=HardPhongShader(device=device, lights=lights, cameras=cameras),
)



R_can = torch.tensor([[[0.3456, 0.5633, 0.7505],
                        [-0.9333, 0.2898, 0.2122],
                        [-0.0980, -0.7737, 0.6259]]]).to(device)
T_can = torch.tensor([[[-0.0161], [-0.0014], [-0.0346]]]).to(device)




root_path = '../data/CorrData/new_DST_part/aligned_CAD_models_new'
root_path2 = '../data/CorrData/new_DST_part/part_transfer_capsule'

save_path = os.path.join(root_path2, 'transfer')

os.makedirs(save_path, exist_ok=True)

json_path = os.path.join(save_path, 'result.json')

with open(json_path, 'w') as f:
    pass

if not os.path.exists(save_path):
    os.makedirs(save_path)

results = dict()

if os.path.exists(os.path.join(save_path, f'distance.txt')):
    os.remove(os.path.join(save_path, f'distance.txt'))
with open(os.path.join(save_path, f'distance.txt'), 'w') as f:
    f.write(f'CD:\n')

info_path = "../data/CorrData/new_DST_part/infos.json"

with open(info_path, "r") as f:
    infos = json.load(f)

for cate_id in os.listdir(root_path):
    print(f'processing {cate_id}')
    if cate_id == 'transfer' or cate_id == 'info.json':
        continue

    # category_name = infos[cate_id]["name"]
    results[cate_id] = dict()
    mesh_path = os.path.join(root_path, cate_id)
    # # careful!! randomly choose one instance
    # for instance_id in os.listdir(mesh_path):
    #     if not instance_id.endswith('.obj'):
    #         continue
    #     chosen_instance = instance_id[:-4]
    #     part_fn = os.path.join(part_path, f'{chosen_instance}.json')
    #     if not os.path.exists(part_fn):
    #         print(f'part file not exists: {part_fn}')
    #         continue
    #     break
    part_path = os.path.join('../data/CorrData/new_DST_part/CAD_models_v1.3', cate_id, 'new_json_files')
    if not os.path.exists(part_path):
        part_path = os.path.join('../data/CorrData/new_DST_part/CAD_models_v1.3', cate_id, 'json_files')

    try:
        if cate_id in infos:
            chosen_instance = infos[cate_id]["chosen_instance"]
        else:
            for instance_id in os.listdir(part_path):
                if not instance_id.endswith('.json'):
                    continue
                chosen_instance = instance_id[:-5]
                break

        part_fn = os.path.join(part_path, f'{chosen_instance}.json')

        save_category_path = os.path.join(save_path, cate_id)
        if not os.path.exists(save_category_path):
            os.makedirs(save_category_path)
        if os.path.exists(os.path.join(save_category_path, f'distance.txt')):
            os.remove(os.path.join(save_category_path, f'distance.txt'))
        with open(os.path.join(save_category_path, f'distance.txt'), 'w') as f:
            f.write(f'category: {cate_id}\n')
            f.write(f'instance: {chosen_instance}\n')

        with warnings.catch_warnings():
            warnings.simplefilter("ignore", category=UserWarning)
            obj_file = os.path.join(mesh_path, f'{chosen_instance}.obj')
            glb_file = os.path.join(mesh_path, f'{chosen_instance}.glb')
            if os.path.exists(obj_file):
                file_name = f'{chosen_instance}.obj'
                mesh = load_obj(os.path.join(mesh_path, f'{chosen_instance}.obj'))
                # normalize
                vert_middle = (mesh[0].max(axis=0)[0] + mesh[0].min(axis=0)[0]) / 2
                vert_scale = (mesh[0].max(axis=0)[0] - mesh[0].min(axis=0)[0]).max()
                vertices = (mesh[0] - vert_middle) / vert_scale
                faces = mesh[1].verts_idx
                # print('vertices: ', vertices.shape, 'faces: ', faces.shape)
                # print('mesh: ', mesh[0].shape, mesh[1].verts_idx.shape)
                mesh = trimesh.Trimesh(vertices=vertices, faces=faces, process=False)
            elif os.path.exists(glb_file):
                file_name = f'{chosen_instance}.glb'
                mesh = trimesh.load_mesh(glb_file)
                if isinstance(mesh, trimesh.Scene):
                    # Convert the scene to a single mesh
                    mesh = trimesh.util.concatenate(mesh.dump())

                # Extract vertices and faces
                vertices = mesh.vertices
                faces = mesh.faces

                vert_middle = (vertices.max(axis=0) + vertices.min(axis=0)) / 2
                vert_scale = (vertices.max(axis=0) - vertices.min(axis=0)).max()
                vertices = vertices - vert_middle
                vertices = vertices / vert_scale
                mesh = trimesh.Trimesh(vertices=vertices, faces=faces, process=False)
            else:
                print(f'No mesh found for {chosen_instance}')
                continue

        # get point_cloud from sampling
        points, face_indices = trimesh.sample.sample_surface(mesh, config.num_pts)

        # load part label from json
        part_fn = os.path.join(part_path, f'{chosen_instance}.json')
        with open(part_fn, 'r') as f:
            part_data = json.load(f)
        part_num = len(part_data.keys())

        valid_part_num = 0
        for key in part_data.keys():
            if len(part_data[key]) == 0:
                continue
            valid_part_num += 1
        print('valid_part_num: ', valid_part_num)

        ori_vertices = np.array(mesh.vertices)
        # print('ori_vertices: ', ori_vertices.shape)
        part_label = np.zeros(len(ori_vertices))
        for key_id, key in enumerate(part_data.keys()):
            # print('key: ', key)
            if len(part_data[key]) == 0:
                continue
            # print('min: ', min(part_data[key]), 'max: ', max(part_data[key]))
            for vert_id in part_data[key]:
                part_label[vert_id] = key_id

        ori_faces = np.array(mesh.faces)
        # print('ori_faces: ', ori_faces.shape, 'min: ', ori_faces.min(), 'max: ', ori_faces.max())
        face_label = np.zeros(len(ori_faces))
        for face_id, face in enumerate(ori_faces):
            if part_label[face[0]] == part_label[face[1]]:
                face_label[face_id] = part_label[face[0]]
            elif part_label[face[0]] == part_label[face[2]]:
                face_label[face_id] = part_label[face[0]]
            elif part_label[face[1]] == part_label[face[2]]:
                face_label[face_id] = part_label[face[1]]
            else:
                face_label[face_id] = part_label[face[0]]
    except:
        print(f"revised annotation is not missing or incorrect for {chosen_instance}")
        continue



    point_label = face_label[face_indices]
    # print('point_label: ', point_label.shape, 'min: ', point_label.min(), 'max: ', point_label.max())

    # print('1111')
    vis_pts_att(points, point_label / np.max(point_label), os.path.join(save_category_path, f'ref_segment.png'))
    # also render GT mesh for the chosen instance
    chosen_ori_verts = np.array(mesh.vertices)
    chosen_ori_faces = np.array(mesh.faces)
    cmap = plt.get_cmap('jet')
    gt_color = cmap(part_label / part_num)
    verts_features = torch.tensor(gt_color[:, :3], dtype=torch.float32)[None]
    gt_render_mesh = Meshes(
        verts=[torch.from_numpy(chosen_ori_verts.astype(np.float32)).to(device)],
        faces=[torch.from_numpy(chosen_ori_faces.astype(np.int32)).to(device)],
        textures=Textures(verts_features=verts_features.to(device))
    )
    R, T = look_at_view_transform(5.0, 30, 60, device=device)
    image = phong_renderer(meshes_world=gt_render_mesh.clone(), R=R, T=T)
    image = image[0, ..., :3].detach().squeeze().cpu().numpy()
    image = np.array((image / image.max()) * 255).astype(np.uint8)
    Image.fromarray(image).save(os.path.join(save_category_path, f'{chosen_instance}_gt_render.png'))
    # print('2222')

    # decompose using point cloud
    prev_x = torch.from_numpy(points).to(device, dtype=torch.float32)

    label, prev_feats = capsule_decompose(prev_x, R_can, T_can)
    prev_feats = prev_feats[0, :, 0, :, 0]
    prev_feats = prev_feats / torch.norm(prev_feats, dim=0)

    # vis_pts_att(prev_x.cpu(), label.cpu(), os.path.join(save_path, f'cap_decompose.png'))

    file_list = os.listdir(mesh_path)

    category_distance = 0
    instance_number = 0
    for file_name in file_list:
        if chosen_instance in file_name:
            continue

        instance_id = file_name[:-4]

        save_instance_path = os.path.join(save_category_path, f'{instance_id}')
        if not os.path.exists(save_instance_path):
            os.makedirs(save_instance_path)

        # load part label from json
        part_fn = os.path.join(part_path, f'{instance_id}.json')
        if not os.path.exists(part_fn):
            # print(f'part file not exists: {part_fn}')
            continue
        
        instance_number += 1

        with open(part_fn, 'r') as f:
            part_data = json.load(f)

        # get point_cloud from sampling
        # ins_mesh = o3d.io.read_triangle_mesh(os.path.join(mesh_path, file_name))
        with warnings.catch_warnings():
            if '.obj' in file_name:
                warnings.simplefilter("ignore", category=UserWarning)
                ins_mesh = load_obj(os.path.join(mesh_path, file_name))
                # normalize
                ins_vert_middle = (ins_mesh[0].max(axis=0)[0] + ins_mesh[0].min(axis=0)[0]) / 2
                ins_vert_scale = (ins_mesh[0].max(axis=0)[0] - ins_mesh[0].min(axis=0)[0]).max()
                ins_vertices = (ins_mesh[0] - ins_vert_middle) / ins_vert_scale
                ins_faces = ins_mesh[1].verts_idx
            elif '.glb' in file_name:
                ins_mesh = trimesh.load_mesh(os.path.join(mesh_path, file_name))
                if isinstance(ins_mesh, trimesh.Scene):
                    # Convert the scene to a single mesh
                    ins_mesh = trimesh.util.concatenate(ins_mesh.dump())
                ins_vertices = ins_mesh.vertices
                ins_faces = ins_mesh.faces
                ins_vert_middle = (ins_vertices.max(axis=0) + ins_vertices.min(axis=0)) / 2
                ins_vert_scale = (ins_vertices.max(axis=0) - ins_vertices.min(axis=0)).max()
                ins_vertices = ins_vertices - ins_vert_middle
                ins_vertices = ins_vertices / ins_vert_scale

        ins_mesh = trimesh.Trimesh(vertices=ins_vertices, faces=ins_faces, process=False)
        ins_points, ins_face_indices = trimesh.sample.sample_surface(ins_mesh, config.num_pts)

        ins_faces = np.array(ins_mesh.faces)
        ins_part_label = np.zeros(len(ins_mesh.vertices))
        flag = False
        for key_id, key in enumerate(part_data.keys()):
            for vert_id in part_data[key]:
                if vert_id >= len(ins_part_label):
                    flag = True
                    break
        if flag:
            print(f'part file is not correct for {instance_id}')
            continue

        for key_id, key in enumerate(part_data.keys()):
            for vert_id in part_data[key]:
                ins_part_label[vert_id] = key_id
        ins_face_label = np.zeros(len(ins_faces))
        for face_id, face in enumerate(ins_faces):
            if ins_part_label[face[0]] == ins_part_label[face[1]]:
                ins_face_label[face_id] = ins_part_label[face[0]]
            elif ins_part_label[face[0]] == ins_part_label[face[2]]:
                ins_face_label[face_id] = ins_part_label[face[0]]
            elif ins_part_label[face[1]] == ins_part_label[face[2]]:
                ins_face_label[face_id] = ins_part_label[face[1]]
            else:
                ins_face_label[face_id] = ins_part_label[face[0]]

        ins_point_label = ins_face_label[ins_face_indices]    

        ins_v = torch.from_numpy(ins_points).to(device, dtype=torch.float32)
        label, feats = capsule_decompose(ins_v, R_can, T_can)
        feats = feats[0, :, 0, :, 0]

        # cosine similarity to find the nearest point of the reference mesh
        feats = feats / torch.norm(feats, dim=0)
        similarity = torch.matmul(prev_feats.transpose(0, 1), feats)
        # print('similarity: ', similarity.shape, 'min: ', similarity.min(), 'max: ', similarity.max())
        max_point = torch.argmax(similarity, dim=0)
        # print('max_point: ', max_point.shape, 'min: ', max_point.min(), 'max: ', max_point.max())

        trans_part_label = point_label[max_point.cpu().numpy()]
        # print('trans_part_label: ', trans_part_label.min(), trans_part_label.max())
        # import ipdb; ipdb.set_trace()

        vis_pts_att(ins_points, trans_part_label / np.max(trans_part_label), os.path.join(save_instance_path, f'{instance_id}.png'))

        instance_distance = 0
        
        # create the txt file for distance, if exists, delete it
        if os.path.exists(os.path.join(save_instance_path, f'distance.txt')):
            os.remove(os.path.join(save_instance_path, f'distance.txt'))
        with open(os.path.join(save_instance_path, f'distance.txt'), 'w') as f:
            f.write(f'instance: {instance_id}\n')

        results[cate_id][instance_id] = dict()

        # for part_id in range(len(trans_part_label)):
        for part_id in range(part_num):
            part_vertices = ins_points[trans_part_label == part_id]
            gt_part_vertices = ins_points[ins_point_label == part_id]

            # print('part_id: ', part_id, 'part_vertices: ', part_vertices.shape, 'gt_part_vertices: ', gt_part_vertices.shape)

            if len(part_vertices) == 0 or len(gt_part_vertices) == 0:
                continue

            # 3D IoU
            bounding_box_part = np.array([part_vertices.min(axis=0), part_vertices.max(axis=0)])
            bounding_box_gt = np.array([gt_part_vertices.min(axis=0), gt_part_vertices.max(axis=0)])
            intersection_box = (np.maximum(bounding_box_part[0], bounding_box_gt[0]), np.minimum(bounding_box_part[1], bounding_box_gt[1]))
            union_box = (np.minimum(bounding_box_part[0], bounding_box_gt[0]), np.maximum(bounding_box_part[1], bounding_box_gt[1]))
            intersection = (intersection_box[1][0] - intersection_box[0][0]) * (intersection_box[1][1] - intersection_box[0][1]) * (intersection_box[1][2] - intersection_box[0][2])
            union = (union_box[1][0] - union_box[0][0]) * (union_box[1][1] - union_box[0][1]) * (union_box[1][2] - union_box[0][2])
            iou = intersection / union
            # print('iou: ', iou)
            if iou < 0 or np.isnan(iou):
                iou = 0

            # chamfer distance
            part_vertices = torch.from_numpy(part_vertices.astype(np.float32)).to(device)
            gt_part_vertices = torch.from_numpy(gt_part_vertices.astype(np.float32)).to(device)

            distance, _ = chamfer_distance(part_vertices[None], gt_part_vertices[None], batch_reduction='mean')
            distance = distance.item()

            results[cate_id][instance_id][part_id] = {
                'distance': distance,
                'iou': iou
            }

            # append the distance to the txt file
            with open(os.path.join(save_instance_path, f'distance.txt'), 'a') as f:
                f.write(f'part {part_id}: {distance}\n')

            instance_distance += distance

        instance_distance = instance_distance / len(trans_part_label)

        with open(os.path.join(save_category_path, f'distance.txt'), 'a') as f:
            f.write(f'instance {instance_id}: {instance_distance}\n')

        category_distance += instance_distance

        # visualize original mesh
        ins_ori_verts = np.array(ins_mesh.vertices)
        ins_ori_faces = np.array(ins_mesh.faces)

        kdtree = KDTree(ins_points)
        _, nearest_idx = kdtree.query(ins_ori_verts, k=1)

        ins_ori_label = trans_part_label[nearest_idx][:, 0]
        # print('ins_ori_label: ', ins_ori_label.shape, 'min: ', ins_ori_label.min(), 'max: ', ins_ori_label.max())

        # construct mesh
        cmap = plt.get_cmap('jet')
        # colors = cmap((ins_ori_label - ins_ori_label.min()) / (ins_ori_label.max() - ins_ori_label.min()))
        colors = cmap(ins_ori_label / part_num)
        # print('colors: ', colors.shape)
        verts_features = torch.tensor(colors[:, :3], dtype=torch.float32)[None]  # (1, V, 3)
        # features = torch.tensor(ins_ori_label, dtype=torch.float32)
        # features = features[None]  # (1, V)
        render_mesh = Meshes(
            verts=[torch.from_numpy(ins_ori_verts.astype(np.float32)).to(device)],
            faces=[torch.from_numpy(ins_ori_faces.astype(np.int32)).to(device)],
            textures=Textures(verts_features=verts_features.to(device))
        )

        # render
        R, T = look_at_view_transform(5.0, 30, 60, device=device)
        image = phong_renderer(meshes_world=render_mesh.clone(), R=R, T=T)
        image = image[0, ..., :3].detach().squeeze().cpu().numpy()
        image = np.array((image / image.max()) * 255).astype(np.uint8)
        Image.fromarray(image).save(os.path.join(save_instance_path, f'{instance_id}_render.png'))

        # gt_color = cmap((ins_part_label - ins_part_label.min()) / (ins_part_label.max() - ins_part_label.min()))
        gt_color = cmap(ins_part_label / part_num)
        verts_features = torch.tensor(gt_color[:, :3], dtype=torch.float32)[None]  # (1, V, 3)
        # features = torch.tensor(ins_part_label, dtype=torch.float32)
        # features = features[None]  # (1, V)
        gt_render_mesh = Meshes(
            verts=[torch.from_numpy(ins_ori_verts.astype(np.float32)).to(device)],
            faces=[torch.from_numpy(ins_ori_faces.astype(np.int32)).to(device)],
            textures=Textures(verts_features=verts_features.to(device))
        )
        # render
        R, T = look_at_view_transform(5.0, 30, 60, device=device)
        image = phong_renderer(meshes_world=gt_render_mesh.clone(), R=R, T=T)
        image = image[0, ..., :3].detach().squeeze().cpu().numpy()
        image = np.array((image / image.max()) * 255).astype(np.uint8)
        Image.fromarray(image).save(os.path.join(save_instance_path, f'{instance_id}_gt_render.png'))

    category_distance = category_distance / instance_number
    print(f'category {cate_id}: ', category_distance)
    with open(os.path.join(save_path, f'distance.txt'), 'a') as f:
        f.write(f'category {cate_id} {instance_number} instances: {category_distance}\n')


with open(json_path, 'w') as f:
    json.dump(results, f, indent=4)