# import sys
# from pathlib import Path

# PROJECT_ROOT = Path(__file__).resolve().parents[1]  # .../CPO
# sys.path.append(str(PROJECT_ROOT))

import gymnasium as gym
import numpy as np
import torch
from abstract_interpretation import domains
from typing import Any, Dict, Tuple


class AntEnv(gym.Env):
    """
    MuJoCo Ant-v4 wrapped as a CMDP that returns
    (obs, reward, cost, terminated, truncated, info).
    """

    def __init__(self, state_processor=None, reduced_dim=None, safety=None):
        super().__init__()                                # ← mandatory for Gym wrappers

        # ---------------------------------------------------------------- base env
        self.env = gym.make("Ant-v4", render_mode="rgb_array")
        self.action_space = self.env.action_space

        # observation space
        self.state_processor = state_processor
        if state_processor is None:
            self.observation_space = self.env.observation_space
        else:
            self.observation_space = gym.spaces.Box(
                low=-1, high=1, shape=(reduced_dim,), dtype=np.float32
            )

        # bookkeeping
        self._max_episode_steps = 1000
        self.step_counter = 0
        self.done = False

        # containers expected elsewhere in your code
        self.safe_polys = []
        self.polys = []

        # your original constraint builders (unchanged)
        self.safety_constraints()
        self.unsafe_constraints()
        self._last_frame_shape = (256, 256, 3)

    # ------------------------------------------------------------------ ORIGINAL
    def safety_constraints(self):
        # Define the observation space bounds
        obs_space_lower = self.observation_space.low
        obs_space_upper = self.observation_space.high

        # Initialize the lower and upper bounds arrays
        lower_bounds = np.copy(obs_space_lower)
        upper_bounds = np.copy(obs_space_upper)
        lower_bounds = np.nan_to_num(
            lower_bounds, nan=-9999, posinf=33333333, neginf=-33333333
        )
        upper_bounds = np.nan_to_num(
            upper_bounds, nan=-9999, posinf=33333333, neginf=-33333333
        )

        # lower_bounds[:12] = [ -4.12, -18.4, 9.80, -0.63, -0.18, -0.1,     -0.1,     -0.1,    -3,    -0.5, -0.51,   -0.1,  ]
        # upper_bounds[:12] =  [ 4.01, 18.39,  9.82,  0.72,  0.15,  0.1,    0.1,    0.1,   3,    0.5,   0.51,  0.1,  ]
        #
        # for i in range(12, 28):
        #     lower_bounds[i] = 0
        #     upper_bounds[i] =
        lower_bounds[13] = -2.3475
        upper_bounds[13] = 2.3475

        # lower_bounds = normalize_constraints(...)
        # upper_bounds = normalize_constraints(...)

        input_deeppoly_domain = domains.DeepPoly(lower_bounds, upper_bounds)
        polys = input_deeppoly_domain.to_hyperplanes(self.env.observation_space)

        # Set the safety constraints using the DeepPolyDomain and the polys
        self.safety = input_deeppoly_domain
        self.original_safety = input_deeppoly_domain
        self.safe_polys = polys
        self.original_safe_polys = polys
        print(self.original_safety)
        # print(self.observation_space)

    def unsafe_constraints(self):
        self.polys = self.safety.invert_polytope(self.env.observation_space)
        print(len(self.polys))
    # ------------------------------------------------------------------ /ORIGINAL

    # ------------------------------------------------------------ gym interface
    def reset(self, **kwargs) -> Tuple[np.ndarray, Dict[str, Any]]:
        state, info = self.env.reset(**kwargs)
        self.step_counter = 0
        self.done = False
        info["state_original"] = state.copy()
        return state, info

    def step(
        self, action: np.ndarray
    ) -> Tuple[np.ndarray, float, float, bool, bool, Dict[str, Any]]:
        # ---- action sanitisation
        if isinstance(action, torch.Tensor):
            action = action.detach().cpu().numpy()
        action = np.asarray(action, dtype=np.float32)
        if action.ndim > 1 and action.shape[0] == 1:
            action = action[0]
        action = action.flatten()
        action = np.clip(action, self.action_space.low, self.action_space.high)

        # ---- env step
        state, reward, terminated, truncated, info = self.env.step(action)
        self.step_counter += 1

        # ---- safety cost
        is_unsafe = self.unsafe(state)
        cost = float(is_unsafe)
        if is_unsafe:
            reward -= 100.0
            terminated = True  # unsafe ⇒ terminate

        # ---- manual time-limit truncation
        if self.step_counter >= self._max_episode_steps:
            truncated = True

        info["state_original"] = state.copy()
        return state, reward, cost, terminated, truncated, info

    # --------------------------------------------------------- convenience API
    def render(self):
        frame = self.env.render()
        if frame is None:
            return np.zeros(self._last_frame_shape, dtype=np.uint8)
        self._last_frame_shape = getattr(frame, "shape", self._last_frame_shape)
        return frame

    def close(self):
        self.env.close()

    def seed(self, seed = None):
        if seed is not None:
            np.random.seed(seed)
            self.env.action_space.seed(seed)
            self.env.observation_space.seed(seed)

    # ------------------------------------------------------------- OmniSafe-ish
    def predict_done(self, state: np.ndarray) -> bool:
        return self.done

    def unsafe(self, state: np.ndarray, simulated: bool = False) -> bool:
        is_health = 2.3475 >= state[13] >= -2.3475
        return not is_health

    def get_cost_from_obs_tensor(self, obs, is_binary, idx: int = 13, limit: float = 2.3475) -> torch.Tensor:
        """Vectorized cost: unsafe iff |obs[..., idx]| > limit.
        Supports (D), (B,D), (T,B,D), (H,M,B,D) → returns (...,1) float32."""
        if isinstance(obs, np.ndarray):
            x = torch.from_numpy(obs)
        elif torch.is_tensor(obs):
            x = obs
        else:
            raise TypeError(type(obs))
        x = x.to(dtype=torch.float32)
        if x.dim() == 1:  # (D,) → (1,D)
            x = x.unsqueeze(0)
        D = x.shape[-1]
        assert idx < D, f"obs dim {D} ≤ idx {idx}"
        unsafe = (x[..., idx].abs() > limit)   # (...,)
        return unsafe.to(torch.float32).unsqueeze(-1)  # (...,1)

# env = AntEnv()

# print("Observation space:", env.observation_space)
# print("Observation shape:", env.observation_space.shape)

# obs, info = env.reset()
# print("Reset obs shape:", obs.shape)