## Termination functions are mostly from MOBILE (https://github.com/yihaosun1124/mobile/blob/main/utils/termination_fns.py)
## For antmaze tasks, we borrowed the formula from CBOP implementation (https://github.com/jihwan-jeong/CBOP/blob/3fb693e513547af87bb59dca9c34c35fc101b55b/rlkit/envs/termination_funcs.py)

import numpy as np
import jax.numpy as jnp
import jax


def obs_unnormalization(termination_fn, obs_mean, obs_std):
    def thunk(obs, act, next_obs):
        obs = obs * obs_std + obs_mean
        next_obs = next_obs * obs_std + obs_mean
        return termination_fn(obs, act, next_obs)

    return thunk


def termination_fn_halfcheetah(obs, act, next_obs):
    assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 2

    not_done = jnp.logical_and(
        jnp.all(next_obs > -100, axis=-1), jnp.all(next_obs < 100, axis=-1)
    )
    done = ~not_done
    done = done[:, None]
    return done


def termination_fn_hopper(obs, act, next_obs):
    assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 2

    height = next_obs[:, 0]
    angle = next_obs[:, 1]
    not_done = (
        jnp.isfinite(next_obs).all(axis=-1)
        * jnp.abs(next_obs[:, 1:] < 100).all(axis=-1)
        * (height > 0.7)
        * (jnp.abs(angle) < 0.2)
    )

    done = ~not_done
    done = done[:, None]
    return done


def termination_fn_halfcheetahveljump(obs, act, next_obs):
    assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 2

    done = jnp.array([False]).repeat(len(obs))
    done = done[:, None]
    return done


def termination_fn_antangle(obs, act, next_obs):
    assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 2

    x = next_obs[:, 0]
    not_done = jnp.isfinite(next_obs).all(axis=-1) * (x >= 0.2) * (x <= 1.0)

    done = ~not_done
    done = done[:, None]
    return done


def termination_fn_antmaze(obs, act, next_obs, center, radius):
    assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 2
    done = ((next_obs[:, :2] - center[None, :]) ** 2).sum(axis=1) < radius * radius
    done = done[:, None]
    return done


def termination_fn_ant(obs, act, next_obs):
    x = next_obs[:, 0]
    not_done = jnp.isfinite(next_obs).all(axis=-1) * (x >= 0.2) * (x <= 1.0)

    done = ~not_done
    done = done[:, None]
    return done


def termination_fn_walker2d(obs, act, next_obs):
    assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 2

    height = next_obs[:, 0]
    angle = next_obs[:, 1]
    not_done = (
        jnp.logical_and(
            jnp.all(next_obs > -100, axis=-1), jnp.all(next_obs < 100, axis=-1)
        )
        * (height > 0.8)
        * (height < 2.0)
        * (angle > -1.0)
        * (angle < 1.0)
    )
    done = ~not_done
    done = done[:, None]
    return done


def termination_fn_point2denv(obs, act, next_obs):
    assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 2

    done = jnp.array([False]).repeat(len(obs))
    done = done[:, None]
    return done


def termination_fn_point2dwallenv(obs, act, next_obs):
    assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 2

    done = jnp.array([False]).repeat(len(obs))
    done = done[:, None]
    return done


def termination_fn_pendulum(obs, act, next_obs):
    assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 2

    done = jnp.zeros((len(obs), 1))
    return done


def termination_fn_humanoid(obs, act, next_obs):
    assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 2

    z = next_obs[:, 0]
    done = (z < 1.0) + (z > 2.0)

    done = done[:, None]
    return done


def termination_fn_pen(obs, act, next_obs):
    assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 2

    obj_pos = next_obs[:, 24:27]
    done = obj_pos[:, 2] < 0.075

    done = done[:, None]
    return done


def terminaltion_fn_door(obs, act, next_obs):
    assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 2

    done = jnp.array([False] * obs.shape[0])

    done = done[:, None]
    return done


def termination_fn_neorl_halfcheetah(obs, act, next_obs):
    assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 2

    done = jnp.array([False] * obs.shape[0])
    done = done[:, None]
    return done


def termination_fn_neorl_hopper(obs, act, next_obs):
    assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 2

    z = next_obs[:, 1:2]
    angle = next_obs[:, 2:3]
    state = next_obs[:, 3:]

    min_state, max_state = (-100.0, 100.0)
    min_z, max_z = (0.7, float("inf"))
    min_angle, max_angle = (-0.2, 0.2)

    healthy_state = jnp.all(
        jnp.logical_and(min_state < state, state < max_state), axis=-1, keepdims=True
    )
    healthy_z = jnp.logical_and(min_z < z, z < max_z)
    healthy_angle = jnp.logical_and(min_angle < angle, angle < max_angle)

    is_healthy = jnp.logical_and(
        jnp.logical_and(healthy_state, healthy_z), healthy_angle
    )

    done = jnp.logical_not(is_healthy).reshape(-1, 1)
    return done


def termination_fn_neorl_walker2d(obs, act, next_obs):
    assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 2

    min_z, max_z = (0.8, 2.0)
    min_angle, max_angle = (-1.0, 1.0)
    min_state, max_state = (-100.0, 100.0)

    z = next_obs[:, 1:2]
    angle = next_obs[:, 2:3]
    state = next_obs[:, 3:]

    healthy_state = jnp.all(
        jnp.logical_and(min_state < state, state < max_state), axis=-1, keepdims=True
    )
    healthy_z = jnp.logical_and(min_z < z, z < max_z)
    healthy_angle = jnp.logical_and(min_angle < angle, angle < max_angle)
    is_healthy = jnp.logical_and(
        jnp.logical_and(healthy_state, healthy_z), healthy_angle
    )
    done = jnp.logical_not(is_healthy).reshape(-1, 1)
    return done


def termination_fn_kitchen(obs, act, next_obs):
    done = jnp.array([False] * obs.shape[0])
    done = done[:, None]
    return done


def termination_fn_dmc(obs, act, next_obs):
    done = jnp.array([False] * obs.shape[0])
    done = done[:, None]
    return done


def get_termination_fn(task):
    if "dmc" in task:
        return termination_fn_dmc
    elif "halfcheetahvel" in task:
        return termination_fn_halfcheetahveljump
    elif "halfcheetah" in task:
        return termination_fn_halfcheetah
    elif "hopper" in task:
        return termination_fn_hopper
    elif "antangle" in task:
        return termination_fn_antangle
    elif "antmaze" in task:
        radius = 0.25
        if "umaze" in task:
            center = (0.75, 8.75)
        if "medium" in task:
            center = (20.75, 20.75)
        if "large" in task:
            center = (32.75, 24.75)
        if "ultra" in task:
            center = (52.75, 36.75)
        center = np.array(center)
        center = jax.device_put(center)
        radius = jax.device_put(radius)
        return lambda obs, action, next_obs: termination_fn_antmaze(
            obs, action, next_obs, center, radius
        )
    elif "ant" in task:
        return termination_fn_ant
    elif "walker2d" in task:
        return termination_fn_walker2d
    elif "point2denv" in task:
        return termination_fn_point2denv
    elif "point2dwallenv" in task:
        return termination_fn_point2dwallenv
    elif "pendulum" in task:
        return termination_fn_pendulum
    elif "humanoid" in task:
        return termination_fn_humanoid
    elif "pen" in task:
        return termination_fn_pen
    elif "door" in task:
        return terminaltion_fn_door
    elif "HalfCheetah" in task:
        return termination_fn_neorl_halfcheetah
    elif "Hopper" in task:
        return termination_fn_neorl_hopper
    elif "Walker2d" in task:
        return termination_fn_neorl_walker2d
    elif "kitchen" in task:
        return termination_fn_kitchen
    else:
        raise jnp.zeros
