from abc import ABC, abstractmethod
from functools import lru_cache
import numpy as np
import scipy
import torch
from dataclasses import dataclass
from ai_control_games.other_utils import floats_are_close
from icecream import ic
import collections

collections.Callable = collections.abc.Callable


# TODO - Relax the assumption that it's 0-1 (maybe)
class DistOver0to1(ABC):
    @abstractmethod
    def cdf(self, x: float) -> float:
        raise NotImplementedError

    @abstractmethod
    def inv_cdf(self, p: float) -> float:
        raise NotImplementedError

    @property
    def vec_icdf(self):
        return torch.vectorize(self.inv_cdf)

    def many_pdf(self, xs: np.ndarray) -> np.ndarray:
        raise NotImplementedError("Not all distributions have a vectorised pdf")

    def estimate_pdf(self, x: float) -> float:
        # Check if there is a method for pdf
        if hasattr(self, "pdf"):
            return self.pdf(x)

        # Round to nearest 0.01, then lru_cache can memorise all values
        x = round(x, 4)
        return self._estimate_pdf(x)

    @lru_cache(maxsize=101)
    def _estimate_pdf(self, x: float, eps: float = 1e-2) -> float:
        low = x - eps if x - eps >= 0.0 else 0.0
        high = x + eps if x + eps <= 1.0 else 1.0
        dx = high - low
        return (self.cdf(high) - self.cdf(low)) / dx

    def validate_dist(self):
        assert self.cdf(0.0) == 0.0, self.cdf(0.0)
        assert floats_are_close(1.0, self.cdf(1.0)), self.cdf(1.0)
        xs = torch.linspace(0.0, 1.0, 100)
        for x in xs:
            x = float(x)
            p = self.cdf(x)
            x2 = self.inv_cdf(p)
            pdf_is_0 = (self.estimate_pdf(x) == 0.0)
            if not pdf_is_0:
                assert floats_are_close(x, x2, rtol=1e-2), (x, p, x2)

    def sample(self) -> float:
        x = float(torch.distributions.uniform.Uniform(0.0, 1.0).sample())
        return self.inv_cdf(x)

    def __init__(self):
        self.validate_dist()
        super().__init__()


class DiscreteDist(DistOver0to1):
    def __init__(self, probs: np.ndarray):
        assert floats_are_close(sum(probs), 1.0), sum(probs)
        self.probs = probs
        self.cum_probs = np.cumsum(probs)
        start_vals = np.linspace(0, 1, len(probs) + 1)[:-1]
        self.bar_width = start_vals[1]
        self.vals = start_vals + (self.bar_width / 2)
        super().__init__()

    def __repr__(self):
        return f"DiscreteDist(probs={self.probs})"

    def pdf(self, x: float) -> float:
        raise NotImplementedError("Discrete distributions do not have a pdf")

    def cdf(self, x: float) -> float:
        if x < self.vals[0]:
            return 0.0
        if x >= self.vals[-1]:
            return 1.0

        index = (x / self.bar_width).astype(int)
        candidate_x_val = self.vals[index]
        if x < candidate_x_val:
            if index == 0:
                return 0.0
            else:
                return self.cum_probs[index - 1]
        return self.cum_probs[index]

    def inv_cdf(self, p: float) -> float:
        if p == 0.0:
            return 0.0
        if p == 1.0:
            return self.vals[-1]
        leftovers = self.cum_probs - p
        index = np.argmax(leftovers >= 0)
        return self.vals[index]

    def validate_dist(self):
        assert self.cdf(0.0) == 0.0, self.cdf(0.0)
        assert floats_are_close(1.0, self.cdf(1.0)), self.cdf(1.0)
        xs = torch.linspace(0.0, 1.0, 100)
        for x in xs:
            x = float(x)
            p = self.cdf(x)
            x2 = self.inv_cdf(p)
            assert x2 <= x, (x, p, x2)
            assert self.cdf(x2) == p, (x, p, x2, self.cdf(x2))


@dataclass(frozen=True, eq=True, order=True, unsafe_hash=False)
class BetaDistOver0to1(DistOver0to1):
    a: float
    b: float

    @property
    @lru_cache(maxsize=1)
    def dist(self):
        return scipy.stats.beta(a=self.a, b=self.b, loc=0, scale=1)

    def pdf(self, xs: float) -> float:
        return scipy.stats.beta.pdf(xs, a=self.a, b=self.b, loc=0, scale=1)

    def many_pdf(self, xs: np.ndarray) -> np.ndarray:
        return scipy.stats.beta.pdf(xs, a=self.a, b=self.b, loc=0, scale=1)

    def cdf(self, x: float) -> float:
        return scipy.stats.beta.cdf(x, a=self.a, b=self.b, loc=0, scale=1)

    def inv_cdf(self, p: float) -> float:
        return scipy.stats.beta.ppf(p, a=self.a, b=self.b, loc=0, scale=1)
        # return self.dist.ppf(p)


class ContFromDiscrete(DistOver0to1):
    def __init__(self, probs: np.ndarray):
        self.probs = probs
        self.cum_probs = np.cumsum(probs)
        self.bar_edges = np.linspace(0, 1, len(probs) + 1)
        self.bar_width = self.bar_edges[1] - self.bar_edges[0]
        super().__init__()

    def __repr__(self):
        return f"ContFromDiscrete(probs={self.probs})"

    def pdf(self, x: float) -> float:
        if x == 1.0:
            # Arbitrary choice for infinitesimal final point
            return self.probs[-1] / self.bar_width
        index, remainder = divmod(x, self.bar_width)
        index = int(index)
        return self.probs[index] / self.bar_width

    def cdf(self, x: float) -> float:
        index, remainder = divmod(x, self.bar_width)
        index = int(index)
        prev_bar_cum_prob = self.cum_probs[index - 1] if index > 0 else 0.0
        cur_bar_cum_prob = self.probs[index] * (remainder / self.bar_width) if remainder > 0 else 0.0
        return prev_bar_cum_prob + cur_bar_cum_prob

    def inv_cdf(self, p: float) -> float:
        if p == 0.0:
            return 0.0
        leftovers = self.cum_probs - p
        index = np.argmax(leftovers >= 0)
        x_base = self.bar_edges[index]
        prev_bar_cum_prob = self.cum_probs[index - 1] if index > 0 else 0.0
        cur_bar_cum_prob = p - prev_bar_cum_prob
        if self.probs[index] == 0.0:
            ic(self.probs, index, p, cur_bar_cum_prob)
            raise ValueError("Cannot invert cdf for a bar with zero probability")
        remainder = cur_bar_cum_prob * self.bar_width / self.probs[index]
        return x_base + remainder


class UniformDistOver0to1(DistOver0to1):

    def pdf(self, _) -> float:
        return 1.0

    def cdf(self, x: float) -> float:
        return x

    def inv_cdf(self, p: float) -> float:
        return p


UNIFORM_DIST = UniformDistOver0to1()


@dataclass(frozen=True, eq=True, order=True, unsafe_hash=False)  # , config=PydanticDCCfg)
class SimplePolyDistOver0to1(DistOver0to1):
    power: int

    def cdf(self, x: float) -> float:
        return x ** self.power

    def pdf(self, x: float) -> float:
        return self.power * x ** (self.power - 1)

    def inv_cdf(self, p: float) -> float:
        return p ** (1 / self.power)


@dataclass(frozen=True, eq=True, order=True, unsafe_hash=False)  # , config=PydanticDCCfg)
class NormalDistributionsNormalised(DistOver0to1):
    mean_sf: float
    std_sf: float
    mean_bd: float
    std_bd: float

    @property
    def _sf_normal(self):
        return torch.distributions.normal.Normal(self.mean_sf, self.std_sf)

    @property
    def _bd_normal(self):
        return torch.distributions.normal.Normal(self.mean_bd, self.std_bd)

    def cdf(self, x_uniform: float) -> float:
        p = x_uniform
        x_normal = self._sf_normal.icdf(torch.tensor(p))
        p_normal_safe = self._bd_normal.cdf(x_normal)
        return float(p_normal_safe.item())

    def inv_cdf(self, p: float) -> float:
        x_normal = self._bd_normal.icdf(torch.tensor(p))
        p_uniform = self._sf_normal.cdf(x_normal)
        return float(p_uniform.item())


NORMALISED_NORMAL_APPROX_OF_RW = NormalDistributionsNormalised(
    mean_sf=-1.0, std_sf=2.0,
    mean_bd=2.5, std_bd=2.0
)

# TODO, write some unittests for the above classes
# TODO, also might be good to simplify somewhat
