import os
import glob

import numpy as np
import gymnasium
import highway_env
from gymnasium.envs.registration import register

from .action_error import *
from .obs_error import *
from .confounding_error import *


class HighwayWrapper(gymnasium.Wrapper):
    def __init__(self, env, **kwargs):
        super().__init__(env, **kwargs)
        self.kwargs = kwargs
        self.env = env

    def step(self, action):
        next_obs, reward, terminated, truncated, info = self.env.step(action)
        return next_obs, reward, (terminated or truncated), info

    def reset(self, **kwargs):
        obs, _ = self.env.reset(**kwargs)
        # self.get_attr('road').vehicles[0].target_speed = 25.0
        self.unwrapped.road.vehicles[0].target_speed = 25.0
        return obs


# ------------------------------------------------------------------------------------
# ------------------------ Datasets environments -------------------------------------
# ------------------------------------------------------------------------------------
available_paths = [
    {'env': 'highway-fast-v0',
     'path': 'data/highway_medium/dataset.hdf5',
     'id': 'highway-fast-medium-v0'
     },
    {'env': 'highway-fast-v0',
     'path': 'data/highway_expert/dataset.hdf5',
     'id': 'highway-fast-expert-v0'
     }
]

for ds in available_paths:
    register(
        id=f'{ds["id"]}',
        entry_point=f'highway.offline:HighwayOfflineEnv',
        kwargs={'ds_path': ds['path']}
    )


# ------------------------------------------------------------------------------------
# ------- Added functions --------
# ------------------------------------------------------------------------------------
def make_env(env_name, wrap=False, hidden_cars=False):
    env = gymnasium.make(env_name)
    if hidden_cars:
        env.configure({"action": {"type": "ContinuousAction"}, "observation": {"type": "Kinematics", "vehicles_count": 3}})
        env = ObsErrorCars(env)
    else:
        env.configure({"action": {"type": "ContinuousAction"}})

    # env.configure({"action": {"type": "ContinuousAction"}, 'other_vehicles_type': 'highway_env.vehicle.behavior.IDMVehicle'})

    if wrap:
        env = HighwayWrapper(env)
    return env

def get_transformed_env(env_name, transform_list):
    """
    A method for making and transforming gymnasium environments by a predetermined set of options.
    The options are:

    'conf_noise'         -> float representing magnitude of Gaussian noise to add to DATA observations.
    'obs_noise'          -> float representing magnitude of Gaussian noise to add to observations.
    'action_noise'       -> float representing magnitude of Gaussian noise to add to actions before applying them.
    'action_fixed_noise' -> float similar to 'action_noise' but with a noise vector pre-created and saved.
    'action_discrete'    -> int representing how many digits to round the action to (discretizing actions).

    :param env_name: the name to be used in gymnasium.make()
    :param transform_list: a list of tuples of desired transformations. E.g., [('obs_noise', 0.05), (stacked_frames, 3)]
    :return: gymnasium environment.
    """
    env = make_env(env_name)
    wandb_name = env_name
    for transform_tuple in transform_list:
        trans_name = transform_tuple[0]
        trans_value = transform_tuple[1]
        if trans_name == 'obs_noise' and trans_value > 0.0:
            env = ObsErrorNoiseWrapper(env, trans_value)
            wandb_name += f'_obs_noise_{trans_value}'
        elif trans_name == 'ds_noise' and trans_value > 0.0:
            if 'transition-error' in env_name:
                noise_vec_path = f"{env_name.split('-transition-error')[0]}.pkl"
            else:
                noise_vec_path = f'{"-".join(env_name.split("-")[:-1]).lower()}.pkl'
            env = DataErrorNoiseWrapper(env, trans_value, noise_vec_path)
            wandb_name += f'_ds_noise_{trans_value}'
        elif trans_name == 'ds_hidden_dims':
            env = DataErrorHiddenDimsWrapper(env, trans_value)
            wandb_name += f'_ds_hidden_dims_{trans_value}'
        elif trans_name =='obs_hidden_dims':
            env, wandb_name = get_hidden_dims_env(env, trans_value, wandb_name)
        elif trans_name == 'obs_hidden_cars':
            env.configure({"observation": {"type": "Kinematics", "vehicles_count": trans_value}})
            env = ObsErrorCars(env)
            wandb_name += f'_obs_hidden_cars_{trans_value}'
        elif trans_name == 'obs_cars_type':
            if trans_value == 0:
                env.configure({'other_vehicles_type': 'highway_env.vehicle.behavior.LinearVehicle'})
            elif trans_value == 1:
                env.configure({'other_vehicles_type': 'highway_env.vehicle.behavior.AggressiveVehicle'})
            elif trans_value == 2:
                env.configure({'other_vehicles_type': 'highway_env.vehicle.behavior.DefensiveVehicle'})
            else:
                raise Exception
            wandb_name += f'_obs_cars_type_{trans_value}'
        elif trans_name == 'action_noise' and trans_value > 0.0:
            env = ActionErrorRandomNoiseWrapper(env, trans_value)
            wandb_name += f'_action_noise_{trans_value}'
        elif trans_name == 'action_discrete' and trans_value > -1:
            env = ActionErrorDiscreteWrapper(env, trans_value)
            wandb_name += f'_action_discrete_{trans_value}'
        elif trans_name == 'action_delay' and trans_value > 0:
            env = ActionErrorRandomDelayWrapper(env, trans_value)
            wandb_name += f'_action_delay_{trans_value}'
        else:
            raise ValueError(f'Transformation {trans_name} with value {trans_value} not available. Please read package manual.')
    env = HighwayWrapper(env)
    return env, wandb_name


def get_hidden_dims_env(env, hidden_dims, wandb_name):
    """
    An additional env transformation named 'obs_hidden_dims'
    :param env: env to transform
    :param hidden_dims: int / list of ints representing which indices of observation to zero.
    :param wandb_name: name for wandb that will be adjusted to include info on obs_hidden_dims
    :return: tuple of the transformed env and the new wandb name
    """
    env = ObsErrorHiddenDims(env, hidden_dims)
    wandb_name += f'_obs_hidden_dims_{hidden_dims}'
    return env, wandb_name
