import sys

sys.path.append('../corr/lib')
sys.path.append('../corr')
import torch
import numpy as np
import os
from PIL import Image, ImageDraw
from MeshUtils import *
from pytorch3d.renderer import PerspectiveCameras
from pytorch3d.io import load_obj
from pytorch3d import _C
from config import get_config, print_usage
from capsule import Network
from sklearn.neighbors import KDTree
import matplotlib.pyplot as plt
import point_cloud_utils as pcu
# import open3d as o3d
import trimesh
import json
from mpl_toolkits.mplot3d import Axes3D
import warnings
from pytorch3d.loss import chamfer_distance
import shutil

def vis_pts_att(pts, label_map, fn="temp.png", marker=".", alpha=0.9):
    # pts (n, d): numpy, d-dim point cloud
    # label_map (n, ): numpy or None
    # fn: filename of visualization
    assert pts.shape[1] == 3
    TH = 0.7
    fig = plt.figure()
    ax = fig.add_subplot(111, projection="3d")
    ax.set_zlim(-TH,TH)
    ax.set_xlim(-TH,TH)
    ax.set_ylim(-TH,TH)
    xs = pts[:, 0]
    ys = pts[:, 1]
    zs = pts[:, 2]
    if label_map is not None:
        ax.scatter(xs, ys, zs, c=label_map, cmap="jet", marker=marker, alpha=alpha)
    else:
        ax.scatter(xs, ys, zs, marker=marker, alpha=alpha, edgecolor="none")

    ax.set_xticklabels([])
    ax.set_yticklabels([])
    ax.set_zticklabels([])
    plt.savefig(
        fn,
        bbox_inches='tight',
        pad_inches=0,
        dpi=300,)
    plt.close()


device = 'cuda:0'

# prepare network
config, unparsed = get_config()
if len(unparsed) > 0:
    print_usage()
    exit(1)

sample_interval = 50

config.res_dir = '../capsules/logs'

network = Network(config)
network.model.eval()
network.load_checkpoint()


def capsule_decompose(pc, R=None, T=None):
    with torch.no_grad():
        _x = pc[None]
        _labels, _feats = network.model.decompose_one_pc(_x, R, T)
    return _labels, _feats


# # render config
# render_image_size = (512, 512)
# image_size = (512, 512)

# raster_settings = RasterizationSettings(
#     image_size=render_image_size,
#     blur_radius=0.0,
#     faces_per_pixel=1,
#     bin_size=0
# )
# # We can add a point light in front of the object.
# lights = PointLights(device=device, location=((2.0, 2.0, -2.0),))

# # prepare camera
# cameras = PerspectiveCameras(focal_length=1.0 * 3200,
#                              principal_point=((render_image_size[1] // 2, render_image_size[0] // 2),),
#                              image_size=(render_image_size,), device=device, in_ndc=False)

# phong_renderer = MeshRenderer(
#     rasterizer=MeshRasterizer(
#         cameras=cameras,
#         raster_settings=raster_settings
#     ),
#     shader=HardPhongShader(device=device, lights=lights, cameras=cameras),
# )

R_can = torch.tensor([[[0.3456, 0.5633, 0.7505],
                    [-0.9333, 0.2898, 0.2122],
                    [-0.0980, -0.7737, 0.6259]]]).to(device)
T_can = torch.tensor([[[-0.0161], [-0.0014], [-0.0346]]]).to(device)

root_path = '../data/CorrData/new_DST_part'
mesh_path = os.path.join(root_path, f'aligned_CAD_models_new')

file_name_list = os.listdir(mesh_path)
# file_name_list = ['n02835271', 'n03345487', 'n03417042', 'n03496892', 'n03599486', 'n03642806', 'n03649909', 'n03670208', 'n03673027', 'n03785016', 'n03947888', 'n04065272', 'n04146614', 'n04147183', 'n04252225', 'n04285008', 'n04482393', 'n04483307', 'n04509417', 'n04552348', 'n04612504']

# info_path = os.path.join(root_path, 'new_infos.json')
info_path = os.path.join(root_path, 'infos.json')
with open(info_path, 'r') as f:
    infos = json.load(f)

for file_name in file_name_list:
    # print('config.cat_type: ', config.cat_type)
    if config.cat_type != 'car' and file_name != config.cat_type:
            continue

    cate_id = file_name
    
    print('cate_id: ', file_name)
    if cate_id not in infos:
        continue
    chosen_instance = infos[cate_id]['chosen_instance']

    mesh_path = os.path.join(root_path, f'aligned_CAD_models_new/{cate_id}')
    index_path = os.path.join(root_path, f'corr_index/{cate_id}/{chosen_instance}')
    # if os.path.exists(index_path):
    #     print(f'index_path already exists: {index_path}')
    #     continue

    new_mesh_path = os.path.join(root_path, f'corr_mesh/{cate_id}/{chosen_instance}')
    visual_path = os.path.join(root_path, f'corr_visual/{cate_id}/{chosen_instance}')

    if not os.path.exists(new_mesh_path):
        os.makedirs(new_mesh_path)

    if not os.path.exists(index_path):
        os.makedirs(index_path)

    if not os.path.exists(visual_path):
        os.makedirs(visual_path)

    with warnings.catch_warnings():
        warnings.simplefilter("ignore", category=UserWarning)
        obj_file = os.path.join(mesh_path, f'{chosen_instance}.obj')
        glb_file = os.path.join(mesh_path, f'{chosen_instance}.glb')
        if os.path.exists(obj_file):
            file_name = f'{chosen_instance}.obj'
            mesh = load_obj(os.path.join(mesh_path, f'{chosen_instance}.obj'))
            # normalize
            vert_middle = (mesh[0].max(axis=0)[0] + mesh[0].min(axis=0)[0]) / 2
            vert_scale = (mesh[0].max(axis=0)[0] - mesh[0].min(axis=0)[0]).max()
            vertices = (mesh[0] - vert_middle) / vert_scale
            faces = mesh[1].verts_idx
            # print('vertices: ', vertices.shape, 'faces: ', faces.shape)
            # print('mesh: ', mesh[0].shape, mesh[1].verts_idx.shape)
            mesh = trimesh.Trimesh(vertices=vertices, faces=faces, process=False)
        elif os.path.exists(glb_file):
            file_name = f'{chosen_instance}.glb'
            mesh = trimesh.load_mesh(glb_file)
            if isinstance(mesh, trimesh.Scene):
                # Convert the scene to a single mesh
                mesh = trimesh.util.concatenate(mesh.dump())

            # Extract vertices and faces
            vertices = mesh.vertices
            faces = mesh.faces

            vert_middle = (vertices.max(axis=0) + vertices.min(axis=0)) / 2
            vert_scale = (vertices.max(axis=0) - vertices.min(axis=0)).max()
            vertices = vertices - vert_middle
            vertices = vertices / vert_scale
            mesh = trimesh.Trimesh(vertices=vertices, faces=faces, process=False)
        else:
            print(f'No mesh found for {chosen_instance}')
            continue

    # get point_cloud from sampling
    points, face_indices = trimesh.sample.sample_surface(mesh, config.num_pts)

    vis_pts_att(points, np.arange(config.num_pts) / config.num_pts, os.path.join(visual_path, f'ref_segment.png'))

    # decompose using point cloud
    prev_x = torch.from_numpy(points).to(device, dtype=torch.float32)

    label, prev_feats = capsule_decompose(prev_x, R_can, T_can)
    prev_feats = prev_feats[0, :, 0, :, 0]
    prev_feats = prev_feats / torch.norm(prev_feats, dim=0)

    # save the index
    index = np.arange(len(points)) + len(mesh.vertices)
    if not os.path.exists(os.path.join(index_path, f'{chosen_instance}')):
        os.makedirs(os.path.join(index_path, f'{chosen_instance}'))
    np.save(os.path.join(index_path, f'{chosen_instance}', 'index.npy'), index)

    # append the points to the mesh
    mesh.vertices = np.concatenate((mesh.vertices, points), axis=0)

    # save the mesh
    mesh.export(os.path.join(new_mesh_path, file_name))

    file_list = os.listdir(mesh_path)
    for file_name in file_list:
        if chosen_instance in file_name:
            continue

        instance_id = file_name[:-4]

        # get point_cloud from sampling
        with warnings.catch_warnings():
            if '.obj' in file_name:
                warnings.simplefilter("ignore", category=UserWarning)
                ins_mesh = load_obj(os.path.join(mesh_path, file_name))
                # normalize
                ins_vert_middle = (ins_mesh[0].max(axis=0)[0] + ins_mesh[0].min(axis=0)[0]) / 2
                ins_vert_scale = (ins_mesh[0].max(axis=0)[0] - ins_mesh[0].min(axis=0)[0]).max()
                ins_vertices = (ins_mesh[0] - ins_vert_middle) / ins_vert_scale
                ins_faces = ins_mesh[1].verts_idx
            elif '.glb' in file_name:
                ins_mesh = trimesh.load_mesh(os.path.join(mesh_path, file_name))
                if isinstance(ins_mesh, trimesh.Scene):
                    # Convert the scene to a single mesh
                    ins_mesh = trimesh.util.concatenate(ins_mesh.dump())
                ins_vertices = ins_mesh.vertices
                ins_faces = ins_mesh.faces
                ins_vert_middle = (ins_vertices.max(axis=0) + ins_vertices.min(axis=0)) / 2
                ins_vert_scale = (ins_vertices.max(axis=0) - ins_vertices.min(axis=0)).max()
                ins_vertices = ins_vertices - ins_vert_middle
                ins_vertices = ins_vertices / ins_vert_scale

        ins_mesh = trimesh.Trimesh(vertices=ins_vertices, faces=ins_faces, process=False)

        # save index
        index = np.arange(config.num_pts) + len(ins_mesh.vertices)
        if not os.path.exists(os.path.join(index_path, f'{instance_id}')):
            os.makedirs(os.path.join(index_path, f'{instance_id}'))
        np.save(os.path.join(index_path, f'{instance_id}', 'index.npy'), index)

        append_points = []
        for point_id in range(config.num_pts):
            if point_id % sample_interval == 0:
                ins_points, ins_face_indices = trimesh.sample.sample_surface(ins_mesh, config.num_pts)
                ins_v = torch.from_numpy(ins_points).to(device, dtype=torch.float32)
                label, feats = capsule_decompose(ins_v, R_can, T_can)
                feats = feats[0, :, 0, :, 0]
                feats = feats / torch.norm(feats, dim=0)

            prev_feat = prev_feats[:, point_id]
            # print('prev_feat: ', prev_feat.shape)
            # print('feats: ', feats.shape)
            similarity = torch.matmul(prev_feat, feats)
            # print('similarity: ', similarity.shape)

            max_point = torch.argmax(similarity, dim=0)
            # print('max_point: ', max_point)
            append_point = ins_points[max_point.cpu().numpy()]
            append_points.append(append_point)

        append_points = np.array(append_points)        

        # print('before appending ins_mesh vertices: ', ins_mesh.vertices.shape)

        ins_mesh.vertices = np.concatenate((ins_mesh.vertices, append_points), axis=0)

        # print('ins_mesh vertices: ', ins_mesh.vertices.shape)
        
        # save the mesh
        ins_mesh.export(os.path.join(new_mesh_path, file_name), )

        # # load the mesh to check
        # ins_mesh = load_obj(os.path.join(new_mesh_path, f'{instance_id}.obj'))
        # print('loaded ins_mesh vertices: ', ins_mesh[0].shape)


        # visualize correspondence
        vis_pts_att(append_points, np.arange(config.num_pts) / config.num_pts, os.path.join(visual_path, f'{instance_id}.png'))
        print(file_name, 'finished!')
