from typing import Dict

import gymnasium as gym
import numpy as np

BIPEDAL_MEASURE_NAMES = [
    "left_contact_freq",
    "right_contact_freq",
    # "normalized_hull_angle",
    # "normalized_height_var",
]


class BipedalBehavioralWrapper(gym.Wrapper):
    """Wrapper for BipedalWalker that tracks behavioral measures.

    All measures are normalized to [0, 1]:
    - Left foot contact frequency (fraction of timesteps)
    - Right foot contact frequency (fraction of timesteps)
    - Average hull angle (normalized from [-pi/4, pi/4] to [0, 1])
    - Height variation (normalized using expected range [0, 2])
    """

    def __init__(self, env: gym.Env):
        super().__init__(env)
        # Initialize episode tracking variables
        self.reset_episode_stats()

        # Constants for normalization
        self.MAX_ANGLE = np.pi / 4  # Assume reasonable range is [-pi/4, pi/4]
        self.MAX_HEIGHT_VAR = 2.0  # Assume reasonable height variation range is [0, 2]

    def reset_episode_stats(self):
        """Reset all episode statistics."""
        self.n_steps = 0
        self.left_contacts = 0  # Leg index 1
        self.right_contacts = 0  # Leg index 3
        self.hull_angles = []
        self.hull_heights = []

    def reset(self, **kwargs):
        """Reset the environment and the episode statistics."""
        self.reset_episode_stats()
        return self.env.reset(**kwargs)

    def normalize_angle(self, angle: float) -> float:
        """Normalize angle from [-MAX_ANGLE, MAX_ANGLE] to [0, 1]."""
        clipped = np.clip(angle, -self.MAX_ANGLE, self.MAX_ANGLE)
        return (clipped + self.MAX_ANGLE) / (2 * self.MAX_ANGLE)

    def normalize_height_variation(self, std: float) -> float:
        """Normalize height variation from [0, MAX_HEIGHT_VAR] to [0, 1]."""
        return np.clip(std / self.MAX_HEIGHT_VAR, 0, 1)

    def get_current_measures(self) -> Dict[str, float]:
        """Calculate current behavioral measures, all normalized to [0, 1]."""
        if self.n_steps == 0:
            return {
                "left_contact_freq": 0.0,
                "right_contact_freq": 0.0,
                "normalized_hull_angle": 0.5,  # Middle point when no data
                "normalized_height_var": 0.0,
            }

        # Contact frequencies are already in [0, 1]
        left_freq = self.left_contacts / self.n_steps
        right_freq = self.right_contacts / self.n_steps

        # Normalize hull angle
        mean_angle = np.mean(self.hull_angles)
        norm_angle = self.normalize_angle(mean_angle)

        # Normalize height variation
        height_std = np.std(self.hull_heights) if len(self.hull_heights) > 1 else 0.0
        norm_height_var = self.normalize_height_variation(height_std)

        measures = {
            "left_contact_freq": left_freq,
            "right_contact_freq": right_freq,
            "normalized_hull_angle": norm_angle,
            "normalized_height_var": norm_height_var,
        }
        return measures

    def step(self, action):
        """Step the environment and update behavioral measures."""
        obs, reward, terminated, truncated, info = self.env.step(action)

        # Update episode statistics
        self.n_steps += 1

        # Update foot contacts (legs[1] is left, legs[3] is right)
        self.left_contacts += 1 if self.env.unwrapped.legs[1].ground_contact else 0
        self.right_contacts += 1 if self.env.unwrapped.legs[3].ground_contact else 0

        # Update hull statistics
        self.hull_angles.append(float(self.env.unwrapped.hull.angle))
        hull_height = float(self.env.unwrapped.hull.position[1])
        self.hull_heights.append(hull_height)

        # Add measures to info dict
        info.update(self.get_current_measures())

        return obs, reward, terminated, truncated, info
