from typing import Dict

import gymnasium as gym

ANT_MEASURE_NAMES = [
    "front_left_contact",
    "front_right_contact",
    "back_left_contact",
    "back_right_contact",
]


class AntBehavioralWrapper(gym.Wrapper):
    """Wrapper for Ant that tracks behavioral measures.

    All measures are normalized to [0, 1]:
    - Front left foot contact frequency: Fraction of time foot is in contact
    - Front right foot contact frequency: Fraction of time foot is in contact
    - Back left foot contact frequency: Fraction of time foot is in contact
    - Back right foot contact frequency: Fraction of time foot is in contact
    """

    def __init__(self, env: gym.Env):
        super().__init__(env)
        # Initialize episode tracking variables
        self.reset_episode_stats()

        # Get geom IDs for contact detection
        env_unwrapped = self.env.unwrapped
        # Each leg's final segment (ankle) geom
        self.fl_foot_id = env_unwrapped.model.geom("left_ankle_geom")
        self.fr_foot_id = env_unwrapped.model.geom("right_ankle_geom").id
        self.bl_foot_id = env_unwrapped.model.geom("third_ankle_geom").id
        self.br_foot_id = env_unwrapped.model.geom("fourth_ankle_geom").id
        self.floor_id = env_unwrapped.model.geom("floor").id

    def reset_episode_stats(self):
        """Reset all episode statistics."""
        self.n_steps = 0
        # Track foot contacts for each leg
        self.fl_contacts = 0  # front left
        self.fr_contacts = 0  # front right
        self.bl_contacts = 0  # back left
        self.br_contacts = 0  # back right

    def reset(self, **kwargs):
        """Reset the environment and the episode statistics."""
        obs, info = self.env.reset(**kwargs)
        self.reset_episode_stats()
        # Add initial measures to info
        info.update(self.get_current_measures())
        return obs, info

    def check_foot_contact(self, foot_id: int) -> bool:
        """Check if a foot is in contact with the ground."""
        for i in range(self.env.unwrapped.data.ncon):
            contact = self.env.unwrapped.data.contact[i]
            if contact.geom1 == foot_id or contact.geom2 == foot_id:
                other_geom = (
                    contact.geom2 if contact.geom1 == foot_id else contact.geom1
                )
                if other_geom == self.floor_id:
                    return True
        return False

    def get_current_measures(self) -> Dict[str, float]:
        """Calculate current behavioral measures, all normalized to [0, 1]."""
        # Compute foot contact frequencies
        fl_freq = self.fl_contacts / max(1, self.n_steps)
        fr_freq = self.fr_contacts / max(1, self.n_steps)
        bl_freq = self.bl_contacts / max(1, self.n_steps)
        br_freq = self.br_contacts / max(1, self.n_steps)

        measures = {
            "front_left_contact": fl_freq,
            "front_right_contact": fr_freq,
            "back_left_contact": bl_freq,
            "back_right_contact": br_freq,
        }
        return measures

    def step(self, action):
        """Step the environment and update behavioral measures."""
        obs, reward, terminated, truncated, info = self.env.step(action)

        # Update episode statistics
        self.n_steps += 1

        # Check each foot's contact
        if self.check_foot_contact(self.fl_foot_id):
            self.fl_contacts += 1
        if self.check_foot_contact(self.fr_foot_id):
            self.fr_contacts += 1
        if self.check_foot_contact(self.bl_foot_id):
            self.bl_contacts += 1
        if self.check_foot_contact(self.br_foot_id):
            self.br_contacts += 1

        # Add measures to info dict
        info.update(self.get_current_measures())

        return obs, reward, terminated, truncated, info
