import numpy as np
from .half_cheetah import HalfCheetahEnv


class CheetahNoFlipEnv(HalfCheetahEnv):
    def check_termination(self):
        for i in range(self.data.ncon):
            contact = self.data.contact[i]
            name_set = set()
            name_set.add(self.model.geom_names[contact.geom1])
            name_set.add(self.model.geom_names[contact.geom2])
            if 'floor' in name_set and 'head' in name_set:
                return True
        return False

    def step(self, action):
        next_state, reward, _, info = super().step(action)
        done = self.check_termination()
        return next_state, reward, done, info

    @staticmethod
    def done(states):
        dones = []
        for state in states:
            _done_test_env.set_state_from_obs(state)
            dones.append(_done_test_env.check_termination())
        return np.array(dones)


_done_test_env = CheetahNoFlipEnv()