from typing import Any
import numpy as np

import gymnasium as gym
from compression_autoencoder.envs.multireward_wrapper import MultiRewardWrapper

class MultiRewardReacherWrapper(MultiRewardWrapper):
    """
    Computes multiple rewards for Reacher-v5.
    """
    def __init__(self, env: gym.Env, reward_type: str = "standard"):
        super().__init__(env, reward_type)
        
        self.observation_space = gym.spaces.Box(
            low=-np.inf, high=np.inf, shape=(6,), dtype=np.float64
        )
        """
        coord = 0.15
        self.corner_goals = {
            "I":   np.array([ coord,  coord]),
            "II":  np.array([-coord,  coord]),
            "III": np.array([-coord, -coord]),
            "IV":  np.array([ coord, -coord]),
        }"""
        coord = 0.21
        self.corner_goals = {
            "I":   np.array([coord,  0]),
            "II":  np.array([0,  coord]),
            "III": np.array([0, -coord]),
            "IV":  np.array([-coord, 0]),
        }
        self.distance_threshold = 0.07
        
    @property
    def reward_keys(self) -> set[str]:
        return {
            "speed", "rot_ccw", "rot_cw", "radial_speed",
            "quad_I", "quad_II", "quad_III", "quad_IV",
        }

    def _calculate_rewards(self, obs, reward, terminated, truncated, info, action) -> dict[str, float]:
        # Get relevant physics state from the environment
        fingertip_pos = self.unwrapped.get_body_com("fingertip")[:2]
        fingertip_vel = self.unwrapped.data.body("fingertip").cvel[3:5] # XY linear velocity
        fingertip_speed = np.linalg.norm(fingertip_vel)
        
        px, py = fingertip_pos

        rewards = {}

        # 1. Fingertip Speed
        rewards["speed"] = 1.0 if fingertip_speed > 6 else 0.0

        # 2. Rotational Direction
        tangential_velocity = np.cross(fingertip_pos, fingertip_vel)
        rewards["rot_ccw"] = 1.0 if tangential_velocity > 1 else 0.0
        rewards["rot_cw"] = 1.0 if -tangential_velocity > 1 else 0.0

        # 3. Fast Extension & Retraction
        pos_norm = np.linalg.norm(fingertip_pos) + 1e-8
        unit_pos_vec = fingertip_pos / pos_norm
        rewards["radial_speed"] = 1.0 if abs(np.dot(fingertip_vel, unit_pos_vec)) > 3 else 0.0

        # 4. Active Quadrant Occupation
        dist_I = np.linalg.norm(fingertip_pos - self.corner_goals["I"])
        rewards["quad_I"] = 1.0 if dist_I < self.distance_threshold else 0.0
        
        dist_II = np.linalg.norm(fingertip_pos - self.corner_goals["II"])
        rewards["quad_II"] = 1.0 if dist_II < self.distance_threshold else 0.0

        dist_III = np.linalg.norm(fingertip_pos - self.corner_goals["III"])
        rewards["quad_III"] = 1.0 if dist_III < self.distance_threshold else 0.0

        dist_IV = np.linalg.norm(fingertip_pos - self.corner_goals["IV"])
        rewards["quad_IV"] = 1.0 if dist_IV < self.distance_threshold else 0.0

        return rewards

    def _modify_observation(self, obs: np.ndarray) -> np.ndarray:
        """Helper function to strip target info from the observation."""
        # Original obs: [cos, sin, target_pos, qvel, to_target_vec]
        return np.concatenate([obs[:4], obs[6:8]])

    def step(self, action: np.ndarray):
        """
        Overrides the parent step method to modify the observation.
        """
        original_obs, main_reward, term, trunc, info = super().step(action)
        
        modified_obs = self._modify_observation(original_obs)
        return modified_obs, main_reward, term, trunc, info

    def reset(self, **kwargs) -> tuple[np.ndarray, dict[str, Any]]:
        """
        Overrides the parent reset method to modify the initial observation.
        """
        original_obs, info = super().reset(**kwargs)
        
        modified_obs = self._modify_observation(original_obs)
        return modified_obs, info