from typing import Callable

import numpy as np
import scipy

from action_masking.rlsampling.sets.zonotope import Zonotope

Function = Callable[[np.ndarray], np.ndarray]


class BaseZonoRandomWalkSampler:
    def __init__(self, zonotope: Zonotope):
        self.zonotope = zonotope
        self.dim = zonotope.d

    def sample(self, walk_length: int = 1, start_point: np.ndarray = None) -> np.ndarray:
        if start_point is None:
            start_point = self.initial_point_sampling()

        p = start_point
        i = 0

        while i < walk_length:
            d = self.generate_random_direction(p)

            # Compute the two boundary points of the zonotope
            b1 = self.zonotope.boundary_point(d, p)
            b2 = self.zonotope.boundary_point(-d, p)

            # get the values of t for b1 and b2 on the line x = p + td
            u_scalar = d.T @ (b1 - p)
            l_scalar = d.T @ (b2 - p)

            line_point_new = self.sample_new_line_point(u_scalar, l_scalar, d, p)
            p_new = p + line_point_new * d

            if self.metropolis_filter(p, p_new):
                p = p_new
                i += 1

            return p

    def initial_point_sampling(self) -> np.ndarray:
        return self.zonotope.c + self.zonotope.G @ np.random.uniform(-1, 1, size=(self.zonotope.g, 1))

    def generate_random_direction(self, p: np.ndarray) -> np.ndarray:
        return uniform_sample_unit_sphere(self.dim)

    def sample_new_line_point(self, u_scalar: float, l_scalar: float, d: np.ndarray, p: np.ndarray) -> float:
        # Default implementation
        raise NotImplementedError

    def metropolis_filter(self, p: np.ndarray, p_new: np.ndarray) -> bool:
        # Default implementation
        return True


class RDHRSampler(BaseZonoRandomWalkSampler):
    def __init__(self, zonotope: Zonotope, function: Function):
        super().__init__(zonotope)
        self.function = function

    def sample_new_line_point(self, u_scalar: float, l_scalar: float, d: np.ndarray, p: np.ndarray) -> np.ndarray:
        return np.random.uniform(l_scalar, u_scalar)

    def metropolis_filter(self, p: np.ndarray, p_new: np.ndarray) -> bool:
        return min(self.function(p_new[:, 0]) / self.function(p[:, 0]), 1) >= np.random.uniform()


class GaussianRDHRSampler(BaseZonoRandomWalkSampler):
    def __init__(self, zonotope: Zonotope, mean: np.ndarray, cov: np.ndarray):
        assert len(mean.shape) == 2 and mean.shape[1] == 1, "mean must be a column vector."
        super().__init__(zonotope)
        self.mean = mean
        self.cov = cov
        self.distribution = scipy.stats.multivariate_normal(mean.flatten(), cov)

    def initial_point_sampling(self) -> np.ndarray:
        if self.zonotope.contains_point(self.mean):
            start_point = self.mean
        else:
            # pseudo-uniformly sample a point in the zonotope
            start_point = self.zonotope.c + self.zonotope.G @ np.random.uniform(-1, 1, size=(self.zonotope.g, 1))
        return start_point

    def generate_random_direction(self, p: np.ndarray) -> np.ndarray:
        # The trick is fine, I just need to find a way to incorporate it, when p is not the mean
        # return np.random.multivariate_normal(self.mean[:, 0], self.cov)[:, np.newaxis]

        # Generate a random direction from the unit sphere
        return uniform_sample_unit_sphere(self.dim)

    def sample_new_line_point(self, u_scalar: float, l_scalar: float, d: np.ndarray, p: np.ndarray) -> float:
        mu_line = d.T @ (self.mean - p)
        sigma_line = np.sqrt(d.T @ self.cov @ d)
        dist = scipy.stats.norm(mu_line, sigma_line)
        line_point_new = dist.ppf(np.random.uniform(dist.cdf(l_scalar), dist.cdf(u_scalar)))
        return line_point_new

    def metropolis_filter(self, p: np.ndarray, p_new: np.ndarray) -> bool:
        return (
            min(self.distribution.pdf(p_new.flatten()) / self.distribution.pdf(p.flatten()), 1) >= np.random.uniform()
        )


class GaussianRDHRSamplerIntegralEstimate(GaussianRDHRSampler):
    def __init__(self, zonotope: Zonotope, mean: np.ndarray, cov: np.ndarray) -> None:
        super().__init__(zonotope, mean, cov)
        self.current_integral_estimate = None

    def get_integral_estimate(self) -> float:
        """This method should be called after a sample step to get the latest integral estimage."""
        return self.current_integral_estimate

    def sample_new_line_point(self, u_scalar: float, l_scalar: float, d: np.ndarray, p: np.ndarray) -> float:
        mu_line = d.T @ (self.mean - p)
        sigma_line = np.sqrt(d.T @ self.cov @ d)
        dist = scipy.stats.norm(mu_line, sigma_line)

        line_point_new = dist.ppf(np.random.uniform(dist.cdf(l_scalar), dist.cdf(u_scalar)))
        self.current_integral_estimate = (dist.cdf(u_scalar) - dist.cdf(l_scalar)) ** 2

        return line_point_new


class UniformRDHRSampler(BaseZonoRandomWalkSampler):
    def __init__(self, zonotope: Zonotope) -> None:
        self.zonotope = zonotope
        self.dim = zonotope.d

    def sample_new_line_point(self, u_scalar: float, l_scalar: float, d: np.ndarray, p: np.ndarray) -> float:
        return np.random.uniform(l_scalar, u_scalar)


class BilliardWalkSampler(BaseZonoRandomWalkSampler):
    def __init__(self, zonotope: Zonotope, max_reflections: int, trajectory_length: float) -> None:
        self.zonotope = zonotope
        self.dim = zonotope.d

    def sample(self, walk_length: int = 1, start_point: np.ndarray = None):
        raise NotImplementedError


def uniform_sample_unit_sphere(dim: int) -> np.ndarray:
    g = np.random.randn(dim)
    return (g / np.linalg.norm(g))[:, np.newaxis]
