import numpy as np
import gym
from pybullet_envs.gym_locomotion_envs import WalkerBaseBulletEnv
from mocca_envs.env_base import EnvBase
from symmetry.env_utils import MirrorIndicesEnv


class FatigueEnv(gym.Wrapper):
    def __init__(self, env: gym.Env,
                 fatigue_param: float = 0.2,
                 fatigue_reward_weight: float = -1.0,
                 dt: float = 1 / 60,
                 instantaneous_effort_reward=False,
                 observable=False):
        super().__init__(env)
        assert isinstance(self.action_space, gym.spaces.Box)
        self.fatigue_param = fatigue_param
        self.fr_weight = fatigue_reward_weight
        self.dt = dt
        # initialize the fatigue_level and its shape should be doubled w.r.t obs shape, considering positive and
        # negative rotation directions
        self.fatigue_level = np.zeros(tuple(list(env.action_space.shape) + [2]))
        self.observable = observable
        self.instantaneous_effort_reward = instantaneous_effort_reward
        if observable:
            # raise NotImplementedError
            assert isinstance(self.observation_space, gym.spaces.Box)
            assert len(self.observation_space.shape) == 1
            assert len(self.action_space.shape) == 1
            obs_low = self.observation_space.low
            obs_high = self.observation_space.high
            org_obs_size = self.observation_space.shape[0]
            act_low = self.action_space.low
            act_high = self.action_space.high
            org_act_size = self.action_space.shape[0]
            fatigue_size = self.action_space.shape[0] * 2

            fatigue_low = np.zeros(fatigue_size)
            fatigue_high = np.concatenate((np.abs(act_high), np.abs(act_low)))
            obs_low = np.concatenate((obs_low, fatigue_low))
            obs_high = np.concatenate((obs_high, fatigue_high))
            self.observation_space = gym.spaces.Box(obs_low, obs_high, dtype=np.float32)
            while isinstance(env, gym.Wrapper):
                if isinstance(env, MirrorIndicesEnv):
                    fatigue_com_inds = np.array(env.minds['com_act_inds'], dtype=np.int32)
                    # fatigue for positive and negative action
                    fatigue_com_inds = np.concatenate((fatigue_com_inds, fatigue_com_inds + org_act_size))

                    fatigue_left_inds = np.array(env.minds['left_act_inds'], dtype=np.int32)
                    fatigue_right_inds = np.array(env.minds['right_act_inds'], dtype=np.int32)
                    neg_inds = np.array(env.minds['neg_act_inds'], dtype=np.int32)
                    # There is no negative fatigue value
                    # left limb fatigue and fatigue for positive action on axis of symmetry
                    fatigue_left_inds = np.concatenate((fatigue_left_inds, fatigue_left_inds + org_act_size, neg_inds))
                    fatigue_right_inds = np.concatenate((fatigue_right_inds, fatigue_right_inds + org_act_size,
                                                         neg_inds + org_act_size))

                    # add original obs size
                    fatigue_com_inds += org_obs_size
                    fatigue_left_inds += org_obs_size
                    fatigue_right_inds += org_obs_size

                    if env.minds['sideneg_act_inds']:
                        raise NotImplementedError

                    env.minds['com_obs_inds'] = env.minds['com_obs_inds'] + fatigue_com_inds.tolist()
                    env.minds['left_obs_inds'] = env.minds['left_obs_inds'] + fatigue_left_inds.tolist()
                    env.minds['right_obs_inds'] = env.minds['right_obs_inds'] + fatigue_right_inds.tolist()
                    break
                env = env.env

    def reset(self, **kwargs):
        # print("fatigue reset")
        self.fatigue_level.fill(0)
        obs = self.env.reset(**kwargs)
        if self.observable:
            obs = np.concatenate((obs, self.fatigue_level.flatten()))
        return obs

    def step(self, action: np.ndarray):
        # print("fatigue step")
        obs, reward, done, info = self.env.step(action)
        processed_action = process_action(action)
        self.fatigue_level = compute_fatigue_level(processed_action, self.fatigue_level, self.fatigue_param, self.dt)
        fatigue_reward = self.fr_weight * np.linalg.norm(self.fatigue_level)
        info["fat_rew"] = fatigue_reward
        # print(fatigue_reward)
        reward += fatigue_reward
        # subtract instantaneous_effort_reward
        if not self.instantaneous_effort_reward:
            if isinstance(self.env.unwrapped, WalkerBaseBulletEnv):
                reward -= self.env.unwrapped.stall_torque_cost * float(np.square(action).mean())
            elif isinstance(self.env.unwrapped, EnvBase):
                reward += self.env.unwrapped.stall_torque_cost * float(np.square(action).mean())
            else:
                raise RuntimeError('Not Supported Environment')
        if self.observable:
            obs = np.concatenate((obs, self.fatigue_level.flatten()))
        return obs, reward, done, info


def process_action(action: np.ndarray):
    pos_action = action * (action >= 0)
    neg_action = -action * (action < 0)
    return np.stack((pos_action, neg_action), axis=-1)


def compute_fatigue_level(action, old_fatigue_level, fatigue_param, dt):
    assert isinstance(action, np.ndarray) and isinstance(old_fatigue_level, np.ndarray)
    assert np.all(action >= 0)
    return (1 - fatigue_param * dt) * old_fatigue_level + fatigue_param * dt * action


def get_bullet_env_dt(env: gym.Env):
    if hasattr(env.unwrapped, "_p"):
        return env.unwrapped._p.getPhysicsEngineParameters()["fixedTimeStep"]
    else:
        return 0.0165
