import numpy as np
import torch

def obs_unnormalization(termination_fn, obs_mean, obs_std):
    def thunk(obs, act, next_obs):
        obs = obs*obs_std + obs_mean
        next_obs = next_obs*obs_std + obs_mean
        return termination_fn(obs, act, next_obs)
    return thunk

def termination_fn_dummy(obs, act, next_obs):
    assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 2

    done = torch.tensor([False] * obs.shape[0], device=next_obs.device)

    done = done[:, None]
    return done

def termination_fn_halfcheetah(obs, act, next_obs):
    assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 2

    # always torch
    not_done = torch.logical_and(torch.all(next_obs > -100, dim=-1), torch.all(next_obs < 100, dim=-1))


    done = ~not_done
    done = done[:, None]
    return done

def termination_fn_hopper(obs, act, next_obs):
    assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 2

    height = next_obs[:, 0]
    angle = next_obs[:, 1]

    # always torch
    finite_check = torch.isfinite(next_obs).all(dim=-1)
    bound_check = torch.abs(next_obs[:, 1:]).lt(100).all(dim=-1)
    angle_check = torch.abs(angle).lt(0.2)
    not_done = finite_check * bound_check * (height > .7) * angle_check

    done = ~not_done
    done = done[:,None]
    return done

def termination_fn_halfcheetahveljump(obs, act, next_obs):
    assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 2

    if type(obs) == np.ndarray and type(act) == np.ndarray and type(next_obs) == np.ndarray:
        done = np.array([False]).repeat(len(obs))
        done = done[:,None]
    else:
        done = torch.tensor([False]).repeat(len(obs))
        done = done[:, None]

    return done

def termination_fn_antangle(obs, act, next_obs):
    assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 2

    x = next_obs[:, 0]
    not_done = 	np.isfinite(next_obs).all(axis=-1) \
                * (x >= 0.2) \
                * (x <= 1.0)

    done = ~not_done
    done = done[:,None]
    return done

def termination_fn_ant(obs, act, next_obs):
    assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 2

    x = next_obs[:, 0]
    not_done = 	np.isfinite(next_obs).all(axis=-1) \
                * (x >= 0.2) \
                * (x <= 1.0)

    done = ~not_done
    done = done[:,None]
    return done

def termination_fn_walker2d(obs, act, next_obs):
    assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 2

    height = next_obs[:, 0]
    angle = next_obs[:, 1]

    # always torch
    not_done = (torch.logical_and(torch.all(next_obs > -100, dim=-1), torch.all(next_obs < 100, dim=-1))
                    * (height > 0.8)
                    * (height < 2.0)
                    * (angle > -1.0)
                    * (angle < 1.0))


    done = ~not_done
    done = done[:, None]
    return done

def termination_fn_point2denv(obs, act, next_obs):
    assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 2

    done = np.array([False]).repeat(len(obs))
    done = done[:,None]
    return done

def termination_fn_point2dwallenv(obs, act, next_obs):
    assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 2

    done = np.array([False]).repeat(len(obs))
    done = done[:,None]
    return done

def termination_fn_pendulum(obs, act, next_obs):
    assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 2

    done = np.zeros((len(obs), 1))
    return done

def termination_fn_humanoid(obs, act, next_obs):
    assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 2

    z = next_obs[:,0]
    done = (z < 1.0) + (z > 2.0)

    done = done[:,None]
    return done

def termination_fn_pen(obs, act, next_obs):
    assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 2

    obj_pos = next_obs[:, 24:27]
    done = obj_pos[:, 2] < 0.075

    done = done[:,None]
    return done

def termination_fn_door(obs, act, next_obs):
    assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 2

    done = np.array([False] * obs.shape[0])

    done = done[:, None]
    return done

# from mopo codes in romi
def termination_fn_maze(obs, act, next_obs):
    assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 2

    if type(obs) == np.ndarray and type(act) == np.ndarray and type(next_obs) == np.ndarray:
        done = np.zeros((obs.shape[0], 1)).astype(bool)
    else:
        done = torch.zeros((obs.shape[0], 1), dtype=torch.bool)

    return done

# from mopo codes in romi
def termination_fn_antmaze(obs, act, next_obs, env):
    assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 2

    if type(obs) == np.ndarray and type(act) == np.ndarray and type(next_obs) == np.ndarray:
        done = np.linalg.norm(next_obs[:, :2] - env.target_goal, axis=1) <= 0.5
    else:
        done = torch.linalg.norm(next_obs[:, :2] - torch.tensor(env.target_goal, device=next_obs.device), dim=1) <= 0.5
    done = done[:, None]

    return done

def termination_fn_slider(obs, act, next_obs):
    assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 2

    if type(obs) == np.ndarray and type(act) == np.ndarray and type(next_obs) == np.ndarray:
        done = np.zeros((obs.shape[0], 1)).astype(bool)
    else:
        done = torch.zeros((obs.shape[0], 1), dtype=torch.bool, device=next_obs.device)
    return done

def termination_fn_adroit(obs, act, next_obs):
    assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 2

    if type(obs) == np.ndarray and type(act) == np.ndarray and type(next_obs) == np.ndarray:
        done = np.zeros((obs.shape[0], 1)).astype(bool)
    else:
        done = torch.zeros((obs.shape[0], 1), dtype=torch.bool, device=next_obs.device)
    return done



def get_termination_fn(task):
    if 'halfcheetahvel' in task:
        return termination_fn_halfcheetahveljump
    elif 'halfcheetah' in task:
        return termination_fn_halfcheetah
    elif 'hopper' in task:
        return termination_fn_hopper
    elif 'antangle' in task:
        return termination_fn_antangle
    elif 'ant' in task:
        return termination_fn_ant
    elif 'walker2d' in task:
        return termination_fn_walker2d
    elif 'point2denv' in task:
        return termination_fn_point2denv
    elif 'point2dwallenv' in task:
        return termination_fn_point2dwallenv
    elif 'pendulum' in task:
        return termination_fn_pendulum
    elif 'humanoid' in task:
        return termination_fn_humanoid
    elif 'maze2d' in task:
        return termination_fn_maze
    elif 'antmaze' in task:
        return termination_fn_antmaze
    elif 'pen' in task:
        return termination_fn_pen
    elif 'door' in task:
        return termination_fn_door
    elif 'slider' in task:
        return termination_fn_slider
    elif 'adroit' in task:
        return termination_fn_adroit
    else:
        raise np.zeros
