import gym
import numpy as np

class RewardHighVelocity(gym.Wrapper):
    """Penalize excessive forward velocity and log velocity as monitor value."""
    def __init__(self, env, env_name, max_vel=2.0, prob_vel_penal=0.3, cost_vel=-5.0):
        super().__init__(env)
        self.env_name = env_name.lower()
        self.max_vel = max_vel
        self.prob_vel_penal = prob_vel_penal
        self.cost_vel = cost_vel

    def step(self, action):
        state, reward, done, info = self.env.step(action)
        info.setdefault('risky_state', False)

        # Select the state index containing velocity depending on environment
        if 'hopper' in self.env_name:
            vel_idx = 5
        elif 'walker2d' in self.env_name:
            vel_idx = 8
        elif 'halfcheetah' in self.env_name:
            vel_idx = 8
        else:
            vel_idx = 8  # default
        
        if len(state) > vel_idx:
            velocity = abs(state[vel_idx])
            info['monitor_val'] = velocity
            info['monitor_name'] = 'velocity'
        else:
            velocity = 0.0
            info['monitor_val'] = velocity
            info['monitor_name'] = 'velocity'
            print(f"[Warning] state dim < {vel_idx}, fallback velocity=0.0")
        
        if velocity > self.max_vel:
            info['risky_state'] = True
            if np.random.rand() < self.prob_vel_penal:
                reward += self.cost_vel
        
        return state, reward, done, info


class RewardUnhealthyPose(gym.Wrapper):
    """Penalize unhealthy pose and optionally terminate on large deviations.

    Logs angle as monitor value and marks risky states when outside healthy range.
    """
    def __init__(
        self, 
        env, 
        prob_pose_penal=0.3, 
        cost_pose=-10.0, 
        healthy_angle_range=(-0.5, 0.5),
        done_if_exceed_factor=2.0  # e.g., 2.0 => if angle exceeds 2 * 0.5 = 1.0
    ):
        super().__init__(env)
        self.prob_pose_penal = prob_pose_penal
        self.cost_pose = cost_pose
        self.healthy_angle_range = healthy_angle_range
        self.done_if_exceed_factor = done_if_exceed_factor

    def step(self, action):
        state, reward, done, info = self.env.step(action)
        info.setdefault('risky_state', False)
        
        # For demonstration, assume state[1] is "pitch angle"
        if len(state) > 1:
            pitch_angle = state[1]
            info["monitor_val"] = pitch_angle
            info["monitor_name"] = "angle"
        else:
            pitch_angle = 0.0
            info["monitor_val"] = pitch_angle
            info["monitor_name"] = "angle"
            print("[Warning] Observed dimension < 1, fallback pitch_angle=0.0")
        
        # Check if outside healthy range => apply penalty
        low, high = self.healthy_angle_range
        if not (low <= pitch_angle <= high):
            info['risky_state'] = True
            if np.random.rand() < self.prob_pose_penal:
                reward += self.cost_pose  # cost_pose negative => penalty
        
        # Additional check: if angle exceeds some factor x threshold => done
        # e.g., threshold=0.5 => factor=2 => angle=1.0 => done
        threshold_angle = abs(high)  # e.g. 0.5
        if abs(pitch_angle) > self.done_if_exceed_factor * threshold_angle:
            done = True
        
        return state, reward, done, info

def make_risky_env(env_name, risk_prob=0.9, risk_penalty=50.0,max_vel=2.0, prob_vel_penal=0.3, cost_vel=-5.0,
                    prob_pose_penal=0.3, cost_pose=-10.0, healthy_angle_range=(-0.5, 0.5),done_if_exceed_factor=2.0):
    """Create base env and attach appropriate risk wrappers per environment type."""
    env = gym.make(env_name)
    if 'halfcheetah' in env_name.lower():
        # HalfCheetah → only velocity penalty
        env = RewardHighVelocity(env, env_name,max_vel=max_vel, prob_vel_penal=prob_vel_penal, cost_vel=cost_vel)
        # env = TimeLimit(env, max_episode_steps=200) # halfcheeta=200steps, walker2d/hopper 500steps
        print("HalfCheetah environment created with velocity penalty.")
    elif 'hopper' in env_name.lower() or 'walker2d' in env_name.lower():
        # env = TimeLimit(env, max_episode_steps=500)
        # Hopper / Walker2D → only unhealthy pose penalty
        # env= RewardHighVelocity(env, env_name, max_vel=max_vel, prob_vel_penal=prob_vel_penal, cost_vel=cost_vel)
        # print("Hopper/Walker2d environment created with velocity penalty.")
        env = RewardUnhealthyPose(env, prob_pose_penal=prob_pose_penal, cost_pose=cost_pose,
                                  healthy_angle_range=healthy_angle_range,done_if_exceed_factor = done_if_exceed_factor)
        print("Hopper/Walker2d environment created with pitch angle penalty.")

    elif 'riskymaze' in env_name.lower():
        # Example: expose set_risk(prob, penalty) on the env
        # or set env.risk_prob / env.risk_penalty attributes directly
        if hasattr(env, "set_risk"):
            env.set_risk(risk_prob, risk_penalty)
            # Optionally introduce TimeLimit (e.g., 200 steps)
            # env = TimeLimit(env, max_episode_steps=200)
            print(f"Risky PointMass environment created with risk_prob={risk_prob}, risk_penalty={risk_penalty}.")
        else:
            # If set_risk is undefined, set attributes directly
            env.risk_prob = risk_prob
            env.risk_penalty = risk_penalty
            print(f"Risky PointMass environment assigned risk_prob={risk_prob}, risk_penalty={risk_penalty}.")
    
    elif 'AntNavigation-v0' in env_name.lower():
        if hasattr(env, "set_risk"):
            env.set_risk(risk_prob, risk_penalty)
            # Optionally introduce TimeLimit (e.g., 600 steps)
            # env = TimeLimit(env, max_episode_steps=600)
            print(f"Ant Obstacles Navigation environment created with risk_prob={risk_prob}, risk_penalty={risk_penalty}.")
        else:
            # set_risk が未定義なら、単純に属性を設定しておく
            env.risk_prob = risk_prob
            env.risk_penalty = risk_penalty
            print(f"Ant Obstacles Navigation environment assigned risk_prob={risk_prob}, risk_penalty={risk_penalty}.")
    else:
        # Default: do nothing unless customization is needed
        pass
    return env
