import gymnasium as gym
import numpy as np

from dataclasses import dataclass

from typing import Generic, Optional
from numpy.typing import NDArray

from gymnasium.core import ActType, ObsType, WrapperObsType, WrapperActType, Wrapper
from gymnasium import Env

from molecule_movement import VerticalAction, LateralAction, Matching

import shapely
from shapely import Point

from loguru import logger


class SensorObservationsWrapper(
    gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs
):
    def __init__(self,
                 env: Env[ObsType, ActType],
                 num_sensors: int):
        super().__init__(env)
        gym.utils.RecordConstructorArgs.__init__(self, num_sensors=num_sensors)
        self.set_wrapper_attr("num_sensors", num_sensors)
        gym.Wrapper.__init__(self, env)
        self.env = env
        assert isinstance(env, Env)

        self.observation_space = gym.spaces.Box(np.concatenate((np.full((5,),[0.0, 0.0, -np.inf, -np.inf,    0.0]), np.full(num_sensors, 0))),
                                                np.concatenate((np.full((5,),[360, 360,  np.inf,  np.inf, np.inf]), np.full(num_sensors,10))), shape=(5 + num_sensors,))

    def reset(self, seed: Optional[int] = None, options: Optional[dict] = None) -> tuple[dict[str, NDArray], dict]:
        obs, info = self.env.reset(seed=seed,options=options)
        self.molecule = self.env.get_wrapper_attr("current_molecule")
        self.obstacles = self.get_wrapper_attr("obstacles")
        self.render()
        return np.concatenate((obs, self.sensor_readings())), info

    def step(self, action):
        obs, reward, terminated, truncated, info = self.env.step(action)
        return np.concatenate((obs, self.sensor_readings())), reward, terminated, truncated, info

    def sensor_readings(self) -> NDArray:
        m = self.molecule
        sensor_readings = np.full(m.num_sensors, 10.0)
        molecules = self.get_wrapper_attr("molecules")
        o = self.obstacles
        for j, s in enumerate(m.sensors):
            for n in molecules + o:
                if m == n: continue
                if s.intersects(n.polygon):
                    intersection = shapely.intersection(s,n.polygon)
                    sensor_readings[j] = np.min([sensor_readings[j], m.polygon.distance(intersection)])
        self._sensor_readings = sensor_readings
        return sensor_readings
