from model.rotation2xyz import Rotation2xyz
import numpy as np
from trimesh import Trimesh
import os
import torch
from utils.simplify_loc2rot import joints2smpl


class Joints2Obj:
    def __init__(self, motions, device=0, cuda=True):
        raw_joints = motions[:, :, :, 8::10]                            # [bs, 22, 3, 16]

        self.rot2xyz = Rotation2xyz(device='cpu')
        self.faces = self.rot2xyz.smpl_model.faces

        bs, njoints, nfeats, nframes = raw_joints.shape
        self.j2s = joints2smpl(batch_size=bs*nframes, device_id=device, cuda=cuda)

        motion_tensor = self.j2s.joint2smpl(raw_joints)                 # [bs, 25, 6, 16]
        smpl_res = self.rot2xyz(motion_tensor.cpu(), mask=None,
                                pose_rep='rot6d', translation=True, glob=True,
                                jointstype='vertices',
                                vertstrans=False,
                                return_full=True)

        self.vertices = smpl_res['vertices'].view(bs, nframes, -1, 3)           # [bs, 16, 6890, 3]
        self.joints = smpl_res['smpl'].view(bs, nframes, -1, 3)                 # [bs, 16, 24, 3]

        self.motions = motion_tensor.cpu().numpy()
        root_loc = motion_tensor[:, -1, :3].permute(0, 2, 1).cpu()              # [bs, 16, 3]
        root_loc += (raw_joints[:, 0, :, 0] - self.joints[:, 0, 0] - root_loc[:, 0]).unsqueeze(1)
        root_loc = root_loc.unsqueeze(2).numpy()                                # [bs, 16, 1, 3]

        self.vertices += root_loc
        self.joints += root_loc

    def save_obj(self, save_path):
        for n in range(self.vertices.shape[0]):
            os.makedirs(os.path.join(save_path, "mesh" + str(n)), exist_ok=True)
            for t in range(self.vertices.shape[1]):
                mesh = Trimesh(vertices=self.vertices[n, t].squeeze().tolist(), faces=self.faces)
                with open(os.path.join(save_path, "mesh" + str(n), "mesh" + str(t) + ".obj"), 'w') as fw:
                    mesh.export(fw, 'obj')

    def save_npy(self, save_path):
        data_dict = {
            'motion': self.motions['motion'][0],
            'thetas': self.motions['motion'][0, :-1],
            'root_translation': self.motions['motion'][0, -1, :3],
            'faces': self.faces,
            'vertices': self.vertices[0],
        }
        np.save(save_path, data_dict)
