import os
import sys

BASEDIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
if BASEDIR not in sys.path:
    sys.path.insert(0, BASEDIR)

import jax
import jax.numpy as jnp
import util.model_util as mutil
import util.transform_util as tutil
import util.latent_obj_util as loutil
from typing import Tuple, Dict


class ContinuousCollisionCostBase:
    def __init__(self):
        pass
    
    def __call__(
            self,
            moving_obj: loutil.LatentObjects,
            moving_obj_pqs: jnp.ndarray,
            fixed_obj: loutil.LatentObjects
        ) -> Tuple[jnp.ndarray, Dict]:
        raise NotImplementedError