import sys
sys.path.append('../')
sys.path.append('./')
import os
import numpy as np
from pytorch3d.structures import Meshes
from Models.ShapeNeRS import HarmonicEmbedding, EncodeNetwork
import torch
from pytorch3d.io import load_obj, save_obj
import trimesh
from pytorch3d.loss import (
    chamfer_distance, 
    mesh_edge_loss, 
    mesh_laplacian_smoothing, 
    mesh_normal_consistency,
)
import json

def load_mesh(file_path):
    """Load mesh from either OBJ or GLB file format."""
    if file_path.endswith('.obj'):
        return load_obj(file_path, load_textures=False)
    elif file_path.endswith('.glb'):
        # Load GLB using trimesh
        mesh = trimesh.load(file_path)
        # convert to mesh if scene
        if isinstance(mesh, trimesh.Scene):
            mesh = trimesh.util.concatenate(mesh.dump())
        # Convert to PyTorch3D format
        vertices = torch.tensor(mesh.vertices, dtype=torch.float32)
        faces = torch.tensor(mesh.faces, dtype=torch.int64)
        return vertices, faces, None
    else:
        raise ValueError(f"Unsupported file format: {file_path}")

def parse_args():
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--cate_id', type=str, default=None)
    parser.add_argument('--iteration', type=int, default=500)
    parser.add_argument('--exp_name', type=str, default='deform')
    parser.add_argument('--laplacian_weight', type=float, default=0.1)
    parser.add_argument('--normal_weight', type=float, default=1)
    parser.add_argument('--edge_weight', type=float, default=0)
    parser.add_argument('--deform_weight', type=float, default=1.0)
    args = parser.parse_args()
    return args

args = parse_args()

if args.cate_id is not None:
    cate_ids = [args.cate_id]
else:
    mesh_path = '../data/CorrData/new_DST_part/corr_mesh'
    cate_ids = os.listdir(mesh_path)

root_path = '../data/CorrData/new_DST_part/'
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
info_path = '../data/CorrData/new_DST_part/infos.json'

with open(info_path, 'r') as f:
    info = json.load(f)

for cate_id in cate_ids:
    if cate_id not in info:
        continue
    chosen_instance = info[cate_id]['chosen_instance']

    new_mesh_path = os.path.join(root_path, f'corr_recover_CAD_models/{cate_id}/{chosen_instance}')
    index_path = os.path.join(root_path, f'corr_index/{cate_id}/{chosen_instance}')
    save_path = os.path.join(root_path, f'deform/{cate_id}/{chosen_instance}/{args.exp_name}')
    if not os.path.exists(save_path):
        os.makedirs(save_path)

    chosen_mesh_file = os.path.join(new_mesh_path, f'{chosen_instance}.obj')
    chosen_glb_file = os.path.join(new_mesh_path, f'{chosen_instance}.glb')
    chosen_index_path = os.path.join(index_path, chosen_instance, 'index.npy')

    if os.path.exists(chosen_mesh_file):
        chosen_mesh = load_obj(chosen_mesh_file, load_textures=False)
        chosen_mesh = (torch.tensor(chosen_mesh[0], dtype=torch.float32), torch.tensor(chosen_mesh[1].verts_idx, dtype=torch.int64))
    elif os.path.exists(chosen_glb_file):
        mesh = trimesh.load(chosen_glb_file)
        if isinstance(mesh, trimesh.Scene):
            mesh = trimesh.util.concatenate(mesh.dump())
        chosen_mesh = (torch.tensor(mesh.vertices, dtype=torch.float32), torch.tensor(mesh.faces, dtype=torch.int64))
    else:
        raise ValueError(f"No mesh file found for {chosen_instance}")
    chosen_index = np.load(chosen_index_path, allow_pickle=True)[()]

    chosen_vertices = chosen_mesh[0]
    # print(chosen_vertices.shape)
    chosen_vertices = chosen_vertices.to(device)
    chosen_points = chosen_vertices[chosen_index]

    mesh_list = os.listdir(new_mesh_path)
    mesh_list = [mesh for mesh in mesh_list if (mesh.endswith('.obj') or mesh.endswith('.glb')) and mesh != f'{chosen_instance}.obj' and mesh != f'{chosen_instance}.glb']
    if os.path.exists(os.path.join(new_mesh_path, f'{chosen_instance}.obj')):
        mesh_list += [f'{chosen_instance}.obj']
    else:
        mesh_list += [f'{chosen_instance}.glb']

    number_of_cad = len(mesh_list)

    deform_encoder = HarmonicEmbedding().to(device)
    deform_net = EncodeNetwork(n_inputs=4, n_lantern=1, n_output=2, input_size=60, lantern_size=number_of_cad - 1,
                                    hidden_size=32, output_size=3).to(device)
                                    
    optimizer = torch.optim.Adam(list(deform_net.parameters()), lr=1e-4)

    deform_net.train()

    for i in range(args.iteration + 1):

        total_loss = 0
        for j in range(number_of_cad):
            instance_id = mesh_list[j].replace('.obj', '').replace('.glb', '')
            
            mesh_file = os.path.join(new_mesh_path, f'{instance_id}.obj')
            if not os.path.exists(mesh_file):
                mesh_file = os.path.join(new_mesh_path, f'{instance_id}.glb')
            index_file = os.path.join(index_path, instance_id, 'index.npy')

            vertices, _, _ = load_mesh(mesh_file)
            index = np.load(index_file, allow_pickle=True)[()]

            points = vertices[index]
            points = points.to(device)

            gt_get = points - chosen_points

            if j == number_of_cad - 1:
                one_hot_latent = torch.zeros((1, number_of_cad - 1)).to(device)
            else:
                one_hot_latent = torch.zeros((1, number_of_cad - 1)).to(device)
                one_hot_latent[0, j] = 1.0

            V = chosen_vertices.shape[0]

            one_hot_latent = one_hot_latent.expand(V, -1).contiguous()
            get = deform_net(deform_encoder(chosen_vertices), one_hot_latent)

            pred_get = get[chosen_index]

            loss_deform = torch.mean(torch.abs(pred_get - gt_get))

            new_vertices = chosen_vertices + get
            new_faces = chosen_mesh[1].to(device)
            # print(new_faces.shape)
            new_mesh = Meshes(verts=[new_vertices], faces=[new_faces])
            new_mesh = new_mesh.to(device)

            loss_normal = mesh_normal_consistency(new_mesh)
            
            loss_laplacian = mesh_laplacian_smoothing(new_mesh, method="uniform")

            loss_edge = mesh_edge_loss(new_mesh)

            loss = loss_deform * args.deform_weight + \
                   loss_normal * args.normal_weight + \
                   loss_laplacian * args.laplacian_weight + \
                   loss_edge * args.edge_weight

            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            total_loss += loss.item()

        total_loss /= number_of_cad
        print(f'Iteration {i}, Loss: {total_loss}')

        if i % 500 == 0 and i > 0:
            print(f'Saving model at iteration {i}')

            torch.save(deform_net.state_dict(), os.path.join(save_path, f'deform_net_{i}.pth'))
            torch.save(deform_encoder.state_dict(), os.path.join(save_path, f'deform_encoder_{i}.pth'))

            for mesh_id in range(min(number_of_cad // 3, 5)):
                deform_net.eval()

                instance_id_1 = mesh_list[mesh_id].replace('.obj', '').replace('.glb', '')
                instance_id_2 = mesh_list[mesh_id + 1].replace('.obj', '').replace('.glb', '')

                latent_1 = torch.zeros((1, number_of_cad - 1)).to(device)
                latent_1[0, mesh_id] = 1.0
                latent_1 = latent_1 / torch.norm(latent_1, dim=1, keepdim=True)
                latent_1 = latent_1.expand(chosen_vertices.shape[0], -1).contiguous()
                get_1 = deform_net(deform_encoder(chosen_vertices), latent_1)

                vertices_1 = chosen_vertices + get_1
                vertices_1 = vertices_1.cpu().detach()
                faces_1 = chosen_mesh[1].to(device)
                mesh_file_1 = os.path.join(save_path, f'{i}_mesh_{instance_id_1}.obj')
                save_obj(mesh_file_1, vertices_1, faces_1)
                
                mixed_latent = torch.zeros((1, number_of_cad - 1)).to(device)
                mixed_latent[0, mesh_id] = 1.0
                mixed_latent[0, mesh_id + 1] = 1.0
                mixed_latent = mixed_latent / torch.norm(mixed_latent, dim=1, keepdim=True)
                mixed_latent = mixed_latent.expand(chosen_vertices.shape[0], -1).contiguous()
                get = deform_net(deform_encoder(chosen_vertices), mixed_latent)

                mixed_vertices = chosen_vertices + get
                # print(f'Mixed vertices shape: {mixed_vertices.shape}')
                mixed_vertices = mixed_vertices.cpu().detach()
                mixed_faces = chosen_mesh[1].to(device)
                # print('mixed_faces shape:', mixed_faces.shape)

                mesh_file = os.path.join(save_path, f'{i}_mixed_mesh_{instance_id_1}_{instance_id_2}.obj')

                save_obj(mesh_file, mixed_vertices, mixed_faces)
                print(f'Exported mixed mesh {mesh_id} to {mesh_file}')

