import os
import shutil
import numpy as np
import pandas as pd
import torch
from pytorch3d.io import load_objs_as_meshes, save_obj
from pytorch3d.structures import Meshes
from pytorch3d.transforms import Rotate
import trimesh
from pytorch3d.transforms import RotateAxisAngle
import json

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def load_mesh(model_path):
    filename = os.path.basename(model_path)
    if filename.endswith('.obj'):
        mesh = load_objs_as_meshes([model_path], device=device)
    elif filename.endswith('.glb'):
        mesh = trimesh.load_mesh(model_path, process=False)
        if isinstance(mesh, trimesh.Scene):
            mesh = trimesh.util.concatenate(mesh.dump())
        verts = mesh.vertices
        faces = mesh.faces
        verts = torch.from_numpy(verts.astype(np.float32)).unsqueeze(0)
        faces = torch.from_numpy(faces.astype(np.int32)).unsqueeze(0)
        mesh = Meshes(verts=verts.to(device), faces=faces.to(device))
    else:
        raise ValueError(f"Unsupported file extension: {filename.split('.')[-1]}")
    return mesh

def save_mesh(verts, faces, out_path):
    filename = os.path.basename(out_path)
    if filename.endswith('.obj'):
        save_obj(out_path, verts, faces)
    elif filename.endswith('.glb'):
        mesh = trimesh.Trimesh(vertices=verts.cpu().numpy(), faces=faces.cpu().numpy(), process=False)
        mesh.export(out_path)
    else:
        raise ValueError(f"Unsupported file extension: {filename.split('.')[-1]}")


if __name__ == "__main__":
    # Load metadata
    raw_model_data = pd.read_csv('./3d-dst-models.csv')
    cad_model_dir = "../data/CorrData/new_DST_part/corr_mesh"
    recover_cad_model_dir = "../data/CorrData/new_DST_part/corr_recover_CAD_models"

    # Build a lookup table
    model_data = {}
    nids = [x for x in os.listdir(cad_model_dir) if not x.endswith('.json')]
    # nids = ['n02835271', 'n03345487', 'n03417042', 'n03496892', 'n03599486', 'n03642806', 'n03649909', 'n03670208', 'n03673027', 'n03785016', 'n03947888', 'n04065272', 'n04146614', 'n04147183', 'n04252225', 'n04285008', 'n04482393', 'n04483307', 'n04509417', 'n04552348', 'n04612504']

    info_path = '../data/CorrData/new_DST_part/infos.json'
    with open(info_path, 'r') as f:
        infos = json.load(f)
    
    for _, row in raw_model_data.iterrows():
        nid = str(row[0])
        if nid not in nids:
            continue
        if nid not in model_data:
            model_data[nid] = []
        model_data[nid].append(row[1:])

    for nid in nids:
        if 'transfer' == nid:
            continue
        models = model_data[nid]
        out_dir = os.path.join(recover_cad_model_dir, nid, infos[nid]['chosen_instance'])
        os.makedirs(out_dir, exist_ok=True)
        for model_entry in models:
            model_id, distance, azimuth, elevation, strength, sampling = model_entry
            model_id = str(model_id)
            model_path = None
            for ext in ['.obj', '.glb']:
                model_path = os.path.join(cad_model_dir, nid, infos[nid]['chosen_instance'], f'{model_id}{ext}')
                if os.path.exists(model_path):
                    out_path = os.path.join(out_dir, f'{model_id}{ext}')
                    break

            if not os.path.exists(model_path):
                print(f"Model {model_path} not found")
                continue

            azimuth = torch.tensor(azimuth / 180.0 * np.pi)    # Blender az in radians
            elevation = torch.tensor(elevation / 180.0 * np.pi)  # Blender elev in radians

            # Create individual rotations: pitch then yaw
            R_pitch = RotateAxisAngle(angle=-elevation, axis="X", degrees=False)  # invert elevation
            R_yaw   = RotateAxisAngle(angle=-azimuth, axis="Y", degrees=False)    # invert azimuth

            # Compose: first X then Y rotation
            R_transform = R_pitch.compose(R_yaw)  # Transform3d: applies R_pitch then R_yaw

            # get the inverse rotation matrix
            R_transform = R_transform.inverse().to(device)

            mesh = load_mesh(model_path)
            rotated_verts = R_transform.transform_points(mesh.verts_padded())
            rotated_mesh = mesh.update_padded(rotated_verts)

            verts = rotated_mesh.verts_packed()
            faces = rotated_mesh.faces_packed()
            save_mesh(verts, faces, out_path)