from typing import Dict

import gymnasium as gym
import numpy as np

WALKER_MEASURE_NAMES = [
    "right_contact_freq",
    "left_contact_freq",
    # "gait_symmetry",
    # "height_stability",
]


class WalkerBehavioralWrapper(gym.Wrapper):
    """Wrapper for Walker2d that tracks behavioral measures.

    All measures are normalized to [0, 1]:
    - Right foot contact: Fraction of time right foot is in contact with ground
    - Left foot contact: Fraction of time left foot is in contact with ground
    - Gait symmetry: How symmetric the walking motion is between left and right legs
    - Height stability: How consistent the walker maintains its height
    """

    def __init__(self, env: gym.Env):
        super().__init__(env)
        # Initialize episode tracking variables
        self.reset_episode_stats()

        # Get foot geom IDs for contact detection
        self.right_foot_id = self.env.unwrapped.model.geom("foot_geom").id
        self.left_foot_id = self.env.unwrapped.model.geom("foot_left_geom").id
        self.floor_id = self.env.unwrapped.model.geom("floor").id

    def reset_episode_stats(self):
        """Reset all episode statistics."""
        self.n_steps = 0
        # Track foot contacts
        self.right_foot_contacts = 0
        self.left_foot_contacts = 0
        # Track joint angles for symmetry
        self.joint_angles = []
        # Track height for stability
        self.heights = []

    def reset(self, **kwargs):
        """Reset the environment and the episode statistics."""
        obs, info = self.env.reset(**kwargs)
        self.reset_episode_stats()
        # Add initial measures to info
        info.update(self.get_current_measures())
        return obs, info

    def check_foot_contact(self, foot_id: int) -> bool:
        """Check if a foot is in contact with the ground."""
        for i in range(self.env.unwrapped.data.ncon):
            contact = self.env.unwrapped.data.contact[i]
            if contact.geom1 == foot_id or contact.geom2 == foot_id:
                # Check if the other geom is the floor
                other_geom = (
                    contact.geom2 if contact.geom1 == foot_id else contact.geom1
                )
                if other_geom == self.floor_id:
                    return True
        return False

    def compute_gait_symmetry(self) -> float:
        """Compute symmetry between left and right leg motions."""
        if len(self.joint_angles) < 2:
            return 1.0  # Default to perfect symmetry with insufficient data

        # Get right (0-2) and left (3-5) joint angles
        joint_angles = np.array(self.joint_angles)
        right_joints = joint_angles[:, :3]  # First 3 joints
        left_joints = joint_angles[:, 3:]  # Last 3 joints

        # Compute average absolute difference in joint angles
        angle_diff = np.mean(np.abs(right_joints - left_joints))
        # Normalize to [0, 1], assuming maximum difference of pi
        symmetry = 1.0 - min(angle_diff / np.pi, 1.0)
        return float(symmetry)

    def compute_height_stability(self) -> float:
        """Compute stability of walker's height."""
        if len(self.heights) < 2:
            return 1.0  # Default to perfect stability with insufficient data

        # Compute coefficient of variation of height
        height_std = np.std(self.heights)
        height_mean = np.mean(self.heights)
        if height_mean == 0:
            return 0.0

        # Normalize using the healthy_z_range
        max_allowed_std = (
            self.env.unwrapped._healthy_z_range[1]
            - self.env.unwrapped._healthy_z_range[0]
        )
        stability = 1.0 - min(height_std / max_allowed_std, 1.0)
        return float(stability)

    def get_current_measures(self) -> Dict[str, float]:
        """Calculate current behavioral measures, all normalized to [0, 1]."""
        # Compute foot contact frequencies
        right_freq = self.right_foot_contacts / max(1, self.n_steps)
        left_freq = self.left_foot_contacts / max(1, self.n_steps)

        measures = {
            "right_contact_freq": right_freq,
            "left_contact_freq": left_freq,
            "gait_symmetry": self.compute_gait_symmetry(),
            "height_stability": self.compute_height_stability(),
        }
        return measures

    def step(self, action):
        """Step the environment and update behavioral measures."""
        obs, reward, terminated, truncated, info = self.env.step(action)

        # Update episode statistics
        self.n_steps += 1

        # Update foot contacts
        if self.check_foot_contact(self.right_foot_id):
            self.right_foot_contacts += 1
        if self.check_foot_contact(self.left_foot_id):
            self.left_foot_contacts += 1

        # Update joint angles (exclude root position and orientation)
        if self.env.unwrapped._exclude_current_positions_from_observation:
            joint_angles = obs[:6]  # 6 joint angles after root
        else:
            joint_angles = obs[2:8]  # 6 joint angles after root
        self.joint_angles.append(joint_angles)

        # Update height tracking
        self.heights.append(info["z_distance_from_origin"])

        # Add measures to info dict
        info.update(self.get_current_measures())

        return obs, reward, terminated, truncated, info
