from typing import Dict

import gymnasium as gym
import numpy as np

HALFCHEETAH_MEASURE_NAMES = [
    "back_foot_freq",
    "front_foot_freq",
    # "height_measure",
    # "gait_coordination",
]


class HalfCheetahBehavioralWrapper(gym.Wrapper):
    """Wrapper for HalfCheetah that tracks behavioral measures.

    All measures are normalized to [0, 1]:
    - Back foot contact: Fraction of time back foot contacts ground
    - Front foot contact: Fraction of time front foot contacts ground
    - Height behavior: How high the cheetah moves (normalized)
    - Gait coordination: How well front and back feet alternate contacts
    """

    def __init__(self, env: gym.Env):
        super().__init__(env)
        # Initialize episode tracking variables
        self.reset_episode_stats()

        # Get foot geom IDs for contact detection
        self.back_foot_id = self.env.unwrapped.model.geom("bfoot").id
        self.front_foot_id = self.env.unwrapped.model.geom("ffoot").id
        self.floor_id = self.env.unwrapped.model.geom("floor").id

        # Constants for normalization
        self.min_height = 0.2  # Minimum expected height
        self.max_height = 1.2  # Maximum expected height

    def reset_episode_stats(self):
        """Reset all episode statistics."""
        self.n_steps = 0
        # Track foot contacts
        self.back_foot_contacts = 0
        self.front_foot_contacts = 0
        self.heights = []
        # For gait coordination
        self.back_contact_phases = []
        self.front_contact_phases = []

    def reset(self, **kwargs):
        """Reset the environment and the episode statistics."""
        obs, info = self.env.reset(**kwargs)
        self.reset_episode_stats()
        # Add initial measures to info
        info.update(self.get_current_measures())
        return obs, info

    def check_foot_contact(self, foot_id: int) -> bool:
        """Check if a foot is in contact with the ground."""
        for i in range(self.env.unwrapped.data.ncon):
            contact = self.env.unwrapped.data.contact[i]
            if contact.geom1 == foot_id or contact.geom2 == foot_id:
                other_geom = (
                    contact.geom2 if contact.geom1 == foot_id else contact.geom1
                )
                if other_geom == self.floor_id:
                    return True
        return False

    def compute_height_measure(self) -> float:
        """Compute normalized height measure."""
        if len(self.heights) == 0:
            return 0.5  # Default to middle value

        avg_height = np.mean(self.heights)
        # Normalize to [0, 1]
        norm_height = (avg_height - self.min_height) / (
            self.max_height - self.min_height
        )
        return float(np.clip(norm_height, 0, 1))

    def compute_gait_coordination(self) -> float:
        """Compute how well the front and back feet alternate."""
        if len(self.back_contact_phases) < 2 or len(self.front_contact_phases) < 2:
            return 0.5  # Default to middle value with insufficient data

        # Check if contacts tend to alternate (opposite phases)
        contacts = np.column_stack(
            [np.array(self.back_contact_phases), np.array(self.front_contact_phases)]
        )

        # Compute fraction of time where exactly one foot is in contact
        one_contact = np.mean(np.sum(contacts, axis=1) == 1)

        return float(one_contact)

    def get_current_measures(self) -> Dict[str, float]:
        """Calculate current behavioral measures, all normalized to [0, 1]."""
        # Compute foot contact frequencies
        back_freq = self.back_foot_contacts / max(1, self.n_steps)
        front_freq = self.front_foot_contacts / max(1, self.n_steps)

        measures = {
            "back_foot_freq": back_freq,
            "front_foot_freq": front_freq,
            "height_measure": self.compute_height_measure(),
            "gait_coordination": self.compute_gait_coordination(),
        }
        return measures

    def step(self, action):
        """Step the environment and update behavioral measures."""
        obs, reward, terminated, truncated, info = self.env.step(action)

        # Update episode statistics
        self.n_steps += 1

        # Update foot contacts
        back_contact = self.check_foot_contact(self.back_foot_id)
        front_contact = self.check_foot_contact(self.front_foot_id)

        if back_contact:
            self.back_foot_contacts += 1
        if front_contact:
            self.front_foot_contacts += 1

        # Store contact phases for coordination measure
        self.back_contact_phases.append(back_contact)
        self.front_contact_phases.append(front_contact)

        # Update height tracking (using z position)
        height = self.env.unwrapped.data.qpos[1]  # z coordinate
        self.heights.append(height)

        # Add measures to info dict
        info.update(self.get_current_measures())

        return obs, reward, terminated, truncated, info
