import os

import h5py
from gym.envs.mujoco.half_cheetah import HalfCheetahEnv as HalfCheetahEnvGT
from gym.envs.mujoco.hopper import HopperEnv as HopperEnvGT
from gym.envs.mujoco.humanoid import HumanoidEnv as HumanoidEnvGT
from gym.envs.mujoco.walker2d import Walker2dEnv as Walker2dEnvGT
from gym.envs.mujoco import mujoco_env
from gym import utils


# ------------------------------------------------------------------------
# ------------------- Transition error environments ----------------------
# ------------------------------------------------------------------------
class HalfCheetahTransitionErrorEnv(HalfCheetahEnvGT):
    def __init__(self, xml_path=None):
        xml_path = 'half_cheetah.xml' if xml_path is None else os.path.join(os.getcwd(), xml_path)
        mujoco_env.MujocoEnv.__init__(self, xml_path, 5)
        utils.EzPickle.__init__(self)


class HopperTransitionErrorEnv(HopperEnvGT):
    def __init__(self, xml_path=None):
        xml_path = 'hopper.xml' if xml_path is None else os.path.join(os.getcwd(), xml_path)
        mujoco_env.MujocoEnv.__init__(self, xml_path, 4)
        utils.EzPickle.__init__(self)


class Walker2dTransitionErrorEnv(Walker2dEnvGT):
    def __init__(self, xml_path=None):
        xml_path = 'walker2d.xml' if xml_path is None else os.path.join(os.getcwd(), xml_path)
        mujoco_env.MujocoEnv.__init__(self, xml_path, 4)
        utils.EzPickle.__init__(self)


class HumanoidTransitionErrorEnv(HumanoidEnvGT):
    def __init__(self, xml_path=None):
        xml_path = 'humanoid.xml' if xml_path is None else os.path.join(os.getcwd(), xml_path)
        mujoco_env.MujocoEnv.__init__(self, xml_path, 5)
        utils.EzPickle.__init__(self)

# ------------------------------------------------------------------------
# -------------- D4RL environments with transition error -----------------
# ------------------------------------------------------------------------
from d4rl import offline_env


class OfflineHopperTransitionErrorEnv(HopperTransitionErrorEnv, offline_env.OfflineEnv):
    def __init__(self, xml_path=None, **kwargs):
        HopperTransitionErrorEnv.__init__(self, xml_path)
        offline_env.OfflineEnv.__init__(self, **kwargs)


class OfflineHalfCheetahTransitionErrorEnv(HalfCheetahTransitionErrorEnv, offline_env.OfflineEnv):
    def __init__(self, xml_path=None, **kwargs):
        HalfCheetahTransitionErrorEnv.__init__(self, xml_path)
        offline_env.OfflineEnv.__init__(self, **kwargs)


class OfflineWalker2dTransitionErrorEnv(Walker2dTransitionErrorEnv, offline_env.OfflineEnv):
    def __init__(self, xml_path=None, **kwargs):
        Walker2dTransitionErrorEnv.__init__(self, xml_path)
        offline_env.OfflineEnv.__init__(self, **kwargs)


class OfflineHumanoidTransitionErrorEnv(HumanoidTransitionErrorEnv, offline_env.OfflineEnv):
    def __init__(self, xml_path=None, **kwargs):
        HumanoidTransitionErrorEnv.__init__(self, xml_path)
        offline_env.OfflineEnv.__init__(self, **kwargs)


# ----------------------------------------------------------------------------------------------
# ------------- Offline environments with dataset from path and transition error ---------------
# ----------------------------------------------------------------------------------------------
class OfflineDSEnv(offline_env.OfflineEnv):
    def __init__(self, ds_path, **kwargs):
        self.ds_path = ds_path

        # Fetch relevant info from dataset meta-data
        # In this case - check if there are
        with h5py.File(self.ds_path, 'r') as dataset_file:
            if 'obs_hidden_dims' in dataset_file.attrs:
                self.hidden_dims = dataset_file.attrs['obs_hidden_dims']

        offline_env.OfflineEnv.__init__(self, **kwargs)

    def get_dataset(self, h5path=None):
        ds = offline_env.OfflineEnv.get_dataset(self, h5path=self.ds_path)
        return ds


class OfflineHopperDSTransitionErrorEnv(HopperTransitionErrorEnv, OfflineDSEnv):
    def __init__(self, ds_path, xml_path=None, **kwargs):
        self.ds_path = ds_path
        HopperTransitionErrorEnv.__init__(self, xml_path)
        OfflineDSEnv.__init__(self, ds_path, **kwargs)


class OfflineHalfCheetahDSTransitionErrorEnv(HalfCheetahTransitionErrorEnv, OfflineDSEnv):
    def __init__(self, ds_path, xml_path=None, **kwargs):
        self.ds_path = ds_path
        HalfCheetahTransitionErrorEnv.__init__(self, xml_path)
        OfflineDSEnv.__init__(self, ds_path, **kwargs)


class OfflineWalker2dDSTransitionErrorEnv(Walker2dTransitionErrorEnv, OfflineDSEnv):
    def __init__(self, ds_path, xml_path=None, **kwargs):
        self.ds_path = ds_path
        Walker2dTransitionErrorEnv.__init__(self, xml_path)
        OfflineDSEnv.__init__(self, ds_path, **kwargs)


class OfflineHumanoidDSTransitionErrorEnv(HumanoidTransitionErrorEnv, OfflineDSEnv):
    def __init__(self, ds_path, xml_path=None, **kwargs):
        self.ds_path = ds_path
        HumanoidTransitionErrorEnv.__init__(self, xml_path)
        OfflineDSEnv.__init__(self, ds_path, **kwargs)
