import gymnasium as gym
import numpy as np

import shapely
from shapely import Point

from typing import Generic, Optional
from numpy.typing import NDArray

from gymnasium.core import ActType, ObsType
from gymnasium import Env

from molecule_movement.AngleSymmetry import AngleSymmetry
from molecule_movement.sampling import Sampler, CircularSampler
from molecule_movement.shapes import ATOM, random_convex_polygon
from molecule_movement import Obstacle

from loguru import logger

from enum import Enum

class ObstacleType(Enum):
    ATOM = 1
    MOLECULE = 2

class RandomObstaclesWrapper(
    gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs
):
    def __init__(self,
                 env: Env[ObsType, ActType],
                 num_obstacles: int = 0,
                 percentage: float = 0.0,
                 obstacle_types: list[ObstacleType] | ObstacleType = ObstacleType.ATOM,
                 sampler: Optional[Sampler] = None,
                 seed: Optional[int] = None):
        super().__init__(env)
        self.num_obstacles = num_obstacles
        self.percentage = percentage
        self.obstacle_types = obstacle_types
        self.multiple_types = isinstance(obstacle_types, list)

        self.angles = AngleSymmetry(1,360)
        if not sampler:
            self.sampler = Sampler(lambda x:  0.1 < x and x < 0.90,
                                   lambda y:  0.1 < y and y < 0.90,
                                   rejection=lambda p: bool(np.any([p.distance(goal.polygon) <= 8 for goal in self.goals])),
                                   min_distance=6,
                                   width=self.get_wrapper_attr("surface_width"), height=self.get_wrapper_attr("surface_height"), seed=seed)

        gym.utils.RecordConstructorArgs.__init__(self, num_obstacles=num_obstacles, obstacle_types=str(obstacle_types))
        gym.Wrapper.__init__(self, env)
        assert isinstance(env, Env)

    def __post_init__(self):
        if (self.num_obstacles is None) == (self.percentage is None):
            raise ValueError("Exactly one of 'num_obstacles' or 'percentage' must be provided.")

    def reset(self, seed: Optional[int] = None, options: Optional[dict] = None) -> tuple[dict[str, NDArray], dict]:
        self.get_wrapper_attr("_set_goals")()
        self.goals = self.get_wrapper_attr("goals")
        setattr(self.unwrapped, "obstacles", self.__create_random_obstacles(seed))
        if self.percentage and self.percentage > 0.0:
            setattr(self.unwrapped, "obstacles", self.__create_random_obstacles_until_percentage(seed))
        else:
            setattr(self.unwrapped, "obstacles", self.__create_random_obstacles(seed))

        obs, info = self.env.reset(seed=seed,options=options)
        return obs, info

    def __create_random_obstacles_until_percentage(self, seed: Optional[int] = None):
        rng = np.random.default_rng(seed)

        def next_seed() -> Optional[int]:
            return int(rng.integers(0, 2**8 - 1)) if seed is not None else None


        self.sampler.set_seed(seed)

        # ---- surface and free-space computation
        W = float(self.get_wrapper_attr("surface_width"))
        H = float(self.get_wrapper_attr("surface_height"))
        surface = shapely.geometry.box(0.0, 0.0, W, H)

        goal_polys = []
        for g in getattr(self, "goals", []):
            gp = getattr(g, "polygon", None)
            if gp is not None:
                # clip to the surface to avoid counting parts outside
                gp_clip = gp.intersection(surface)
                if not gp_clip.is_empty:
                    goal_polys.append(gp_clip)

        goals_union = shapely.unary_union(goal_polys) if goal_polys else None
        goals_area = goals_union.area if goals_union else 0.0
        effective_area = max(0.0, surface.area - 2.0 * goals_area)
        percentage = float(np.clip(self.percentage, 0.0, 1.0))
        #logger.info(f"{effective_area=}")
        target_area = percentage * effective_area
        #logger.info(f"{target_area=}")

        min_obstacle_size, max_obstacle_size = 5, 12
        sqrt_target_area = np.sqrt(np.ceil(target_area))
        #logger.info(sqrt_target_area)
        if sqrt_target_area <= 12.0:
            min_obstacle_size, max_obstacle_size = max(1, sqrt_target_area - 5), sqrt_target_area
        #logger.info(f"scaled to {min_obstacle_size}, {max_obstacle_size}")

        if self.multiple_types:
            def shape_sampler() -> object:
                type_ = int(rng.integers(1, 3))
                return ATOM if type_ == 1 else random_convex_polygon(min=min_obstacle_size, max=max_obstacle_size, seed=next_seed())
        elif self.obstacle_types == ObstacleType.ATOM:
            def shape_sampler() -> object:
                return ATOM
        elif self.obstacle_types == ObstacleType.MOLECULE:
            def shape_sampler() -> object:
                return random_convex_polygon(min=min_obstacle_size, max=max_obstacle_size, seed=next_seed())
        else:
            raise ValueError(f"Unsupported obstacle type: {self.obstacle_types}")

        # ---- sampling loop
        obstacles: list[Obstacle] = []
        obstacle_polys = []
        union_area = 0.0

        # hard cap to avoid infinite loops if target is unreachable
        # (you can tune these based on your sampler acceptance rate)
        max_accepts = 100
        max_attempts = 200_000
        attempts = 0

        while union_area < target_area and attempts < max_attempts and len(obstacles) < max_accepts:
            attempts += 1

            # Pose proposal
            position = self.sampler.sample_position(obstacles)
            shape = shape_sampler()
            angle = self.angles.random_angle(next_seed())

            # Provisional obstacle
            obstacle = Obstacle(position, shape, angle)

            obstacles.append(obstacle)
            obstacle_polys.append(obstacle.polygon)

            # Recompute union area (exact)
            union_area = shapely.unary_union(obstacle_polys).area
        logger.info(union_area)
        logger.info(len(obstacles))
        return obstacles

    def __create_random_obstacles(self, seed: Optional[int] = None):
        rng = np.random.default_rng(seed)

        def next_seed() -> Optional[int]:
            return int(rng.integers(0, 2**8 - 1)) if seed is not None else None

        if self.multiple_types:
            def shape_sampler() -> object:
                type_ = int(rng.integers(1, 3))
                return ATOM if type_ == 1 else random_convex_polygon(seed=next_seed())
        elif self.obstacle_types == ObstacleType.ATOM:
            def shape_sampler() -> object:
                return ATOM
        elif self.obstacle_types == ObstacleType.MOLECULE:
            def shape_sampler() -> object:
                return random_convex_polygon(seed=next_seed())
        else:
            raise ValueError(f"Unsupported obstacle type: {self.obstacle_types}")

        self.sampler.set_seed(seed)

        obstacles = []
        for _ in range(self.num_obstacles):
            position = self.sampler.sample_position(obstacles)
            shape = shape_sampler()
            angle = self.angles.random_angle(next_seed())
            obstacles.append(Obstacle(position, shape, angle))


        return obstacles
