from typing import Dict

import gymnasium as gym
import numpy as np

SWIMMER_MEASURE_NAMES = [
    "angular_span",
    "phase_coordination",
    "straightness",
]


class SwimmerBehavioralWrapper(gym.Wrapper):
    """Wrapper for Swimmer that tracks behavioral measures.

    All measures are normalized to [0, 1]:
    - Angular span: How much the joints bend (average span of joint angles)
    - Phase coordination: How well the joints coordinate (based on phase difference)
    - Straightness: How straight the trajectory is (inverse of lateral deviation)
    """

    def __init__(self, env: gym.Env):
        super().__init__(env)
        # Initialize episode tracking variables
        self.reset_episode_stats()

        # Constants for normalization
        self.max_angle = np.pi  # Maximum expected joint angle magnitude
        self.max_y_dev = 2.0  # Maximum expected y-deviation

    def reset_episode_stats(self):
        """Reset all episode statistics."""
        self.n_steps = 0
        # Track joint angles over time
        self.joint1_angles = []
        self.joint2_angles = []
        # Track y positions for deviation
        self.y_positions = []

    def reset(self, **kwargs):
        """Reset the environment and the episode statistics."""
        obs, info = self.env.reset(**kwargs)
        self.reset_episode_stats()
        # Add initial measures to info
        info.update(self.get_current_measures())
        return obs, info

    def compute_angular_span(self) -> float:
        """Compute normalized measure of joint angle spans."""
        if len(self.joint1_angles) < 2:
            return 0.0

        # Compute peak-to-peak amplitude for both joints
        j1_span = np.ptp(self.joint1_angles)
        j2_span = np.ptp(self.joint2_angles)

        # Average the normalized spans
        avg_span = (j1_span + j2_span) / 2
        return float(np.clip(avg_span / (2 * self.max_angle), 0, 1))

    def compute_phase_coordination(self) -> float:
        """Compute how well the joints coordinate in phase."""
        if len(self.joint1_angles) < 4:  # Need a few points to compute phase
            return 0.5

        # Compute correlation between joint angles
        correlation = np.corrcoef(self.joint1_angles, self.joint2_angles)[0, 1]
        # Convert from [-1, 1] to [0, 1] where 0 means in-phase, 1 means anti-phase
        return float((-correlation + 1) / 2)

    def compute_straightness(self) -> float:
        """Compute how straight the trajectory is."""
        if len(self.y_positions) < 2:
            return 1.0

        # Compute standard deviation of y positions
        y_std = np.std(self.y_positions)
        # Convert to [0, 1] where 1 means perfectly straight
        return float(1.0 - np.clip(y_std / self.max_y_dev, 0, 1))

    def get_current_measures(self) -> Dict[str, float]:
        """Calculate current behavioral measures, all normalized to [0, 1]."""
        measures = {
            "angular_span": self.compute_angular_span(),
            "phase_coordination": self.compute_phase_coordination(),
            "straightness": self.compute_straightness(),
        }
        return measures

    def step(self, action):
        """Step the environment and update behavioral measures."""
        obs, reward, terminated, truncated, info = self.env.step(action)

        # Update episode statistics
        self.n_steps += 1

        # Get joint angles (these are part of qpos, after the root coordinates)
        if self.env.unwrapped._exclude_current_positions_from_observation:
            joint1_angle = obs[1]  # motor1_rot
            joint2_angle = obs[2]  # motor2_rot
        else:
            joint1_angle = obs[3]  # motor1_rot
            joint2_angle = obs[4]  # motor2_rot

        self.joint1_angles.append(joint1_angle)
        self.joint2_angles.append(joint2_angle)

        # Track y-position for deviation
        self.y_positions.append(info["y_position"])

        # Add measures to info dict
        info.update(self.get_current_measures())

        return obs, reward, terminated, truncated, info
