import numpy as np
from gym import utils
from gym.envs.mujoco import mujoco_env

DEFAULT_CAMERA_CONFIG = {
    "trackbodyid": 2,
    "distance": 3.0,
    "lookat": np.array((0.0, 0.0, 1.15)),
    "elevation": -20.0,
}

class HopperPiecewiseScaleEnv(mujoco_env.MujocoEnv, utils.EzPickle):
    """
    A MuJoCo-based Hopper environment that uses a single noise law (uniform),
    but changes the noise scale multiple times per episode with a gradual (partial) transition.
    The randomization logic is placed in reset_model() to avoid issues with the
    step() call inside MujocoEnv.__init__.
    """

    def __init__(
        self,
        xml_file="hopper.xml",
        forward_reward_weight=1.0,
        ctrl_cost_weight=1e-3,
        healthy_reward=1.0,
        terminate_when_unhealthy=True,
        healthy_state_range=(-100.0, 100.0),
        healthy_z_range=(0.7, float("inf")),
        healthy_angle_range=(-0.2, 0.2),
        reset_noise_scale=5e-3,
        exclude_current_positions_from_observation=True,
        max_episode_steps=1000,
        n_switches=3,            # number of scale changes
        transition_length=50,    # steps for partial blending
        scale_max=5.0,           # maximum noise scale
        seed=None,
    ):
        """
        Args:
            max_episode_steps: total steps in one episode
            n_switches: how many random times we change scale
            transition_length: number of steps to blend from old scale -> new scale
            scale_max: maximum uniform noise scale
            seed: if not None, used to seed the numpy RNG
        """

        # Store constructor arguments for EzPickle (helps with env cloning)
        utils.EzPickle.__init__(**locals())

        # We will define placeholders here so they exist before MujocoEnv calls step().
        self.seed_val = seed
        if seed is not None:
            np.random.seed(seed)

        self._forward_reward_weight = forward_reward_weight
        self._ctrl_cost_weight = ctrl_cost_weight
        self._healthy_reward = healthy_reward
        self._terminate_when_unhealthy = terminate_when_unhealthy
        self._healthy_state_range = healthy_state_range
        self._healthy_z_range = healthy_z_range
        self._healthy_angle_range = healthy_angle_range
        self._reset_noise_scale = reset_noise_scale
        self._exclude_current_positions_from_observation = exclude_current_positions_from_observation

        self.max_episode_steps = max_episode_steps
        self.n_switches = n_switches
        self.transition_length = transition_length
        self.scale_max = scale_max

        # Internal attributes that we will properly set in reset_model()
        self.switch_times = None
        self.scales = None
        self.timestep_in_episode = 0
        self.name = 'hopper-medium-expert-v2'
        # For external forces
        self.force_body_index = 1  # which body to apply force
        self.n_bodies = 5
        self.current_force = np.zeros((self.n_bodies, 6))

        # Initialize MuJoCoEnv (loads the model, calls a "probe" step)
        mujoco_env.MujocoEnv.__init__(self, xml_file, 4)

    def reset_model(self):
        """
        Called by MujocoEnv when the environment is actually reset.
        We set up the random switch times, scales, etc. here.
        """

        # Episode time
        self.timestep_in_episode = 0

        # Reset the external forces
        self.current_force[:] = 0.0
        self.data.xfrc_applied[:] = self.current_force

        # Randomize initial qpos, qvel
        noise_low = -self._reset_noise_scale
        noise_high = self._reset_noise_scale
        qpos = self.init_qpos + np.random.uniform(low=noise_low, high=noise_high, size=self.model.nq)
        qvel = self.init_qvel + np.random.uniform(low=noise_low, high=noise_high, size=self.model.nv)
        self.set_state(qpos, qvel)

        # Choose random switch times (sorted) within the episode
        # e.g., if max_episode_steps=1000, pick some times in [50, 950]
        # plus a sentinel beyond the last step
        self.switch_times = sorted(
            np.random.choice(range(50, self.max_episode_steps - 50), self.n_switches, replace=False)
        )
        self.switch_times.append(self.max_episode_steps + 1)

        # For each segment, pick a random scale in [0, scale_max]
        self.scales = [np.random.uniform(0.0, self.scale_max) for _ in range(self.n_switches + 1)]

        # Return initial observation
        return self._get_obs()

    def step(self, action):
        """
        Standard step in the environment.
        Because MujocoEnv calls 'step()' once in __init__,
        we handle the scenario where switch_times is None or not set yet.
        """

        # If this is the "probe step" from MujocoEnv.__init__,
        # our lists might still be None. Provide a trivial step to avoid error.
        if self.switch_times is None or self.scales is None:
            # Minimal forward step with zero external force
            self.do_simulation(action, self.frame_skip)
            obs = self._get_obs()
            return obs, 0.0, False, {}

        # Normal flow: we are in a real episode
        self.timestep_in_episode += 1

        # Identify which segment we are in
        seg_idx = 0
        while seg_idx < len(self.switch_times) - 1:
            if self.timestep_in_episode < self.switch_times[seg_idx]:
                break
            seg_idx += 1

        # old_scale_idx = seg_idx - 1 (but at least 0)
        old_scale_idx = max(seg_idx - 1, 0)
        new_scale_idx = seg_idx
        old_scale = self.scales[old_scale_idx]
        new_scale = self.scales[new_scale_idx]

        # We'll do a partial transition over "transition_length" steps
        old_segment_start = 0 if (old_scale_idx == 0) else self.switch_times[old_scale_idx]
        steps_in_current_segment = float(self.timestep_in_episode - old_segment_start)

        # alpha in [0,1] over transition_length
        alpha = min(steps_in_current_segment / self.transition_length, 1.0)
        current_scale = (1.0 - alpha) * old_scale + alpha * new_scale

        # Now sample uniform noise in [-current_scale, current_scale]
        # We'll treat this as "delta_f" - how much to CHANGE the force by this step
        delta_f = np.random.uniform(-current_scale, current_scale)

        # Add to existing force, clamp in [-scale_max, scale_max]
        new_force = self.current_force[self.force_body_index, 0] + delta_f
        new_force = np.clip(new_force, -self.scale_max, self.scale_max)
        self.current_force[:, 0] = new_force
        self.data.xfrc_applied[:] = self.current_force

        # Simulation step
        x_position_before = self.sim.data.qpos[0]
        self.do_simulation(action, self.frame_skip)
        x_position_after = self.sim.data.qpos[0]
        x_velocity = (x_position_after - x_position_before) / self.dt

        # Compute reward
        reward = (
            self._forward_reward_weight * x_velocity
            + self.healthy_reward
            - self.control_cost(action)
        )

        obs = self._get_obs()
        done = self.done or (self.timestep_in_episode >= self.max_episode_steps)
        info = {
            "x_position": x_position_after,
            "x_velocity": x_velocity,
            "old_scale": old_scale,
            "new_scale": new_scale,
            "blended_scale": current_scale,
            "alpha": alpha,
        }

        return obs, reward, done, info

    # ------------------------
    # Helper / property methods
    # ------------------------

    @property
    def healthy_reward(self):
        return float(self.is_healthy or self._terminate_when_unhealthy) * self._healthy_reward

    def control_cost(self, action):
        return self._ctrl_cost_weight * np.sum(np.square(action))

    @property
    def is_healthy(self):
        z, angle = self.sim.data.qpos[1:3]
        state = self.state_vector()[2:]

        min_state, max_state = self._healthy_state_range
        min_z, max_z = self._healthy_z_range
        min_angle, max_angle = self._healthy_angle_range

        healthy_state = np.all((min_state < state) & (state < max_state))
        healthy_z = (min_z < z < max_z)
        healthy_angle = (min_angle < angle < max_angle)

        return all((healthy_state, healthy_z, healthy_angle))

    @property
    def done(self):
        if self._terminate_when_unhealthy:
            return not self.is_healthy
        else:
            return False

    def _get_obs(self):
        position = self.sim.data.qpos.flat.copy()
        velocity = np.clip(self.sim.data.qvel.flat.copy(), -10, 10)
        if self._exclude_current_positions_from_observation:
            position = position[1:]
        return np.concatenate((position, velocity)).ravel()

    def viewer_setup(self):
        for key, value in DEFAULT_CAMERA_CONFIG.items():
            if isinstance(value, np.ndarray):
                getattr(self.viewer.cam, key)[:] = value
            else:
                setattr(self.viewer.cam, key, value)
