from abc import abstractmethod, ABC
import polars as pl
from .num import eps, eigen_eps
from collections.abc import Iterable, Iterator, Sequence
from torch.nn import functional as F
from dataclasses import dataclass
from functools import cache, reduce
import itertools as it
import math
import numpy as np
import torch as th
import typing as tp
import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib.figure import Figure
from matplotlib.axes import Axes
from matplotlib import patheffects


def make_P_(device):
    P = th.tensor([[1, 1, -2], [1, -1.0, 0]], device=device)
    P = P / th.linalg.norm(P, axis=-1, keepdims=True)
    return P


def make_P(device: str | th.device = "cpu"):
    return make_P_(device).detach().clone()


def bv(mu, sigma):
    assert mu.shape == (2,)
    assert sigma.shape == (2, 2)
    mu1 = mu[0]
    mu2 = mu[1]

    std1 = th.sqrt(sigma[0, 0].clip(0)).clip(eps)
    std2 = th.sqrt(sigma[1, 1].clip(0)).clip(eps)
    rho = ((sigma[1, 0] + sigma[0, 1]) / (2 * std1 * std2).clip(eps)).clip(-1, 1)
    one_minus_rho_sq = (1 - rho * rho).clip(eps)
    log_norm_const = (
        math.log(2 * math.pi)
        + th.log(std1)
        + th.log(std2)
        + 0.5 * th.log(one_minus_rho_sq)
    )

    def log_pdf_f(x1, x2=None):
        if x2 is None:
            assert x1.shape[-1] == 2
            x1, x2 = x1[..., 0], x1[..., 1]

        d1 = (x1 - mu1) / std1
        d2 = (x2 - mu2) / std2
        exponent_val = d1 * d1 - 2 * rho * d1 * d2 + d2 * d2
        r = -exponent_val / (2 * one_minus_rho_sq) - log_norm_const
        return r

    return log_pdf_f


Constraint = Sequence[int]
Constraints = Sequence[Constraint]


def pair_iter[T](x: Iterable[T]) -> Iterable[tuple[T, T]]:
    pi: T | None = None
    for i in x:
        if pi is not None:
            yield (pi, i)
        pi = i


def apply_constraint(constraint: Constraints, x: th.Tensor):
    return reduce(
        th.logical_and,
        (
            reduce(th.logical_and, (x[..., a] > x[..., b] for a, b in it.pairwise(i)))
            for i in constraint
        ),
    )


def stopgrad(x):
    if isinstance(x, th.Tensor):
        return x.detach()
    return x


def LL(sigma):
    # return th.linalg.cholesky_ex(sigma)[0]
    e = th.linalg.eigh(sigma)
    return e.eigenvectors @ th.diag(th.sqrt(e.eigenvalues.clamp(eigen_eps)))


def prob_grid(
    c: Constraints, mu: th.Tensor, sigma: th.Tensor, N=1000, bound_width=6, device=None
):
    if device is None:
        device = mu.device
    P = make_P(device)

    mu = P @ mu
    sigma = P @ sigma @ P.T
    sigma = (sigma + sigma.T) / 2 + eps * th.eye(2, device=sigma.device)

    pdf = bv(mu, sigma)
    sl = stopgrad(LL(sigma))
    mus = stopgrad(mu)

    def g(y):
        y = mus + y @ sl.T
        filt = apply_constraint(c, y @ P)
        x = pdf(y).clamp(max=20).exp()
        return th.where(filt, x, 0) * th.abs(th.linalg.det(sl))

    _x = th.linspace(-bound_width, bound_width, steps=N, device=device)
    xs = th.meshgrid(_x, _x, indexing="ij")
    v = g(th.stack(xs, dim=-1))

    v = th.trapezoid(v, _x, dim=-2)
    v = th.trapezoid(v, _x, dim=-1)
    return v


def to_np(x):
    return np.asarray(x).astype(np.float64)


def prob_grid_rec(c: Constraints, mu: th.Tensor, sigma: th.Tensor, N=10000):
    dist = th.distributions.MultivariateNormal(mu, sigma)
    x = dist.rsample((N,))
    return apply_constraint(c, x).sum(-1) / N


class RUM(ABC):
    @abstractmethod
    def sample(self, N: int) -> th.Tensor: ...

    @abstractmethod
    def diff_logprob(self, constraint: Constraints, N: int) -> th.Tensor: ...

    @abstractmethod
    def viz(self, keys=None) -> Figure: ...

    @abstractmethod
    def csv(self, keys=None) -> dict[str, pl.DataFrame]: ...

    def welfare(self, N: int) -> th.Tensor:
        return self.sample(N).max(-1).values.mean()

    @abstractmethod
    def __getitem__(self, index: th.Tensor | slice) -> tp.Self: ...

    @abstractmethod
    def __len__(self) -> int: ...


def imshow_text(
    ax: Axes, x, *, alpha=None, xlabel=None, ylabel=None, vmin=None, vmax=None
):  # pyright: ignore
    im = ax.imshow(x, alpha=alpha, vmax=vmax, vmin=vmin)
    if xlabel is not None:
        ax.set_xticks(range(len(xlabel)), labels=xlabel, rotation=-90)
    if ylabel is not None:
        ax.set_yticks(range(len(ylabel)), labels=ylabel)

    for i in range(x.shape[0]):
        for j in range(x.shape[1]):
            txt = ax.text(j, i, f"{x[i, j]:.2f}", ha="center", va="center", color="w", size='x-large')
            txt.set_path_effects(
                [patheffects.withStroke(linewidth=1, foreground="black")]
            )
    ax.set_aspect("equal")
    return im


@dataclass
class Probit3(RUM):
    mu: th.Tensor
    si: th.Tensor
    ll: th.Tensor

    def __init__(self, mu, si):
        self.mu = mu
        self.si = si
        self.ll = LL(si)

    @staticmethod
    def parse_phi_sig(phi, sig):
        match sig:
            case "none":
                pass
            case "tanh":
                phi = th.tanh(phi) * math.sqrt(1 / phi.shape[-1])
            case "norm":
                phi = F.layer_norm(phi, phi.shape[-1:])
            case "softplus":
                phi = th.nn.functional.softplus(phi)
            case "sigmoid":
                phi = th.nn.functional.sigmoid(phi)
            case _:
                raise NotImplementedError(sig)

        return phi

    @staticmethod
    def gen_si(mu, phi, mode):
        match mode:
            case "dot":
                si = phi @ phi.transpose(-2, -1)
            case "rbf":
                d = phi.unsqueeze(-2) - phi.unsqueeze(-3)
                si = th.exp(-d.square().sum(-1))
            case "spherical":
                assert phi is not None
                assert phi.shape[-1] == 1
                phi = phi[..., 0]
                si = th.diag_embed(phi)
            case "unit":
                assert phi is None
                si = th.eye(mu.shape[-1], device=mu.device).broadcast_to(
                    (*mu.shape, mu.shape[-1])
                )
            case _:
                raise NotImplementedError(mode)
        assert si.size(-1) == si.size(-2)
        I = th.eye(si.size(-1), dtype=phi.dtype, device=phi.device)
        si = (si + si.T) / 2 + eigen_eps * I
        return si

    @classmethod
    def from_flat(
        cls,
        mu,
        phi,
        *,
        phi_sig: str,
        mode: tp.Literal["dot", "rbf", "spherical", "unit"],
    ) -> tp.Self:
        assert mu.shape[-1] == 1
        assert mu.shape[:-1] == phi.shape[:-1]
        mu = mu[..., 0]
        mu = mu - th.mean(mu, dim=-1, keepdim=True)
        phi = cls.parse_phi_sig(phi, phi_sig)
        si = cls.gen_si(mu, phi, mode)
        return cls(mu, si)

    @property
    def std(self) -> tuple[th.Tensor, th.Tensor]:
        I = th.eye(self.si.shape[-1], device=self.device)
        In = I - th.ones_like(I) / I.shape[-1]
        si = In @ self.si @ In
        si = (si + si.T) / 2

        scale = th.diag(si).abs().max().clamp(1e-10)
        return self.mu / th.sqrt(scale), si / scale

    def __repr__(self):
        mu, si = map(lambda x: x.detach().cpu().numpy(), self.std)
        return f"""mu:
        {mu}
        sigma:
        {si}
"""

    @property
    def device(self) -> th.device:
        return self.mu.device

    @property
    def N(self) -> int:
        return self.mu.shape[-1]

    def sample(self, N: int) -> th.Tensor:
        return self.mu + th.randn((N, self.N), device=self.device) @ self.ll.T

    def diff_logprob(self, constraint: Constraints, N: int) -> th.Tensor:
        return prob_grid(constraint, self.mu, self.si, N=N).clamp(eps).log()

    def viz(self, keys=None) -> Figure:
        mu, si = map(lambda x: x.detach().cpu().numpy(), self.std)
        fig, axes = plt.subplots(
            6,
            1,
            figsize=(len(mu), 2 * len(mu) + 2),
            height_ratios=[len(mu), 0.2, 1, 0.2, len(mu), 0.2],
        )

        axes = tp.cast(list[Axes], axes)
        im0 = imshow_text(axes[0], si)
        axes[0].set_title("$\Sigma$")

        fig.colorbar(im0, shrink=0.5, location="bottom", cax=axes[1])

        im1 = imshow_text(axes[2], mu[None])
        axes[2].set_title("$\mu$")

        diag = np.sqrt(np.diag(si))

        im2 = imshow_text(axes[4], si / diag[:, None] / diag[None, :], vmin=-1, vmax=1)
        axes[4].set_title("correlations")
        fig.colorbar(im2, shrink=0.5, location="bottom", cax=axes[5])

        if keys is None:
            keys = np.arange(len(mu))
        else:
            assert len(keys) == len(mu)
        keys = list(map(str, keys))

        ticks = np.arange(len(keys))
        axes[0].set_xticks(ticks, labels=keys, rotation="vertical")
        axes[0].set_yticks(ticks, labels=keys)
        axes[2].set_xticks(ticks, labels=keys, rotation="vertical")

        axes[4].set_xticks(ticks, labels=keys, rotation="vertical")
        axes[4].set_yticks(ticks, labels=keys)

        fig.colorbar(im1, shrink=0.5, location="bottom", cax=axes[3])
        fig.tight_layout()
        return fig

    def csv(self, keys=None) -> dict[str, pl.DataFrame]:
        mu, si = map(lambda x: x.detach().cpu().numpy(), self.std)
        if keys is None:
            keys = np.arange(len(mu))
        keys = list(map(str, keys))
        df1 = pl.DataFrame({"key": keys})
        df2 = pl.DataFrame(si)
        df2.columns = keys

        return dict(mu=pl.DataFrame(dict(key=keys, mu=mu)), sigma=df1.with_columns(df2))

    def __len__(self) -> int:
        return self.mu.size(-1)

    def __getitem__(self, index):
        return type(self)(self.mu[..., index], self.si[..., index, :][..., :, index])


@dataclass
class Probit2(Probit3):
    def __init__(self, mu, si):
        super().__init__(mu, si)

    def diff_logprob(self, constraint: Constraints, N: int) -> th.Tensor:
        assert len(constraint) == 1
        i, j = constraint[0]
        u = self.mu[i] - self.mu[j]
        s = (self.si[i, i] + self.si[j, j] - self.si[i, j] - self.si[j, i]).clip(
            eigen_eps
        )
        z = u / th.sqrt(s)
        return log_ndtr(z)


@dataclass
class ProbitN(Probit3):
    def __init__(self, mu, si):
        super().__init__(mu, si)

    def diff_logprob(self, constraint: Constraints, N: int) -> th.Tensor:
        device = self.mu.device

        def g(y):
            y = self.mu + y @ self.sl.T
            filt = apply_constraint(constraint, y)
            x = (-(y * y) / 2).clamp(max=20).exp() / (
                (2 * np.pi) ** (self.mu.size(-1) / 2) * th.linalg.det(self.sl)
            )
            return th.where(filt, x, 0) * th.abs(th.linalg.det(self.sl))

        _x = th.linspace(-6, 6, steps=N, device=device)

        xs = th.meshgrid(*([_x] * self.mu.size(-1)), indexing="ij")
        v = g(th.stack(xs, dim=-1))

        for i in range(self.mu.size(-1)):
            v = th.trapezoid(v, _x, dim=self.mu.ndim - 1 + i)
        return v


# https://botorch-dev.readthedocs.io/en/latest/_modules/botorch/utils/probability/utils.html
def log_ndtr(x: th.Tensor) -> th.Tensor:
    return log_erfc(-1 / math.sqrt(2) * x) - math.log(2)


def log_erfc(x: th.Tensor) -> th.Tensor:
    """Computes the logarithm of the complementary error function in a numerically
    stable manner. The GitHub issue https://github.com/pytorch/pytorch/issues/31945
    tracks progress toward moving this feature into PyTorch in C++.

    Args:
        x: An input tensor with dtype torch.float32 or torch.float64.

    Returns:
        A tensor of values of the same type and shape as x containing log(erfc(x)).
    """
    is_pos = x > 0
    x_pos = x.masked_fill(~is_pos, 0)
    x_neg = x.masked_fill(is_pos, 0)
    return th.where(
        is_pos,
        th.log(th.special.erfcx(x_pos)) - x_pos.square(),
        th.log(th.special.erfc(x_neg)),
    )


@dataclass
class Logit(RUM):
    logit: th.Tensor
    temp: float
    eps: float

    @staticmethod
    def from_flat(x, *, temp, eps=1e-10):
        assert x.shape[-1] == 1
        x = x[..., 0]
        x = x - th.mean(x, dim=-1, keepdim=True)
        return Logit(x, temp, eps)

    @property
    def device(self) -> th.device:
        return self.logit.device

    def sample(self, N: int) -> th.Tensor:
        u = th.rand(
            (*self.logit.shape[:-1], N, self.logit.shape[-1]), device=self.device
        ).clamp(self.eps, 1 - self.eps)
        return self.logit - self.temp * (-u.log()).log() - np.euler_gamma

    def diff_logprob(self, constraint: Constraints, N: int) -> th.Tensor:
        assert len(constraint) == 1
        c = tuple(next(iter(constraint)))
        lse = th.logsumexp(self.logit[..., c], dim=-1)
        p = th.zeros_like(lse)
        for i in c[:-1]:
            li = self.logit[..., i]
            p = p + li - lse
            # log[ exp(lse) - exp(l)]
            lse = lse + th.log(1 - th.exp(li - lse))
        return p

    def viz(self, keys=None) -> Figure:
        if keys is None:
            keys = np.arange(len(self.logit))
        fig, ax = plt.subplots()
        im = imshow_text(ax, self.logit[..., None].detach().cpu().numpy())
        ax.set_yticks(np.arange(len(keys)), labels=keys)
        ax.set_title("logit")
        fig.colorbar(im)
        return fig

    def csv(self, keys=None) -> dict[str, pl.DataFrame]:
        if keys is None:
            keys = np.arange(len(self.logit))
        return {
            "logit": pl.DataFrame(
                {"name": keys, "logit": self.logit.cpu().detach().numpy()}
            )
        }

    def welfare(self, N) -> th.Tensor:
        return self.temp * th.logsumexp(self.logit / self.temp, dim=-1)

    def __len__(self) -> int:
        return self.logit.size(-1)

    def __getitem__(self, index) -> tp.Self:
        return type(self)(self.logit[index], temp=self.temp, eps=self.eps)
