import os
import glob

import numpy as np
import gym
from gym.envs.registration import register

from .action_error import *
from .obs_error import *
from .confounding_error import *


# ------------------------------------------------------------------------------------
# ----------------- Altered XMLs (transition error): environments --------------------
# ------------------------------------------------------------------------------------
xml_paths = glob.glob('sim/xmls/**/*.xml', recursive=True)
file_names = [os.path.splitext(os.path.basename(xml_path))[0] for xml_path in xml_paths]
env_names = [xml_path.split(os.path.sep, 2)[2].split(os.path.sep)[0] for xml_path in xml_paths]

for i in range(len(file_names)):
    register(
        id=f'{env_names[i].lower()}-transition-error-{file_names[i]}-v0',
        entry_point=f'sim.transition_error:{env_names[i]}TransitionErrorEnv',
        max_episode_steps=1000,
        kwargs={'xml_path': xml_paths[i]}
    )

# ------------------------------------------------------------------------------------
# ----------------- Altered XMLs (transition error): d4rl datasets  ------------------
# ------------------------------------------------------------------------------------
# Register offline environments with transition errors
ds_types = ['random', 'medium', 'expert', 'medium-expert', 'medium-replay']
for i in range(len(file_names)):
    for ds_type in ds_types:
        register(
            id=f'{env_names[i].lower()}-{ds_type}-transition-error-{file_names[i]}-v2',
            entry_point=f'sim.transition_error:Offline{env_names[i]}TransitionErrorEnv',
            max_episode_steps=1000,
            kwargs={'xml_path': xml_paths[i],
                    'dataset_url': d4rl.infos.DATASET_URLS[f'{env_names[i].lower()}-{ds_type}-v2']
                    }
        )

# ------------------------------------------------------------------------------------
# ------------------------- Default XMLs; Datasets from path -------------------------
# ------------------------------------------------------------------------------------
# Register offline environments with default XMLs (no transition error)
available_paths = [
    {'env': 'HalfCheetah',
     'path': 'data/HalfCheetah-v3_obs_hidden_dims_16_stacked_frames_3/dataset.hdf5',
     'id': 'halfcheetah-ds-obs-hidden-dims-16-stacked-3'
     },
    {'env': 'HalfCheetah',
     'path': 'data/HalfCheetah-v3_obs_hidden_dims_16/dataset.hdf5',
     'id': 'halfcheetah-ds-obs-hidden-dims-16'
     },
    {'env': 'HalfCheetah',
     'path': 'data/HalfCheetah-v3_obs_hidden_dims_9/dataset.hdf5',
     'id': 'halfcheetah-ds-obs-hidden-dims-9'
     },
    {'env': 'HalfCheetah',
     'path': 'data/HalfCheetah-v3_obs_hidden_dims_9_stacked_frames_3/dataset.hdf5',
     'id': 'halfcheetah-ds-obs-hidden-dims-9-stacked-3'
     },
    {'env': 'HalfCheetah',
     'path': 'data/HalfCheetah-v3_obs_hidden_dims_9_stacked_frames_5/dataset.hdf5',
     'id': 'halfcheetah-ds-obs-hidden-dims-9-stacked-5'
     }
]

for ds in available_paths:
    register(
        id=f'{ds["id"]}-v0',
        entry_point=f'sim.transition_error:Offline{ds["env"]}DSTransitionErrorEnv',
        max_episode_steps=1000,
        kwargs={'ds_path': ds['path']}
    )

# ------------------------------------------------------------------------------------
# ---------------- Altered XMLs (transition error): datasets from path ---------------
# ------------------------------------------------------------------------------------
# Register offline environments with transition errors
for i in range(len(file_names)):
    for ds in available_paths:
        if env_names[i] == ds['env']:
            register(
                id=f'{ds["id"]}-transition-error-{file_names[i]}-v0',
                entry_point=f'sim.transition_error:Offline{env_names[i]}DSTransitionErrorEnv',
                max_episode_steps=1000,
                kwargs={'xml_path': xml_paths[i],
                        'ds_path': ds['path']}
            )

# ------------------------------------------------------------------------------------
# ------- Added functions --------
# ------------------------------------------------------------------------------------
from joblib import Parallel, delayed
import copy

def calc_sim_next_obs(simulator, state, action):
    qpos_dim = simulator.sim.data.qpos.size
    if len(state.shape) > 1:
        next_sim_obs = np.zeros_like(state)
        for i in range(state.shape[0]):
            simulator.reset()
            cur_qpos = np.concatenate(([0], state[i, 0:qpos_dim - 1]), axis=0)
            cur_qvel = state[i, qpos_dim - 1:]
            simulator.set_state(cur_qpos, cur_qvel)
            next_sim_obs[i], _, _, _ = simulator.step(action[i])

    else:
        simulator.reset()
        cur_qpos = np.concatenate(([0], state[0:qpos_dim - 1]), axis=0)
        cur_qvel = state[qpos_dim - 1:]
        simulator.set_state(cur_qpos, cur_qvel)
        next_sim_obs, _, _, _ = simulator.step(action)

    return next_sim_obs


def get_transformed_env(env_name, transform_list):
    """
    A method for making and transforming gym environments by a predetermined set of options.
    The options are:

    'obs_noise'              -> float representing magnitude of Gaussian noise to add to SIMULATOR observations.
    'ds_noise'               -> float representing magnitude of Gaussian noise to add to DATASET observations (this is a hidden confounder)
    'ds_hidden_dims'         -> int or list of ints representing which dimensions to remove (actually, to zero and not remove completely) from the DATASET observations (this is a confounder).
    'obs_hidden_dims'        -> int or list of ints representing which dimensions to remove from the SIMULATOR.

    'action_noise'           -> float representing magnitude of Gaussian noise to add to actions before applying them.
    'action_discrete'        -> int representing how many digits to round the action to (discretizing actions).
    'action_delay'           -> int representing the mean of added delay to each new action.

    :param env_name: the name to be used in gym.make()
    :param transform_list: a list of tuples of desired transformations. E.g., [('obs_noise', 0.05), (stacked_frames, 3)]
    :return: gym environment.
    """
    env = gym.make(env_name)
    for transform_tuple in transform_list:
        trans_name = transform_tuple[0]
        trans_value = transform_tuple[1]
        if trans_name == 'obs_noise' and trans_value > 0.0:
            env = ObsErrorNoiseWrapper(env, trans_value)
        elif trans_name == 'ds_noise' and trans_value > 0.0:
            if 'transition-error' in env_name:
                noise_vec_path = f"{env_name.split('-transition-error')[0]}.pkl"
            else:
                noise_vec_path = f'{"-".join(env_name.split("-")[:-1]).lower()}.pkl'
            env = DataErrorNoiseWrapper(env, trans_value, noise_vec_path)
        elif trans_name == 'ds_hidden_dims':
            env = DataErrorHiddenDimsWrapper(env, trans_value)
        elif trans_name == 'obs_hidden_dims':
            env = get_hidden_dims_env(env, trans_value)
        elif trans_name == 'action_noise' and trans_value > 0.0:
            env = ActionErrorRandomNoiseWrapper(env, trans_value)
        elif trans_name == 'action_discrete' and trans_value > -1:
            env = ActionErrorDiscreteWrapper(env, trans_value)
        elif trans_name == 'action_delay' and trans_value > 0:
            env = ActionErrorRandomDelayWrapper(env, trans_value)
        else:
            raise ValueError(f'Transformation {trans_name} with value {trans_value} not available. Please read package manual.')
    return env


def get_hidden_dims_env(env, hidden_dims):
    """
    An additional env transformation named 'obs_hidden_dims'
    :param env: env to transform
    :param hidden_dims: int / list of ints representing which indices of observation to zero.
    :param wandb_name: name for wandb that will be adjusted to include info on obs_hidden_dims
    :return: tuple of the transformed env and the new wandb name
    """
    env = ObsErrorHiddenDims(env, hidden_dims)
    return env
