from typing import List, Dict, Union, Tuple
import random
import numpy as np
import gym

from gym.wrappers.time_limit import TimeLimit
from .utils.half_cheetah import HalfCheetahEnv  # modified from gym


class HC_Dynamics_Wrapper(gym.Wrapper):
    def __init__(self, param_dict: Dict = dict()):
        param_dict = param_dict.copy()  # avoid modifying the original dict
        action_noise = param_dict.pop('action noise', None)
        super().__init__(
            TimeLimit(
                HalfCheetahEnv(),
                1000
            )
        )
        
        self.name               = 'HalfCheetah'
        
        self.params             = list(param_dict.keys())
        self.param_dict         = param_dict
        self.initial_param_dict = {param: [] for param in self.params}

        self.current_param_scale = dict()

        # noise scale for the action noise
        self.noise_scale = action_noise
        self.max_action = self.env.action_space.high[0]

        for param in self.params:
            if param == 'foot mass':
                self.initial_param_dict[param].append(self.unwrapped.model.body_mass[4])
            elif param == 'shin mass':
                self.initial_param_dict[param].append(self.unwrapped.model.body_mass[3])
            elif param == 'torso mass':
                self.initial_param_dict[param].append(self.unwrapped.model.body_mass[5])
            elif param == 'foot fric':
                self.initial_param_dict[param].append(self.unwrapped.model.geom_friction[5][0])
            elif param == 'damping':
                self.initial_param_dict[param].append(self.unwrapped.model.dof_damping[3])
                self.initial_param_dict[param].append(self.unwrapped.model.dof_damping[4])
                self.initial_param_dict[param].append(self.unwrapped.model.dof_damping[5])
                self.initial_param_dict[param].append(self.unwrapped.model.dof_damping[6])
                self.initial_param_dict[param].append(self.unwrapped.model.dof_damping[7])
                self.initial_param_dict[param].append(self.unwrapped.model.dof_damping[8])
            elif param == 'mass':
                self.initial_param_dict[param].append(self.unwrapped.model.body_mass[1])
                self.initial_param_dict[param].append(self.unwrapped.model.body_mass[2])
                self.initial_param_dict[param].append(self.unwrapped.model.body_mass[3])
                self.initial_param_dict[param].append(self.unwrapped.model.body_mass[4])
                self.initial_param_dict[param].append(self.unwrapped.model.body_mass[5])
                self.initial_param_dict[param].append(self.unwrapped.model.body_mass[6])
                self.initial_param_dict[param].append(self.unwrapped.model.body_mass[7])
            elif param == 'bt jnt lower limit':
                self.initial_param_dict[param].append(self.unwrapped.model.jnt_range[3][0])
            elif param == 'bt jnt upper limit':
                self.initial_param_dict[param].append(self.unwrapped.model.jnt_range[3][1])
            elif param == 'ff jnt lower limit':
                self.initial_param_dict[param].append(self.unwrapped.model.jnt_range[8][0])
            elif param == 'ff jnt upper limit':
                self.initial_param_dict[param].append(self.unwrapped.model.jnt_range[8][1])

            elif param == 'all fric':
                self.initial_param_dict[param].append(self.env.model.geom_friction.copy())

            else:
                raise NotImplementedError(f"{param} is not adjustable in HalfCheetah")

            self.current_param_scale[param]    = 1

    def set_params(self, param_scales: Dict) -> None:
        assert len(param_scales) == len(self.params), 'Length of new params must align the initilization params'
        for param, scale in list(param_scales.items()):
            if param == 'foot mass':
                self.unwrapped.model.body_mass[4] = self.initial_param_dict[param][-1] * scale
                self.unwrapped.model.body_mass[7] = self.initial_param_dict[param][-1] * scale
            elif param == 'foot fric':
                self.unwrapped.model.geom_friction[5][0] = self.initial_param_dict[param][-1] * scale
                self.unwrapped.model.geom_friction[8][0] = self.initial_param_dict[param][-1] * scale
            elif param == 'torso mass':
                self.unwrapped.model.body_mass[1] = self.initial_param_dict[param][-1] * scale
            elif param == 'shin mass':
                self.unwrapped.model.body_mass[3] = self.initial_param_dict[param][-1] * scale
                self.unwrapped.model.body_mass[6] = self.initial_param_dict[param][-1] * scale
            elif param == 'damping':
                self.unwrapped.model.dof_damping[3]  = self.initial_param_dict[param][0] * scale
                self.unwrapped.model.dof_damping[4]  = self.initial_param_dict[param][1] * scale
                self.unwrapped.model.dof_damping[5]  = self.initial_param_dict[param][2] * scale
                self.unwrapped.model.dof_damping[6]  = self.initial_param_dict[param][3] * scale
                self.unwrapped.model.dof_damping[7]  = self.initial_param_dict[param][4] * scale
                self.unwrapped.model.dof_damping[8]  = self.initial_param_dict[param][5] * scale
            elif param == 'mass':
                self.unwrapped.model.body_mass[1] =  self.initial_param_dict[param][0] * scale
                self.unwrapped.model.body_mass[2] =  self.initial_param_dict[param][1] * scale
                self.unwrapped.model.body_mass[3] =  self.initial_param_dict[param][2] * scale
                self.unwrapped.model.body_mass[4] =  self.initial_param_dict[param][3] * scale
                self.unwrapped.model.body_mass[5] =  self.initial_param_dict[param][4] * scale
                self.unwrapped.model.body_mass[6] =  self.initial_param_dict[param][5] * scale
                self.unwrapped.model.body_mass[7] =  self.initial_param_dict[param][6] * scale
            elif param == 'ff jnt lower limit':
                self.unwrapped.model.jnt_range[8][0] = self.initial_param_dict[param][-1] * scale
            elif param == 'ff jnt upper limit':
                self.unwrapped.model.jnt_range[8][1] = self.initial_param_dict[param][-1] * scale
            elif param == 'bt jnt lower limit':
                self.unwrapped.model.jnt_range[3][0] = self.initial_param_dict[param][-1] * scale
            elif param == 'bt jnt upper limit':
                self.unwrapped.model.jnt_range[3][1] = self.initial_param_dict[param][-1] * scale
                    
            elif param == 'all fric':
                self.env.model.geom_friction[:] = self.initial_param_dict[param][-1] * scale
    
            self.current_param_scale[param] = scale

    def resample_params(self) -> None:
        new_scales = {}
        for param, bound_or_possible_values in list(self.param_dict.items()):
            if len(bound_or_possible_values) == 2:
                new_scales[param] = random.uniform(
                    bound_or_possible_values[0],
                    bound_or_possible_values[1]
                )
            else:
                new_scales[param] = random.choice(bound_or_possible_values)
        self.set_params(new_scales)

    def reset(self, resample: bool = True) -> np.array:
        if resample:
            self.resample_params()
        return self.env.reset()
    
    def step(self, action):       
        # add noise to the action if noise scale is set
        if self.noise_scale is not None:
            noise = self._get_noise(self.noise_scale, size=action.shape)
            action = (action + noise).clip(-self.max_action, self.max_action)
        
        return self.env.step(action) 

    def __getattr__(self, name):
        return getattr(self.env, name)

    @property
    def current_param_scales(self) -> Dict:
        return self.current_param_scale

    @property
    def current_flat_scale(self) -> List:
        return list(self.current_param_scale.values())

    @property
    def action_bound(self) -> float:
        return self.env.action_space.high[0]

    def _get_noise(self, noise_params, size):
        if noise_params is None:
            return np.zeros(size)
        
        if 'uniform' in noise_params:
            low_noise = noise_params.get('uniform', {}).get('low', -0.01)
            high_noise = noise_params.get('uniform', {}).get('high', 0.01)
            noise = self.np_random.uniform(
                low=low_noise, high=high_noise, size=size
            )
        elif 'normal' in noise_params:
            means = noise_params.get('normal', {}).get('mean', [0.0])
            stds = noise_params.get('normal', {}).get('std', [0.01])
            weights = noise_params.get('normal', {}).get('weights', [1.0])
            
            component_ids = self.np_random.choice(
                len(means), size=size, p=weights
            )
            noise = self.np_random.normal(
                loc=np.array(means)[component_ids],
                scale=np.array(stds)[component_ids],
                size=size
            )
        else:
            raise ValueError("Invalid noise parameters provided. Must contain 'uniform' or 'normal' keys.")
            
        return noise


class HC_Speed_and_Dynamics_Wrapper(gym.Wrapper):
    def __init__(self, param_dict: Dict = dict()):
        super().__init__(gym.make('HalfCheetah-v2'))
        self.name               = 'HalfCheetah'
        
        self.params             = list(param_dict.keys())
        self.param_dict         = param_dict
        self.initial_param_dict = {param: [] for param in self.params}

        self.current_param_scale = dict()
        
        # noise scale for the action noise
        self.noise_scale = (False, 0.0)
        self.max_action = self.env.action_space.high[0]

        for param in self.params:
            if param == 'speed':
                self.initial_param_dict[param].append(param_dict[param])
                self.speed_requirement = param_dict[param]
            elif param == 'foot mass':
                self.initial_param_dict[param].append(self.unwrapped.model.body_mass[4])
            elif param == 'shin mass':
                self.initial_param_dict[param].append(self.unwrapped.model.body_mass[3])
            elif param == 'torso mass':
                self.initial_param_dict[param].append(self.unwrapped.model.body_mass[5])
            elif param == 'foot fric':
                self.initial_param_dict[param].append(self.unwrapped.model.geom_friction[5][0])
            elif param == 'damping':
                self.initial_param_dict[param].append(self.unwrapped.model.dof_damping[3])
                self.initial_param_dict[param].append(self.unwrapped.model.dof_damping[4])
                self.initial_param_dict[param].append(self.unwrapped.model.dof_damping[5])
                self.initial_param_dict[param].append(self.unwrapped.model.dof_damping[6])
                self.initial_param_dict[param].append(self.unwrapped.model.dof_damping[7])
                self.initial_param_dict[param].append(self.unwrapped.model.dof_damping[8])
            elif param == 'mass':
                self.initial_param_dict[param].append(self.unwrapped.model.body_mass[1])
                self.initial_param_dict[param].append(self.unwrapped.model.body_mass[2])
                self.initial_param_dict[param].append(self.unwrapped.model.body_mass[3])
                self.initial_param_dict[param].append(self.unwrapped.model.body_mass[4])
                self.initial_param_dict[param].append(self.unwrapped.model.body_mass[5])
                self.initial_param_dict[param].append(self.unwrapped.model.body_mass[6])
                self.initial_param_dict[param].append(self.unwrapped.model.body_mass[7])
            elif param == 'bt jnt lower limit':
                self.initial_param_dict[param].append(self.unwrapped.model.jnt_range[3][0])
            elif param == 'bt jnt upper limit':
                self.initial_param_dict[param].append(self.unwrapped.model.jnt_range[3][1])
            elif param == 'ff jnt lower limit':
                self.initial_param_dict[param].append(self.unwrapped.model.jnt_range[8][0])
            elif param == 'ff jnt upper limit':
                self.initial_param_dict[param].append(self.unwrapped.model.jnt_range[8][1])

            elif param == 'action noise':
                self.initial_param_dict[param].append(0.0)
                
            else:
                raise NotImplementedError(f"{param} is not adjustable in HalfCheetah")

            self.current_param_scale[param]    = 1

    def step(self, action):
        # add noise to the action if noise scale is set
        if self.noise_scale[0]:
            noise = np.random.normal(0, self.max_action * self.noise_scale[1], size=action.shape)
            action = (action + noise).clip(-self.max_action, self.max_action)
        
        observation, reward, terminated, info = self.env.step(action)
        
        if info['reward_run'] < self.speed_requirement:
            reward = info['reward_ctrl']
        return observation, reward, terminated, info

    def set_params(self, param_scales: Dict) -> None:
        assert len(param_scales) == len(self.params), 'Length of new params must align the initilization params'
        for param, scale in list(param_scales.items()):
            if param == 'foot mass':
                self.unwrapped.model.body_mass[4] = self.initial_param_dict[param][-1] * scale
                self.unwrapped.model.body_mass[7] = self.initial_param_dict[param][-1] * scale
            elif param == 'foot fric':
                self.unwrapped.model.geom_friction[5][0] = self.initial_param_dict[param][-1] * scale
                self.unwrapped.model.geom_friction[8][0] = self.initial_param_dict[param][-1] * scale
            elif param == 'torso mass':
                self.unwrapped.model.body_mass[1] = self.initial_param_dict[param][-1] * scale
            elif param == 'shin mass':
                self.unwrapped.model.body_mass[3] = self.initial_param_dict[param][-1] * scale
                self.unwrapped.model.body_mass[6] = self.initial_param_dict[param][-1] * scale
            elif param == 'damping':
                self.unwrapped.model.dof_damping[3]  = self.initial_param_dict[param][0] * scale
                self.unwrapped.model.dof_damping[4]  = self.initial_param_dict[param][1] * scale
                self.unwrapped.model.dof_damping[5]  = self.initial_param_dict[param][2] * scale
                self.unwrapped.model.dof_damping[6]  = self.initial_param_dict[param][3] * scale
                self.unwrapped.model.dof_damping[7]  = self.initial_param_dict[param][4] * scale
                self.unwrapped.model.dof_damping[8]  = self.initial_param_dict[param][5] * scale
            elif param == 'mass':
                self.unwrapped.model.body_mass[1] =  self.initial_param_dict[param][0] * scale
                self.unwrapped.model.body_mass[2] =  self.initial_param_dict[param][1] * scale
                self.unwrapped.model.body_mass[3] =  self.initial_param_dict[param][2] * scale
                self.unwrapped.model.body_mass[4] =  self.initial_param_dict[param][3] * scale
                self.unwrapped.model.body_mass[5] =  self.initial_param_dict[param][4] * scale
                self.unwrapped.model.body_mass[6] =  self.initial_param_dict[param][5] * scale
                self.unwrapped.model.body_mass[7] =  self.initial_param_dict[param][6] * scale
            elif param == 'ff jnt lower limit':
                self.unwrapped.model.jnt_range[8][0] = self.initial_param_dict[param][-1] * scale
            elif param == 'ff jnt upper limit':
                self.unwrapped.model.jnt_range[8][1] = self.initial_param_dict[param][-1] * scale
            elif param == 'bt jnt lower limit':
                self.unwrapped.model.jnt_range[3][0] = self.initial_param_dict[param][-1] * scale
            elif param == 'bt jnt upper limit':
                self.unwrapped.model.jnt_range[3][1] = self.initial_param_dict[param][-1] * scale
                    
            elif param == 'action noise':
                self.noise_scale = (True, scale)
    
            self.current_param_scale[param] = scale

    def resample_params(self) -> None:
        new_scales = {}
        for param, bound_or_possible_values in list(self.param_dict.items()):
            if len(bound_or_possible_values) == 2:
                new_scales[param] = random.uniform(
                    bound_or_possible_values[0],
                    bound_or_possible_values[1]
                )
            else:
                new_scales[param] = random.choice(bound_or_possible_values)
        self.set_params(new_scales)

    def reset(self, resample: bool = True) -> np.array:
        if resample:
            self.resample_params()
        return self.env.reset()

    @property
    def current_param_scales(self) -> Dict:
        return self.current_param_scale

    @property
    def current_flat_scale(self) -> List:
        return list(self.current_param_scale.values())

    @property
    def action_bound(self) -> float:
        return self.env.action_space.high[0]


if __name__ == '__main__':
    env = HC_Speed_and_Dynamics_Wrapper({})
    s = env.reset()
    while True:
        a = env.action_space.sample()
        s = env.step(a)
