from __future__ import annotations
try:
    import pybullet as p
except:
    print('off pybullet')
import numpy as np
import jax.numpy as jnp
import os


# import util.render_util as rutil
# import util.io_util as ioutil
# import util.cvx_util as cxutil


class ObjCls(object):
    """NOTE(ssh): Need to clean up unused arguments..."""
    def __init__(
            self, 
            obj_finename=None, 
            scale=np.array([1.,1.,1.]), 
            primitive_params=None, 
            pbcolobj=None, 
            pbvisobj=None, 
            o3dobj=None, 
            pcdobj=None, 
            shift_com=False
    ):
        self.obj_filename = obj_finename
        self.primitive_params = primitive_params # 0: sphere, 1: box, 2: cylinder ....
        self.pbcolobj = pbcolobj
        self.pbvisobj = pbvisobj
        self.o3dobj = o3dobj
        self.pcdobj = pcdobj
        self.shift_com = shift_com

        self.create_pbobj(1)
        aabb = p.getAABB(self.pbobj)
        base_length = np.max(np.array(aabb[1]) - np.array(aabb[0]))
        scale = scale / base_length
        p.removeBody(self.pbobj)
        self.create_pbobj(scale)
        
        self.scale = scale


    def create_pbobj(self, scale):
        if self.pbcolobj is not None:
            assert os.path.exists(self.pbcolobj), f"pbcolobj {self.pbcolobj} does not exist"
            cshape = p.createCollisionShape(p.GEOM_MESH, fileName=self.pbcolobj, meshScale=[scale,scale,scale])
        else:
            cshape = -1

        if self.pbvisobj is not None:
            assert os.path.exists(self.pbvisobj), f"pbvisobj {self.pbvisobj} does not exist"
            vshape = p.createVisualShape(p.GEOM_MESH, fileName=self.pbvisobj, meshScale=[scale,scale,scale])
        else:
            vshape = -1
        baseInertialFramePosition = np.zeros(3)
        self.baseInertialFramePosition = baseInertialFramePosition
        self.pbobj = p.createMultiBody(baseMass=0.3, baseInertialFramePosition=baseInertialFramePosition, baseCollisionShapeIndex=cshape, baseVisualShapeIndex=vshape)
        # get aabb and base lenght -> normalize scales
        
        # p.changeVisualShape(self.pbobj, -1, meshScale=scale)
        




def cal_center_from_obj_file(fileName):
    vtx = []
    f = open(fileName)
    for line in f:
        if line[:2] == "v ":
            index1 = line.find(" ") + 1
            index2 = line.find(" ", index1 + 1)
            index3 = line.find(" ", index2 + 1)
            vertex = (float(line[index1:index2]), float(line[index2:index3]), float(line[index3:-1]))
            vtx.append(vertex)
    f.close()
    return (np.min(vtx, axis=0) + np.max(vtx, axis=0)) * 0.5