# Based on https://github.com/atong01/conditional-flow-matching/blob/main/
# torchcfm/optimal_transport.py


import numpy as np
import ot as pot
import torch
from scipy.optimize import linear_sum_assignment
from src.models.ot_samplers.base import BaseSampler
from functools import partial


# theory in 3.2.3 of https://arxiv.org/pdf/2302.00482
# https://github.com/atong01/conditional-flow-matching/blob/main/runner/src/models/components/optimal_transport.py
class OTPlanSampler(BaseSampler):
    """OTPlanSampler implements sampling coordinates according to an squared L2 OT plan with
    different implementations of the plan calculation."""

    def __init__(
        self,
        method: str,
        reg: float = 0.05,
        reg_m: float = 1.0,
        normalize_cost=False,
        **kwargs,
    ):
        # ot_fn should take (a, b, M) as arguments where a, b are marginals and
        # M is a cost matrix
        if method == "exact":
            self.ot_fn = pot.emd
        elif method == "sinkhorn":
            self.ot_fn = partial(pot.sinkhorn, reg=reg)
        elif method == "unbalanced":
            self.ot_fn = partial(
                pot.unbalanced.sinkhorn_knopp_unbalanced, reg=reg, reg_m=reg_m
            )
        elif method == "partial":
            self.ot_fn = partial(
                pot.partial.entropic_partial_wasserstein, reg=reg
            )
        else:
            raise ValueError(f"Unknown method: {method}")
        self.reg = reg
        self.reg_m = reg_m
        self.normalize_cost = normalize_cost
        self.kwargs = kwargs
        self.modifies_target_indices = True

    def get_map(self, x0, x1):
        a, b = pot.unif(x0.shape[0]), pot.unif(x1.shape[0])
        if x0.dim() > 2:
            x0 = x0.reshape(x0.shape[0], -1)
        if x1.dim() > 2:
            x1 = x1.reshape(x1.shape[0], -1)
        x1 = x1.reshape(x1.shape[0], -1)
        M = torch.cdist(x0, x1) ** 2
        if self.normalize_cost:
            M = M / M.max()
        p = self.ot_fn(a, b, M.detach().cpu().numpy())
        if not np.all(np.isfinite(p)):
            print("ERROR: p is not finite")
            print(p)
            print("Cost mean, max", M.mean(), M.max())
            print(x0, x1)
        return p

    def sample_map(self, pi, batch_size):
        p = pi.flatten()
        p = p / p.sum()
        choices = np.random.choice(
            pi.shape[0] * pi.shape[1], p=p, size=batch_size
        )
        return np.divmod(choices, pi.shape[1])

    def transport(self, x0, x1):
        pi = self.get_map(x0, x1)
        i, j = self.sample_map(pi, x0.shape[0])
        return x0[i], x1[j], j


class EquivariantOTSampler(BaseSampler):
    def __init__(self):
        self.plan = "rp"
        self.plan_map = {
            "r": self._rotate,
            "p": self._permute,
        }
        self.modifies_target_indices = False

    def _rotate(self, x0, x1):
        assert x0.shape == x1.shape, "Tensor dimensions must match"
        H = torch.matmul(x0.transpose(0, 1), x1)
        U, S, Vt = torch.linalg.svd(H)
        d = torch.det(torch.matmul(Vt.transpose(0, 1), U.transpose(0, 1)))
        flip = d < 0.0
        if flip.any().item():
            Vt[flip, -1] *= -1.0
        R = torch.matmul(Vt.transpose(0, 1), U.transpose(0, 1))
        x0 = torch.matmul(x0, R.transpose(0, 1))
        return x0, x1

    def _permute(self, x0, x1):
        assert x0.shape == x1.shape, "Tensor dimensions must match"
        cost_matrix = (x0.unsqueeze(0) - x1.unsqueeze(1)).norm(dim=-1)
        row_indices, _ = linear_sum_assignment(cost_matrix.cpu().numpy())
        x0 = x0[row_indices]
        return x0, x1

    def transport(self, x0, x1):
        if x0.dim() > 2:
            x0 = x0.reshape(x0.shape[0], -1)
        if x1.dim() > 2:
            x1 = x1.reshape(x1.shape[0], -1)
        for c in self.plan:
            if c not in self.plan_map:
                raise ValueError(f"Unknown plan operation: {c}")
            x0, x1 = self.plan_map[c](x0, x1)

        return (
            x0.reshape(x0.shape[0], -1, 3),
            x1.reshape(x0.shape[0], -1, 3),
            None,
        )
