import math

import numpy as np
import torch
from jaxtyping import Float
from torch import Tensor

from app.SB import SchrodingerBridge


class Sinkhorn(torch.nn.Module, SchrodingerBridge):
    def getC(
        self, x: Float[Tensor, "data_X feature"], y: Float[Tensor, "data_Y feature"]
    ) -> Float[Tensor, "data_X data_Y"]:
        return torch.cdist(x, y, p=2.0) ** 2 / 2

    def __init__(
        self,
        x_w: Float[Tensor, "data_X"],
        x_spt: Float[Tensor, "data_X feature"],
        y_w: Float[Tensor, "data_Y"],
        y_spt: Float[Tensor, "data_Y feature"],
        eps: float,
        n_iter: int,
        stopThr: float = 1e-5,
    ):
        super(Sinkhorn, self).__init__()

        self.n_iter = n_iter
        self.x_w = x_w
        self.x_spt = x_spt
        self.y_w = y_w
        self.y_spt = y_spt
        self.dim = x_spt.shape[1]  # dimension
        self.stopThr = stopThr
        self.eps = eps

        self.sinkhorn_()

    # @stop_watch
    def sinkhorn(
        self,
        C: Float[Tensor, "data_X data_Y"],
        eps: float,
        x_w: Float[Tensor, "data_X"],
        y_w: Float[Tensor, "data_Y"],
        n_iter: int,
    ):
        log_ = Sinkhorn._sinkhorn_log(
            x_w=x_w,
            y_w=y_w,
            C=C,
            eps=eps,
            n_iter=n_iter,
            stopThr=self.stopThr,
        )

        print(f"stopped in {log_['niter']}/{n_iter} iterations")

        # It is returned in the form of f/epsilon, so multiply by epsilon
        u = eps * log_["log_u"]
        v = eps * log_["log_v"]

        return u, v

    def _sinkhorn_log(
        x_w: Float[Tensor, "data_X"],
        y_w: Float[Tensor, "data_Y"],
        C: Float[Tensor, "data_X data_Y"],
        eps: float,
        n_iter: int,
        stopThr: float,
    ):
        Mr = -C / eps

        a = x_w
        b = y_w
        loga = a.log()
        logb = b.log()

        u = torch.ones_like(a)
        v = torch.ones_like(b)

        def get_logT(u, v):
            return Mr + u[:, None] + v[None, :]

        for ii in range(n_iter):
            u = loga - torch.logsumexp(Mr + v[None, :], 1)
            v = logb - torch.logsumexp(Mr + u[:, None], 0)

            # u, v
            tmp = torch.sum(torch.exp(get_logT(u, v)), 1)
            err = torch.norm(tmp - a)

            # v, u
            # tmp = torch.sum(torch.exp(get_logT(u, v)), 0)
            # err = torch.norm(tmp - b)

            if err < stopThr:
                break

        return {
            "log_u": u,
            "log_v": v,
            "niter": ii + 1,
        }

    def sinkhorn_(self):
        self.C: Float[Tensor, "data_X data_Y"] = self.getC(self.x_spt, self.y_spt)

        with torch.no_grad():
            self.u, self.v = self.sinkhorn(
                self.C,
                self.eps,
                self.x_w,
                self.y_w,
                self.n_iter,
            )

    def get_coupling(
        self,
        C: Float[Tensor, "data_X data_Y"],
        u: Float[Tensor, "data_X"],
        v: Float[Tensor, "data_Y"],
    ) -> Float[Tensor, "data_X data_Y"]:
        return torch.exp((u.view(-1, 1) + v.view(1, -1) - C) / self.eps)

    def coupling(self) -> Float[Tensor, "data_X data_Y"]:
        return self.get_coupling(self.C, self.u, self.v)

    def get_drift(
        self, x: Float[Tensor, "batch_drift feature"], t: Float[Tensor, "batch_drift"] | float
    ) -> Float[Tensor, "batch_drift feature"]:
        if isinstance(t, float):
            t = torch.full((x.shape[0],), t, device=x.device)
        return self.get_drift_without_div_time(x, t) / (1 - t).view(-1, 1)

    def get_drift_without_div_time(
        self, x: Float[Tensor, "batch_drift feature"], t: Float[Tensor, "batch_drift"] | float
    ) -> Float[Tensor, "batch_drift feature"]:
        if isinstance(t, float):
            t = torch.full((x.shape[0],), t, device=x.device)
        # x = x.clone().detach()

        M = torch.sum((x.unsqueeze(1) - self.y_spt.unsqueeze(0)) ** 2, dim=-1) / (2 * (1 - t)).view(-1, 1)
        K = (self.v.unsqueeze(0) - M) / self.eps
        gammaz = -torch.max(K, dim=1).values
        K_shift = K + gammaz.view(-1, 1)
        exp_ = torch.exp(K_shift)

        top_ = exp_ @ self.y_spt
        bot_ = exp_.sum(dim=1, keepdim=True)
        entmap = top_ / bot_

        return -x + entmap

    def sample_from_source(self, size: int) -> Float[Tensor, "size feature"]:
        return self.x_spt[torch.randint(0, self.x_spt.shape[0], (size,))]

    def sample_from_target(self, size: int) -> Float[Tensor, "size feature"]:
        return self.y_spt[torch.randint(0, self.y_spt.shape[0], (size,))]

    def sample_euler_maruyama(
        self, x: Float[Tensor, "batch_drift feature"], n_steps: int
    ) -> Float[Tensor, "batch_drift n_steps + 1 feature"]:
        epsilon = self.eps
        t = 0.0
        dt = 0.99 / n_steps
        trajectory = [x]

        for _ in range(n_steps):
            # print(f"{self.get_drift(x,t)=}")
            noise = torch.randn_like(x, device=x.device)
            x = x + dt * self.get_drift(x, t) + math.sqrt(dt * epsilon) * noise
            t += dt
            trajectory.append(x)

        return torch.stack(trajectory, dim=1)

    def sample_from_drift(
        self, x: Float[Tensor, "batch_drift feature"], n_steps: int, time: float = 1.0
    ) -> Float[Tensor, "batch_drift feature"]:
        epsilon = self.eps
        t = 0.0
        dt = time / n_steps

        for _ in range(n_steps):
            noise = torch.randn_like(x, device=x.device)
            x = x + dt * self.get_drift(x, t) + math.sqrt(dt * epsilon) * noise
            t += dt

        return x

    def sample_map(
        self, sample_x_size: int, sample_y_per_x_size: int = 1
    ) -> tuple[Float[Tensor, "sample_size feature"], Float[Tensor, "sample_size feature"]]:
        pi = self.coupling()

        if sample_y_per_x_size == 1:
            flat_pi = pi.view(-1)
            idx = torch.multinomial(flat_pi, num_samples=sample_x_size, replacement=True)
            row = idx // pi.shape[1]
            col = idx % pi.shape[1]
            return self.x_spt[row], self.y_spt[col]

        x_samples_idxes = torch.multinomial(pi.sum(dim=1), num_samples=sample_x_size, replacement=True)

        y_samples_idxes = []
        for x in x_samples_idxes:
            p_y_given_x = pi[x] / pi[x].sum()
            y_samples_idxes_given_x = torch.multinomial(p_y_given_x, num_samples=sample_y_per_x_size, replacement=True)
            y_samples_idxes.extend(y_samples_idxes_given_x)

        x_samples_repeated = np.repeat(x_samples_idxes, sample_y_per_x_size)

        return self.x_spt[x_samples_repeated], self.y_spt[y_samples_idxes]

    def sample_at_time_moment(
        self, t: float, sample_size: int, with_coupling: bool = False
    ) -> (
        Float[Tensor, "sample_size feature"]
        | tuple[
            Float[Tensor, "sample_size feature"],
            Float[Tensor, "sample_size feature"],
            Float[Tensor, "sample_size feature"],
        ]
    ):
        x0, x1 = self.sample_map(sample_size)

        noise = torch.randn_like(x0) * self.eps

        xt = (1 - t) * x0 + t * x1 + t * (1 - t) * noise

        if with_coupling:
            return x0, xt, x1

        return xt

    def sample_at_times_moment(
        self, t: Float[Tensor, "sample_size"], with_coupling: bool = False
    ) -> (
        Float[Tensor, "sample_size feature"]
        | tuple[
            Float[Tensor, "sample_size feature"],
            Float[Tensor, "sample_size feature"],
            Float[Tensor, "sample_size feature"],
        ]
    ):
        x0, x1 = self.sample_map(t.shape[0])
        t = t.view(-1, 1)

        noise = torch.randn_like(x0) * self.eps

        xt = (1 - t) * x0 + t * x1 + t * (1 - t) * noise

        if with_coupling:
            return x0, xt, x1

        return xt

    def sample_at_time_moment_given_x(self, x: Float[Tensor, "feature"], t: float) -> Float[Tensor, "feature"]:
        y = self.sample_from_drift(x.unsqueeze(0), 1000, t)
        return y[0]
