from abc import ABCMeta, abstractmethod

import torch
from jaxtyping import Float
from torch import Tensor
from torch.distributions.multivariate_normal import MultivariateNormal

from app.utils import sqrtm


class SchrodingerBridge(metaclass=ABCMeta):
    @abstractmethod
    def sample_from_source(self, size: int) -> Float[Tensor, "size feature"]:
        """
        Sample from the source distribution.
        """
        pass

    @abstractmethod
    def sample_from_target(self, size: int) -> Float[Tensor, "size feature"]:
        """
        Sample from the target distribution.
        """
        pass

    @abstractmethod
    def sample_at_time_moment(self, t: float, size: int) -> Float[Tensor, "size feature"]:
        """
        Sample at a specific time moment.
        """
        pass

    @abstractmethod
    def get_drift(
        self,
        x: Float[Tensor, "size feature"],
        t: Float[Tensor, "size"] | float,
    ) -> Float[Tensor, "size feature"]:
        """
        Get the drift of the process.
        """
        pass

    @abstractmethod
    def get_drift_without_div_time(
        self,
        x: Float[Tensor, "size feature"],
        t: Float[Tensor, "size"] | float,
    ) -> Float[Tensor, "size feature"]:
        """
        Get the drift of the process without dividing by time.
        """
        pass


class GaussianSB(SchrodingerBridge):
    def __init__(
        self,
        A: Float[Tensor, "feature feature"],
        B: Float[Tensor, "feature feature"],
        eps: float,
    ) -> None:
        self.device = A.device
        self.A = A
        self.B = B
        self.dim = A.shape[0]
        self.eps = eps
        self.Ceps = self._Cepsilon()

    def _Depsilon(self) -> Float[Tensor, "feature feature"]:
        sqrtA = sqrtm(self.A)
        return sqrtm((sqrtA @ self.B @ sqrtA) + ((self.eps**2) / 4) * torch.eye(self.dim, device=self.device))

    def _Cepsilon(self) -> Float[Tensor, "feature feature"]:
        Deps = self._Depsilon()
        Asqrt = sqrtm(self.A)
        Asqrtinv = torch.linalg.inv(Asqrt)
        return (Asqrt @ Deps @ Asqrtinv) - (self.eps / 2) * torch.eye(self.dim, device=self.device)

    def _Sigma_t(self, t: float) -> Float[Tensor, "feature feature"]:
        return (
            ((1 - t) ** 2) * self.A
            + (t**2) * self.B
            + (1 - t) * t * (self.Ceps + self.Ceps.T)
            + self.eps * t * (1 - t) * torch.eye(self.dim, device=self.device)
        )

    def get_drift(
        self,
        x: Float[Tensor, "size feature"],
        t: Float[Tensor, "size"] | float,
    ) -> Float[Tensor, "size feature"]:
        if isinstance(t, torch.Tensor):
            t = t[0].item()
        Σ_t = self._Sigma_t(t)
        Pt = t * self.B + (1 - t) * self.Ceps
        Qt = (1 - t) * self.A + t * self.Ceps
        St = Pt - Qt.T - (self.eps * t) * torch.eye(self.dim, device=self.device)
        return x @ torch.linalg.inv(Σ_t).T @ St

    def get_drift_without_div_time(
        self,
        x: Float[Tensor, "size feature"],
        t: Float[Tensor, "size"] | float,
    ) -> Float[Tensor, "size feature"]:
        if isinstance(t, torch.Tensor):
            t = t[0].item()
        return (1 - t) * self.get_drift(x, t)

    def sample_from_source(self, size: int) -> Float[Tensor, "size feature"]:
        dist = MultivariateNormal(torch.zeros(self.dim, device=self.device), self.A)
        return dist.sample((size,))

    def sample_from_target(self, size: int) -> Float[Tensor, "size feature"]:
        dist = MultivariateNormal(torch.zeros(self.dim, device=self.device), self.B)
        return dist.sample((size,))

    def sample_at_time_moment(self, t: float, size: int) -> Float[Tensor, "size feature"]:
        dist = MultivariateNormal(torch.zeros(self.dim, device=self.device), self._Sigma_t(t))
        return dist.sample((size,))
