from __future__ import annotations

import scipy.stats as ss
from scipy.stats.qmc import PoissonDisk
from scipy.integrate import quad
import numpy as np

from typing import Optional, Callable

from shapely import Point
import shapely

from typing import TYPE_CHECKING
if TYPE_CHECKING:
    from molecule_movement.Molecule import Molecule

from loguru import logger

def normal_dist_pdf(mean: float = 0.5, scale: float = 0.5) -> Callable[[float], float]:
    return lambda x: np.exp(-(x-mean)**2/(2 * scale**2)) / np.sqrt(2*np.pi*scale**2)

def gaussian_mixture(mean1: float = 0.25, scale1: float =.15, mean2: float = 0.75, scale2: float = 0.15, alpha: float = 1/2):
    return lambda x: alpha * normal_dist_pdf(mean1, scale1)(x) + (1 - alpha) * normal_dist_pdf(mean2, scale2)(x)

def triangular(mean: float = 0.5) -> Callable[[float], float]:
    return lambda x: 2 * x / mean

def pol2cart(angle: float, radius: float) -> tuple[float, float]:
    return (radius * np.cos(angle), radius * np.sin(angle))

class ContinuousAxisSampler(ss.rv_continuous):
    def __init__(self, pdf: Callable[[float], float], seed: Optional[int] = None, a: int | None = None, b: int | None = None) -> None:
        super().__init__(a=a,b=b, seed=seed)
        self.scale, _ = quad(pdf, self.a, self.b)
        self.dist = pdf

    def _pdf(self, x):
        return self.dist(x) / self.scale

class DiscreteAxisSampler(ss.rv_continuous):
    def __init__(self, pdf: Callable[[float], float], seed: Optional[int] = None, a: int | None = None, b: int | None = None) -> None:
        super().__init__(a=a,b=b, seed=seed)
        self.scale, _ = quad(pdf, self.a, self.b)
        self.dist = pdf

    def _pmf(self, x):
        return self.dist(x) / self.scale


class Sampler():
    def __init__(self,
                 x_distribution: Callable[[float], float],
                 y_distribution: Callable[[float], float],
                 width: int,
                 height: int,
                 seed: Optional[int] = None,
                 min_distance: float = -np.inf,
                 rejection: Callable[[Point], bool] | bool = True) -> None:
        self.x_sampler = ContinuousAxisSampler(x_distribution, a=0, b=1, seed = None if not seed else (2 * seed)**3)
        self.y_sampler = ContinuousAxisSampler(y_distribution, a=0, b=1, seed = None if not seed else (3 * seed)**2)
        self.width = width
        self.height = height
        self.rejection = rejection
        self.min_distance = min_distance
        self.random_state = np.random.RandomState(seed)

    def sample_x(self) -> float:
        return float(self.x_sampler.rvs(size=1, scale=self.width, random_state=self.random_state)[0])

    def sample_y(self) -> float:
        return float(self.y_sampler.rvs(size=1, scale=self.height, random_state=self.random_state)[0])

    def sample_position(self, molecules: list[Molecule]) -> Point:
        position = Point(self.sample_x(), self.sample_y())
        while callable(self.rejection) and self.rejection(position) or np.any([position.distance(m.polygon) < self.min_distance for m in molecules]):
            position = Point(self.sample_x(), self.sample_y())
        return position

    def set_seed(self, seed: Optional[int] = None) -> None:
        self.random_state = np.random.RandomState(seed)

class PolarSampler():
    def __init__(self,
                 angle_distribution: Callable[[float], float],
                 radius_distribution: Callable[[float], float],
                 radius: float,
                 x_offset: float,
                 y_offset: float,
                 seed: Optional[int] = None,
                 min_distance: float = -np.inf,
                 rejection: Callable[[Point], bool] | bool = True) -> None:
        self.angle_sampler = ContinuousAxisSampler(angle_distribution, a=0, b=1, seed = None if not seed else (2 * seed)**3)
        self.radius_sampler = ContinuousAxisSampler(radius_distribution, a=0, b=1, seed = None if not seed else (3 * seed)**2)
        self.radius = radius
        self.x_offset = x_offset
        self.y_offset = y_offset
        self.rejection = rejection
        self.min_distance = min_distance
        self.random_state = np.random.RandomState(seed)

    def sample_radius(self) -> float:
        return float(self.radius_sampler.rvs(size=1, scale=self.radius, random_state=self.random_state)[0])

    def sample_angle(self) -> float:
        return float(self.angle_sampler.rvs(size=1, scale=2*np.pi, random_state=self.random_state)[0])

    def sample_position(self, molecules: list[Molecule]) -> Point:
        angle = self.sample_angle()
        radius = self.sample_radius()
        position = shapely.affinity.translate(Point(*pol2cart(angle, radius)), self.x_offset, self.y_offset)
        while callable(self.rejection) and self.rejection(position) or np.any([position.distance(m.polygon) < self.min_distance for m in molecules]):
            angle = self.sample_angle()
            radius = self.sample_radius()
            position = shapely.affinity.translate(Point(*pol2cart(angle, radius)), self.x_offset, self.y_offset)
        return position

    def set_seed(self, seed: Optional[int] = None) -> None:
        self.random_state = np.random.RandomState(seed)

class PoissonDiskSampler():
    DIMENSION = 2
    def __init__(self,
                 width: int,
                 height: int,
                 seed: Optional[int] = None) -> None:
        self.width = width
        self.height = height
        self.random_state = np.random.RandomState(seed)

    def sample_positions(self, n: int) -> list[Point]:
        n_per_side = np.ceil(np.sqrt(n))
        self.min_distance = np.min([self.width/n_per_side, self.height/n_per_side])/np.max([self.width, self.height]) * 0.75
        self.sampler = PoissonDisk(d=self.DIMENSION, radius=self.min_distance, seed=self.random_state, optimization="lloyd")
        positions = self.sampler.random(n)
        return [Point(p) for p in positions * np.array([self.width, self.height])]

    def set_seed(self, seed: Optional[int] = None) -> None:
        self.random_state = np.random.RandomState(seed)


if __name__ == "__main__":
    sampler = Sampler(lambda x: 1, lambda x: 1, width=20, height=200, seed=1)
    import time
    logger.warning(sampler.y_sampler.cdf(20))
    while True:
        logger.info(sampler.sample_position())
        time.sleep(0.45)

