from typing import Dict

import gymnasium as gym
import numpy as np

HOPPER_MEASURE_NAMES = [
    "foot_contact_freq",
    # "height_stability",
    # "hop_frequency",
]


class HopperBehavioralWrapper(gym.Wrapper):
    """Wrapper for Hopper that tracks behavioral measures.

    All measures are normalized to [0, 1]:
    - Foot contact: Fraction of time foot is in contact with ground
    - Height stability: How consistent the hopper maintains its height
    - Hop frequency: Number of hops per second, normalized
    """

    def __init__(self, env: gym.Env):
        super().__init__(env)
        # Initialize episode tracking variables
        self.reset_episode_stats()

        # Get foot geom ID for contact detection
        self.foot_id = self.env.unwrapped.model.geom("foot_geom").id
        self.floor_id = self.env.unwrapped.model.geom("floor").id

        # Constants
        self.max_hops_per_second = (
            4.0  # Maximum expected hops per second for normalization
        )

    def reset_episode_stats(self):
        """Reset all episode statistics."""
        self.n_steps = 0
        # Track foot contacts
        self.foot_contacts = 0
        self.prev_contact = False
        self.n_hops = 0
        # Track height for stability
        self.heights = []

    def reset(self, **kwargs):
        """Reset the environment and the episode statistics."""
        obs, info = self.env.reset(**kwargs)
        self.reset_episode_stats()
        # Add initial measures to info
        info.update(self.get_current_measures())
        return obs, info

    def check_foot_contact(self) -> bool:
        """Check if foot is in contact with the ground."""
        for i in range(self.env.unwrapped.data.ncon):
            contact = self.env.unwrapped.data.contact[i]
            if contact.geom1 == self.foot_id or contact.geom2 == self.foot_id:
                other_geom = (
                    contact.geom2 if contact.geom1 == self.foot_id else contact.geom1
                )
                if other_geom == self.floor_id:
                    return True
        return False

    def compute_height_stability(self) -> float:
        """Compute stability of hopper's height."""
        if len(self.heights) < 2:
            return 1.0  # Default to perfect stability with insufficient data

        # Compute coefficient of variation of height
        height_std = np.std(self.heights)
        height_mean = np.mean(self.heights)
        if height_mean == 0:
            return 0.0

        # Normalize using the healthy_z_range
        max_allowed_std = (
            self.env.unwrapped._healthy_z_range[1]
            - self.env.unwrapped._healthy_z_range[0]
        )
        stability = 1.0 - min(height_std / max_allowed_std, 1.0)
        return float(stability)

    def compute_hop_frequency(self) -> float:
        """Compute normalized hop frequency."""
        if self.n_steps == 0:
            return 0.0

        # Compute hops per second
        hops_per_second = self.n_hops / (self.n_steps * self.env.unwrapped.dt)
        # Normalize
        return min(hops_per_second / self.max_hops_per_second, 1.0)

    def get_current_measures(self) -> Dict[str, float]:
        """Calculate current behavioral measures, all normalized to [0, 1]."""
        # Compute foot contact frequency
        contact_freq = self.foot_contacts / max(1, self.n_steps)

        measures = {
            "foot_contact_freq": contact_freq,
            "height_stability": self.compute_height_stability(),
            "hop_frequency": self.compute_hop_frequency(),
        }
        return measures

    def step(self, action):
        """Step the environment and update behavioral measures."""
        obs, reward, terminated, truncated, info = self.env.step(action)

        # Update episode statistics
        self.n_steps += 1

        # Update foot contact
        current_contact = self.check_foot_contact()
        if current_contact:
            self.foot_contacts += 1

        # Count complete hops (when foot starts new contact)
        if current_contact and not self.prev_contact:
            self.n_hops += 1
        self.prev_contact = current_contact

        # Update height tracking
        self.heights.append(info["z_distance_from_origin"])

        # Add measures to info dict
        info.update(self.get_current_measures())

        return obs, reward, terminated, truncated, info
