"""Environment which consists of a point-mass that has to navigate in
a two-dimensional Cartesian plane consisting of three terrains with
different characteristics. Actions are two-dimensional floating points
which model pseudo-impulses applied to the point-mass. Actions are bounded.

Sand is top terrain, then water, then rock.

```
      ------
       .
         .    <-- Sand
       .
   -  ------
   |    --
 s |    --    <-- Water
   |    --
   _  ------
       +
         +    <-- Rock
           +
      ------
```

The middle of the water terrain has `y=0`, and it is of size `s`.
"""
from dataclasses import dataclass
import torch
import enum

from terrain_mass.control_affine_system import ControlAffineSystem
from terrain_mass.control_affine_system import step
from terrain_mass.island import Rectangle


class Terrain(enum.Enum):
    SAND = enum.auto()
    WATER = enum.auto()


@dataclass
class EnvironmentInstance:
    sand_drag_constant: float
    water_drag_constant: float
    action_min: float
    action_max: float
    initial_mass_position: tuple[float, float]
    islands: list[Rectangle]

    @staticmethod
    def get_pos(x: torch.Tensor) -> torch.Tensor:
        """Return the (x, y)-coordinate of the given state vector."""
        return x[0:2]

    @staticmethod
    def get_x_pos(x: torch.Tensor) -> torch.Tensor:
        """Return the x-coordinate of the given state vector."""
        return x[0]

    @staticmethod
    def get_y_pos(x: torch.Tensor) -> torch.Tensor:
        """Return the y-coordinate of the given state vector."""
        return x[1]

    @staticmethod
    def get_x_vel(x: torch.Tensor) -> torch.Tensor:
        """Return the x-coordinate velocity of the given state vector."""
        return x[2]

    @staticmethod
    def get_y_vel(x: torch.Tensor) -> torch.Tensor:
        """Return the y-coordinate velocity of the given state vector."""
        return x[3]

    def get_terrain(self, x: torch.Tensor) -> Terrain:
        pos = self.get_pos(x).tolist()
        is_inside = (
            island.is_inside(*pos)
            for island in self.islands
        )
        if any(is_inside):
            return Terrain.SAND
        return Terrain.WATER

    def get_terrain_drag_constant(self, terrain: Terrain) -> float:
        if terrain is Terrain.SAND:
            return self.sand_drag_constant
        else:
            return self.water_drag_constant

    @property
    def control_affine_system(self) -> ControlAffineSystem:
        """The control affine system induced by the environment.

        The state is a vector `x = [x_pos, y_pos, x_vel, y_vel]`.

        The system is modelled as a simple two-dimensional double-integrator.
        The system matrices should have a straight-forward derivation, but can
        also be found in e.g., "Accelerating Kinodynamic RRT* Through
        Dimensionality Reduction", 2021.
        """
        def drag(x: torch.Tensor) -> torch.Tensor:
            # Get current velocity
            x_vel = self.get_x_vel(x)
            y_vel = self.get_y_vel(x)

            # Get terrain properties
            terrain = self.get_terrain(x)
            terrain_constant = self.get_terrain_drag_constant(terrain)

            effect = torch.tensor([
                0.0,
                0.0,
                x_vel,
                y_vel,
            ])
            return -terrain_constant * effect

        def inertia(x: torch.Tensor) -> torch.Tensor:
            A = torch.tensor([
                [0.0, 0.0, 1.0, 0.0],
                [0.0, 0.0, 0.0, 1.0],
                [0.0, 0.0, 0.0, 0.0],
                [0.0, 0.0, 0.0, 0.0],
            ])
            return torch.matmul(A, x)

        def drift_field(x: torch.Tensor) -> torch.Tensor:
            return drag(x) + inertia(x)

        def control_field(_: torch.Tensor) -> torch.Tensor:
            return torch.tensor([
                [0.0, 0.0],
                [0.0, 0.0],
                [1.0, 0.0],
                [0.0, 1.0],
            ])

        system = ControlAffineSystem(
            drift_field=drift_field,
            control_fields=[control_field],
        )
        return system

    def get_initial_state(self) -> torch.Tensor:
        """Return the initial state induced by this instance."""
        return torch.tensor([
            self.initial_mass_position[0],
            self.initial_mass_position[1],
            0.0,
            0.0,
        ])

    def step(
            self,
            x: torch.Tensor,
            action: torch.Tensor,
            dt: float,
            ) -> torch.Tensor:
        """Simulate forward a state, given a state and action."""
        # Change control according to terrain
        if self.get_terrain(x) is Terrain.SAND:
            action = -action

        # Convert the action into a control for the system
        bounded_action = action.clip(
            min=self.action_min,
            max=self.action_max,
        )
        next_state = step(
            x=x,
            u=[bounded_action],
            system=self.control_affine_system,
            dt=dt,
        )
        return next_state
