from __future__ import annotations


import numpy as np
from numpy.typing import NDArray

from shapely import Point, LineString, Polygon

from dataclasses import dataclass, field

from typing import TYPE_CHECKING
if TYPE_CHECKING:
    from molecule_movement.Molecule import Molecule
    from molecule_movement.Objects import Goal

from loguru import logger


MAX_SEGMENT_LENGTH=10.0
WAYPOINT_BUFFER=1.5

@dataclass(frozen=True,eq=True)
class Pair:
    molecule: Molecule
    goal: Goal

    def __iter__(self):
        return (self.__dict__[item] for item in self.__dict__)

    def __str__(self):
        return f"({self.molecule.name} - {self.goal.position})"

@dataclass(frozen=False,eq=True)
class Matching:
    molecule: Molecule
    goal: Goal
    waypoints: list[Point] = field(default_factory=list)
    __linestring: LineString = field(init=False)

    def __post_init__(self) -> None:
        self.__set_linestring()
        self.__set_initial_waypoint()
        self.__inf = False

    def __set_linestring(self) -> None:
        self.__anchors = [self.molecule.center] + self.waypoints + [self.goal.position]
        self.__linestring = LineString(self.__anchors)
        #self.__segmentize_linestring()
        if len(self.waypoints) > 0 and not self.__linestring.is_simple:
            raise ValueError(f"The linestring for {self.molecule.center=} - {self.waypoints=} - {self.goal.position=} crosses itself.")

    def __segmentize_linestring(self) -> None:
        self.__linestring = self.__linestring.segmentize(max_segment_length=MAX_SEGMENT_LENGTH)
        self.__anchors = [Point(*p) for p in list(self.__linestring.coords)]
        self.waypoints = self.__anchors[1:-2]

    def __set_initial_waypoint(self) -> None:
        self.__current_waypoint_index = 1
        self.__current_waypoint = self.__anchors[self.__current_waypoint_index]
        self.compute_waypoint_distance()

    def next_waypoint(self, molecule: Molecule) -> Point:
        if molecule.polygon.intersects(self.__current_waypoint.buffer(WAYPOINT_BUFFER)):
            try:
                self.__current_waypoint_index += 1
                self.__current_waypoint = self.__anchors[self.__current_waypoint_index]
            except IndexError:
                self.__current_waypoint_index = len(self.__anchors) - 1
                self.__current_waypoint = self.__anchors[self.__current_waypoint_index]
            self.compute_waypoint_distance()
        return self.__current_waypoint

    def compute_waypoint_distance(self) -> float:
        self.__current_waypoint_distance = self.__current_waypoint.distance(self.__anchors[self.__current_waypoint_index - 1])
        return self.__current_waypoint_distance

    def distance_to_waypoint(self, molecule: Molecule) -> float:
        return molecule.center.distance(self.__current_waypoint)

    def normalized_distance_to_waypoint(self, molecule: Molecule) -> float:
        current_distance = self.distance_to_waypoint(molecule)
        return current_distance / self.__current_waypoint_distance

    @property
    def has_waypoints(self) -> bool:
        return bool(self.waypoints)

    def __len__(self):
        return 2 + len(self.waypoints)

    @property
    def length(self) -> float:
        if self.__inf: return np.inf
        return self.__linestring.length

    @property
    def matching_line(self) -> LineString:
        return self.__linestring

    @property
    def anchors(self) -> list[Point]:
        return self.__anchors

    @property
    def centroid(self) -> Point:
        num_anchors = len(self.__anchors)
        if num_anchors == 2:
            return self.__linestring.centroid
        else:
            return LineString([self.__anchors[num_anchors // 2], self.__anchors[num_anchors // 2 - 1]]).centroid

    def add_waypoint(self, waypoint: Point, index: int=0) -> None:
        if len(self.waypoints) == 0:
            self.waypoints.append(waypoint)
        else:
            self.waypoints.insert(index, waypoint)
        self.__set_linestring()

    def drop_waypoints(self) -> None:
        self.waypoints = list()
        self.__set_linestring()

    @property
    def coordinates(self) -> tuple[NDArray[np.floating], NDArray[np.floating]]:
        return self.__linestring.coords.xy

    @property
    def linestring(self) -> LineString:
        return self.__linestring

    def set_inf(self) -> None:
        self.__inf = True

    def __str__(self) -> str:
        if self.has_waypoints:
            return f"{self.molecule} - {self.goal} ({len(self) - 2} waypoints)"
        return f"{self.molecule} - {self.goal}"
