import os
import glob

import numpy as np
import gym
from gym.envs.registration import register

from .action_error import *
from .obs_error import *
from .confounding_error import *

# ------------------------------------------------------------------------------------
# ---------------------- Added HIV environment ------------------------------------------
# ------------------------------------------------------------------------------------
gym.envs.registration.register(
    id="HIV-v0",
    entry_point='sim.added_envs.hiv:HIVTreatment',
    max_episode_steps=50,
    kwargs={'dosage_noise': 0.15}
)

gym.envs.registration.register(
    id="HIV-transition-error-dosage-0.3-v0",
    entry_point='sim.added_envs.hiv:HIVTreatment',
    max_episode_steps=50,
    kwargs={'dosage_noise': 0.3}
)

register(
    id="HIV-transition-error-dosage-0.1-v0",
    entry_point='sim.added_envs.hiv:HIVTreatment',
    max_episode_steps=50,
    kwargs={'dosage_noise': 0.1}
)

register(
    id="HIV-transition-error-dosage-0.2-v0",
    entry_point='sim.added_envs.hiv:HIVTreatment',
    max_episode_steps=50,
    kwargs={'dosage_noise': 0.2}
)

# ------------------------------------------------------------------------------------
# ----------------- Altered XMLs (transition error): environments --------------------
# ------------------------------------------------------------------------------------
xml_paths = glob.glob('sim/xmls/**/*.xml', recursive=True)
file_names = [os.path.splitext(os.path.basename(xml_path))[0] for xml_path in xml_paths]
env_names = [xml_path.split(os.path.sep, 2)[2].split(os.path.sep)[0] for xml_path in xml_paths]

for i in range(len(file_names)):
    register(
        id=f'{env_names[i].lower()}-transition-error-{file_names[i]}-v0',
        entry_point=f'sim.transition_error:{env_names[i]}TransitionErrorEnv',
        max_episode_steps=1000,
        kwargs={'xml_path': xml_paths[i]}
    )

# ------------------------------------------------------------------------------------
# ----------------- Altered XMLs (transition error): d4rl datasets  ------------------
# ------------------------------------------------------------------------------------
# Register offline environments with transition errors
ds_types = ['random', 'medium', 'expert', 'medium-expert', 'medium-replay']
for i in range(len(file_names)):
    for ds_type in ds_types:
        register(
            id=f'{env_names[i].lower()}-{ds_type}-transition-error-{file_names[i]}-v2',
            entry_point=f'sim.transition_error:Offline{env_names[i]}TransitionErrorEnv',
            max_episode_steps=1000,
            kwargs={'xml_path': xml_paths[i],
                    'dataset_url': d4rl.infos.DATASET_URLS[f'{env_names[i].lower()}-{ds_type}-v2']
                    }
        )

# ------------------------------------------------------------------------------------
# ------------------------- Default XMLs; Datasets from path -------------------------
# ------------------------------------------------------------------------------------
# Register offline environments with default XMLs (no transition error)
available_paths = [
    {'env': 'HalfCheetah',
     'path': 'data/HalfCheetah-v3_obs_hidden_dims_16_stacked_frames_3/dataset.hdf5',
     'id': 'halfcheetah-ds-obs-hidden-dims-16-stacked-3'
     },
    {'env': 'HalfCheetah',
     'path': 'data/HalfCheetah-v3_obs_hidden_dims_16/dataset.hdf5',
     'id': 'halfcheetah-ds-obs-hidden-dims-16'
     },
    {'env': 'HalfCheetah',
     'path': 'data/HalfCheetah-v3_obs_hidden_dims_9/dataset.hdf5',
     'id': 'halfcheetah-ds-obs-hidden-dims-9'
     },
    {'env': 'HalfCheetah',
     'path': 'data/HalfCheetah-v3_obs_hidden_dims_9_stacked_frames_3/dataset.hdf5',
     'id': 'halfcheetah-ds-obs-hidden-dims-9-stacked-3'
     },
    {'env': 'HalfCheetah',
     'path': 'data/HalfCheetah-v3_obs_hidden_dims_9_stacked_frames_5/dataset.hdf5',
     'id': 'halfcheetah-ds-obs-hidden-dims-9-stacked-5'
     }
]

for ds in available_paths:
    register(
        id=f'{ds["id"]}-v0',
        entry_point=f'sim.transition_error:Offline{ds["env"]}DSTransitionErrorEnv',
        max_episode_steps=1000,
        kwargs={'ds_path': ds['path']}
    )

# ------------------------------------------------------------------------------------
# ---------------- Altered XMLs (transition error): datasets from path ---------------
# ------------------------------------------------------------------------------------
# Register offline environments with transition errors
for i in range(len(file_names)):
    for ds in available_paths:
        if env_names[i] == ds['env']:
            register(
                id=f'{ds["id"]}-transition-error-{file_names[i]}-v0',
                entry_point=f'sim.transition_error:Offline{env_names[i]}DSTransitionErrorEnv',
                max_episode_steps=1000,
                kwargs={'xml_path': xml_paths[i],
                        'ds_path': ds['path']}
            )


# ------------------------------------------------------------------------------------
# ------- Added functions --------
# ------------------------------------------------------------------------------------


def get_transformed_env(env_name, transform_list):
    """
    A method for making and transforming gym environments by a predetermined set of options.
    The options are:

    'conf_noise'         -> float representing magnitude of Gaussian noise to add to DATA observations.
    'obs_noise'          -> float representing magnitude of Gaussian noise to add to observations.
    'action_noise'       -> float representing magnitude of Gaussian noise to add to actions before applying them.
    'action_fixed_noise' -> float similar to 'action_noise' but with a noise vector pre-created and saved.
    'action_discrete'    -> int representing how many digits to round the action to (discretizing actions).

    :param env_name: the name to be used in gym.make()
    :param transform_list: a list of tuples of desired transformations. E.g., [('obs_noise', 0.05), (stacked_frames, 3)]
    :return: gym environment.
    """
    env = gym.make(env_name)
    wandb_name = env_name
    for transform_tuple in transform_list:
        trans_name = transform_tuple[0]
        trans_value = transform_tuple[1]
        if trans_name == 'obs_noise' and trans_value > 0.0:
            env = ObsErrorNoiseWrapper(env, trans_value)
            wandb_name += f'_obs_noise_{trans_value}'
        elif trans_name == 'ds_noise' and trans_value > 0.0:
            if 'transition-error' in env_name:
                noise_vec_path = f"{env_name.split('-transition-error')[0]}.pkl"
            else:
                noise_vec_path = f'{"-".join(env_name.split("-")[:-1]).lower()}.pkl'
            env = DataErrorNoiseWrapper(env, trans_value, noise_vec_path)
            wandb_name += f'_ds_noise_{trans_value}'
        elif trans_name == 'ds_hidden_dims':
            env = DataErrorHiddenDimsWrapper(env, trans_value)
            wandb_name += f'_ds_hidden_dims_{trans_value}'
        elif trans_name == 'obs_hidden_dims':
            env, wandb_name = get_hidden_dims_env(env, trans_value, wandb_name)
        elif trans_name == 'action_noise' and trans_value > 0.0:
            env = ActionErrorRandomNoiseWrapper(env, trans_value)
            wandb_name += f'_action_noise_{trans_value}'
        elif trans_name == 'action_discrete' and trans_value > -1:
            env = ActionErrorDiscreteWrapper(env, trans_value)
            wandb_name += f'_action_discrete_{trans_value}'
        elif trans_name == 'action_delay' and trans_value > 0:
            env = ActionErrorRandomDelayWrapper(env, trans_value)
            wandb_name += f'_action_delay_{trans_value}'
        else:
            raise ValueError(
                f'Transformation {trans_name} with value {trans_value} not available. Please read package manual.')
        # if trans_name == 'stacked_frames' and trans_value > 0:
        #     env = gym.wrappers.FrameStack(env, trans_value)
        #     wandb_name += f'_stacked_frames_{trans_value}'
    return env, wandb_name


def get_hidden_dims_env(env, hidden_dims, wandb_name):
    """
    An additional env transformation named 'obs_hidden_dims'
    :param env: env to transform
    :param hidden_dims: int / list of ints representing which indices of observation to zero.
    :param wandb_name: name for wandb that will be adjusted to include info on obs_hidden_dims
    :return: tuple of the transformed env and the new wandb name
    """
    env = ObsErrorHiddenDims(env, hidden_dims)
    wandb_name += f'_obs_hidden_dims_{hidden_dims}'
    return env, wandb_name
