# envs/finsler_wrappers.py
import gym
import numpy as np

class FinslerRewardWrapper(gym.Wrapper):
    """Wraps a MuJoCo environment to use Finslerian cost as the reward signal."""
    def __init__(self, env, we=1.0, wd=1.0, wf=1.0, beta_coef=1.0, lambda_lat=1.0):
        super().__init__(env)
        # Coefficients for Finsler metric components:contentReference[oaicite:10]{index=10}:contentReference[oaicite:11]{index=11}
        self.we = we   # weight for kinetic energy term
        self.wd = wd   # weight for drift (gravity) term
        self.wf = wf   # weight for lateral friction term
        self.beta_coef = beta_coef    # base coefficient for beta(x) slope multiplier
        self.lambda_lat = lambda_lat  # λ for lateral friction

    def step(self, action):
        # Step the underlying environment
        obs, _, done, info = self.env.step(action)
        # Compute instantaneous velocity vector of the robot's base or COM
        velocity = self._get_base_velocity()
        # Compute Finsler components:
        F_energy = self._compute_energy_term(velocity)      # Eq. (1):contentReference[oaicite:12]{index=12}
        F_drift  = self._compute_drift_term(velocity, obs)  # Eq. (2):contentReference[oaicite:13]{index=13}
        F_frict  = self._compute_friction_term(velocity, obs)  # Eq. (3):contentReference[oaicite:14]{index=14}
        # Total Finsler cost:contentReference[oaicite:15]{index=15}
        cost = self.we * F_energy + self.wd * F_drift + self.wf * F_frict
        # Reward is negative cost (so that maximizing reward = minimizing cost)
        reward = -cost
        # Optionally, include original task success info in 'info'
        return obs, reward, done, info

    def _get_base_velocity(self):
        # For MuJoCo-based envs, get base (torso) linear velocity in world frame.
        sim = self.env.unwrapped.sim
        try:
            # For environments like Walker2d/Hopper, the root is the first body
            vel = sim.data.qvel[: self._base_vel_dofs]
        except AttributeError:
            # Fallback: if environment provides velocity in observation (some do)
            obs = self.env.unwrapped._get_obs()
            # Assuming observation includes [position, velocity] as in Gym's Mujoco envs:
            vel = obs[-self._base_vel_dofs:]
        # Convert to horizontal/vertical components as needed
        # Many planar envs have qvel such that [velocity_x, velocity_z, ...] for base
        return np.array(vel)

    def _compute_energy_term(self, velocity):
        # Kinetic energy norm term: sqrt(v^T M v). 
        # Simplify by treating M as identity or diagonal with heavier weight on vertical & rotational velocities.
        v = velocity
        # Example: if velocity = [v_forward, v_vertical, ...], weight vertical higher
        # We assume a planar model: v[0] horizontal, v[1] vertical (if available)
        M = np.eye(len(v))
        if len(v) > 1:
            M[1,1] = 2.0  # example: double weight on vertical component (adjust as needed)
        # (In a full model, M(x) could be the inertia or a learned matrix)
        quad_form = v.T.dot(M).dot(v)
        return np.sqrt(quad_form)  # F_energy = ||v||_M  (Riemannian metric style):contentReference[oaicite:16]{index=16}

    def _compute_drift_term(self, velocity, obs):
        # v_parallel: component of velocity along gravity (vertical axis). 
        # Assuming index 1 is vertical velocity (for planar env).
        if len(velocity) > 1:
            v_vertical = velocity[1]
        else:
            v_vertical = 0.0
        # Determine slope at current state x for beta(x)
        slope = 0.0
        if hasattr(self.env.unwrapped, 'slope_angle'):
            slope = self.env.unwrapped.slope_angle  # custom attribute from InclineWrapper
        # β(x) increases with incline (e.g., beta_coef * sin(theta) or similar):contentReference[oaicite:17]{index=17}
        beta_x = self.beta_coef * np.sin(np.deg2rad(slope))
        # Only penalize upward movement: max(0, v_vertical):contentReference[oaicite:18]{index=18}
        return beta_x * max(0.0, v_vertical)

    def _compute_friction_term(self, velocity, obs):
        # Compute lateral velocity component v_perp (perpendicular to forward direction):contentReference[oaicite:19]{index=19}
        # For planar env (1D forward), assume lateral velocity is zero.
        if len(velocity) > 1:
            v_forward = velocity[0]
        else:
            v_forward = velocity[0] if len(velocity) > 0 else 0.0
        v_perp = 0.0
        # If environment had 3D movement, we would compute the component of COM velocity 
        # perpendicular to forward heading.
        # Here we assume planar tasks, so lateral (out-of-plane) velocity is 0.
        return self.lambda_lat * abs(v_perp)
