import os
import sys

BASEDIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
if BASEDIR not in sys.path:
    sys.path.insert(0, BASEDIR)

import jax
import jax.numpy as jnp
from functools import partial

import util.model_util as mutil
from .base import ContinuousCollisionCostBase
import util.latent_obj_util as loutil

def smooth_col_cost(col_logits, mu=0.1):
    col_cost = jnp.where(col_logits > 0, col_logits + 0.5*mu, col_logits)
    col_cost = jnp.where(jnp.logical_and(0>col_logits, col_logits>-mu), 0.5/mu*(col_logits+mu)**2, col_cost)
    col_cost = jnp.where(col_logits<-mu, 0, col_cost)
    return col_cost

class OursCCD(ContinuousCollisionCostBase):
    def __init__(self, models: mutil.Models, collision_threshold:int, col_coef:int, reduce_k:int, is_continuous: bool, broad_phase_cls, broadphase_type: str, return_collision_loss_pair=False):
        self.models = models
        self.is_continuous = is_continuous
        self.reduce_k = reduce_k
        self.broad_phase_cls = broad_phase_cls
        self.broadphase_type = broadphase_type
        self.collision_threshold = collision_threshold
        self.col_coef = col_coef
        self.return_collision_loss_pair = return_collision_loss_pair
    
    def __call__(self, moving_obj:loutil.LatentObjects, moving_obj_pqs:jnp.ndarray, fixed_obj:loutil.LatentObjects, jkey, interpolation_num:int, visualize:bool):
        '''
        Args:
            moving_obj: LatentObjects, (NOBJA, ...)
            moving_obj_pqs: jnp.ndarray, (..., NSEG, NOBJA, 7)
            fixed_obj: LatentObjects, (NOBJB, ...)
            jkey: jnp.ndarray, (2,)
            interpolation_num: int, number of interpolation
            visualize: bool, whether to visualize
        '''
        if self.is_continuous:
            if self.broadphase_type.split("_")[1] in ['segment', 'traj', 'independent']:
                if self.broadphase_type.split("_")[1] in ['traj']:
                    collision_loss_pair = self.models.apply('col_decoder', fixed_obj.merge()[None], moving_obj, pq_transform_B=moving_obj_pqs,
                                                                                        reduce_k=self.reduce_k, broadphase_type=self.broadphase_type,
                                                                                        broadphase_func=self.broad_phase_cls.path_broad_phase,
                                                                                        path_check=True, jkey=jkey, debug=visualize)
                    if visualize:
                        return collision_loss_pair
                elif self.broadphase_type.split("_")[1] in ['segment', 'independent']:
                    if self.broadphase_type in ['naivebf_independent', 'naive_independent']:
                        moving_obj_pqs = jnp.stack([moving_obj_pqs[...,1:,:,:], moving_obj_pqs[...,:-1,:,:]], axis=-3)
                    else:
                        if self.broadphase_type.split("_")[1] in ['segment']:
                            # gap, seg_len = interpolation_num, interpolation_num+1
                            nseg = 4
                            seg_len = moving_obj_pqs.shape[-3]//nseg+1
                            gap = moving_obj_pqs.shape[-3]//nseg
                            # gap, seg_len = moving_obj_pqs.shape[-3]//2, moving_obj_pqs.shape[-3]//2+1
                        else: # independent
                            gap, seg_len, nseg = 1, 3, moving_obj_pqs.shape[-3]-2
                        assert seg_len >= gap
                        # assert (moving_obj_pqs.shape[-3] - seg_len) % gap == 0
                        # nseg = (moving_obj_pqs.shape[-3] - seg_len)//gap + 1
                        moving_obj_pqs_list = []
                        for i in range(nseg):
                            if i == nseg-1:
                                moving_obj_pqs_list.append(moving_obj_pqs[...,-seg_len:,:,:])
                            else:
                                moving_obj_pqs_list.append(moving_obj_pqs[...,gap*i:gap*i+seg_len,:,:])
                        moving_obj_pqs = jnp.stack(moving_obj_pqs_list, axis=-4)
                    moving_obj_pqs = jnp.moveaxis(moving_obj_pqs, -4, 0)
                    before_outer_shape = moving_obj_pqs.shape[:-3]
                    collision_loss_pair = self.models.apply('col_decoder', fixed_obj.merge()[None], moving_obj, 
                                                        reduce_k=self.reduce_k, path_check=True, jkey=jkey,
                                                        broadphase_type=self.broadphase_type,
                                                        broadphase_func=self.broad_phase_cls.path_broad_phase,
                                                        pq_transform_B=moving_obj_pqs.reshape(-1, moving_obj_pqs.shape[-3], moving_obj_pqs.shape[-2], moving_obj_pqs.shape[-1]),
                                                        debug=visualize,
                                                        )
                    collision_loss_pair = collision_loss_pair.reshape(before_outer_shape + collision_loss_pair.shape[1:])
                    collision_loss_pair = jnp.moveaxis(collision_loss_pair.squeeze(-1), 0, -1)

                collision_binary = jnp.where(collision_loss_pair > -0.5, 1, 0)

                loss_per_batch = smooth_col_cost(collision_loss_pair, mu=self.collision_threshold) + 10*collision_binary
                # collision_loss = self.col_coef*jnp.sum(loss_per_batch, axis=(-1,-2,-3)) + 100*jnp.max(collision_binary, axis=(-1,-2,-3))
                collision_loss = self.col_coef*jnp.sum(loss_per_batch, axis=(-1,-2,-3))

                collision_logits = mutil.aggregate_cost(collision_loss_pair, axes=(-1,-2,-3), reduce_ops=jnp.sum)
            elif self.broadphase_type.split("_")[1] == 'aabb':
                collision_loss_pair = jax.vmap(
                    partial(self.models.apply, 'col_decoder', fixed_obj, moving_obj, reduce_k=self.reduce_k, path_check=True, 
                            broadphase_type = self.broadphase_type, jkey=jkey)
                )(pq_transform_B=moving_obj_pqs)
                collision_loss = jnp.sum(smooth_col_cost(collision_loss_pair, mu=self.collision_threshold), axis=(-1,-2,-3,-4))
                collision_logits = mutil.aggregate_cost(collision_loss_pair, axes=(-1,-2,-3,-4), reduce_ops=jnp.sum)

        else: # stamping
            ## interpolated collision check
            # objB_query = moving_obj.apply_pq_z(moving_obj_pqs[...,:3], moving_obj_pqs[...,3:], self.models.rot_configs) # (NB, NSEG, NOBJ, ...)
            # objB_query = objB_query.reshape_outer_shape((-1, objB_query.shape[-2]*objB_query.shape[-1])) # (NB, NSEG*NOBJ, ...)
            # col_logits, col_logits_patch_A, col_logits_patch_B = self.models.apply('col_decoder', fixed_obj, objB_query,
            #                                                                     merge=True, reduce_k=self.reduce_k, jkey=jkey)

            # collision check for each segment
            collision_loss_pair, col_logits_patch_A, col_logits_patch_B = self.models.apply('col_decoder', fixed_obj, moving_obj, 
                                                                                pq_transform_B=moving_obj_pqs,
                                                                                merge=True, reduce_k=self.reduce_k, jkey=jkey)
            collision_loss_pair = collision_loss_pair.squeeze(-1) # (NB, NSEG, 1)
            collision_binary = jnp.where(collision_loss_pair > -0.5, 1, 0)
            # collision_loss = jnp.sum(smooth_col_cost(col_logits, mu=self.collision_threshold), axis=-1) + 100*jnp.max(collision_binary, axis=(-1,))
            collision_loss = jnp.sum(smooth_col_cost(collision_loss_pair, mu=self.collision_threshold), axis=-1)
            collision_logits = mutil.aggregate_cost(collision_loss_pair, axes=(-1,), reduce_ops=jnp.sum)

        return collision_loss, ({
            "collision_binary": collision_binary,
            "collision_logits": collision_logits,
            "collision_loss_pair": collision_loss_pair,
        } if self.return_collision_loss_pair else{
            "collision_binary": collision_binary,
            "collision_logits": collision_logits,
        })