import numpy as np
import torch
import ot


class OptimalTransportSolver:
    def __init__(self, mode, method="emd", reg=None, reg_m=None):
        """
        Initialize the optimal transport solver for time-based pairing.

        Parameters:
        -----------
        method : str, {'emd', 'sinkhorn', 'unbalanced_sinkhorn'}
            Specifies the POT method to use:
              - 'emd': uses ot.emd.
              - 'sinkhorn': uses ot.sinkhorn (requires reg).
              - 'unbalanced_sinkhorn': uses ot.unbalanced.sinkhorn_knopp_unbalanced (requires reg and reg_m).
        reg : float, optional
            Regularization parameter for Sinkhorn (if applicable).
        reg_m : float, optional
            Mass regularization parameter for unbalanced Sinkhorn (if applicable).
        """
        self.mode = mode
        self.method = method
        self.reg = reg
        self.reg_m = reg_m

        if self.method == "emd":
            self.ot_fn = ot.emd
        elif self.method == "sinkhorn":
            if reg is None:
                raise ValueError("Parameter 'reg' must be provided for sinkhorn")
            self.ot_fn = lambda a, b, M: ot.sinkhorn(a, b, M, reg)
        elif self.method == "unbalanced_sinkhorn":
            if reg is None or reg_m is None:
                raise ValueError(
                    "Both 'reg' and 'reg_m' must be provided for unbalanced_sinkhorn"
                )
            self.ot_fn = lambda a, b, M: ot.unbalanced.sinkhorn_knopp_unbalanced(
                a, b, M, reg, reg_m
            )
        else:
            raise ValueError(f"Invalid method: {method}")

    @staticmethod
    def unif(n):
        """Return a uniform probability vector of length n."""
        return np.ones(n) / n

    def solve(self, x0, x1):
        """
        Solve the optimal transport problem between two sets of time indices based on squared differences.

        Parameters:
        -----------
        x0 : np.ndarray or torch.Tensor
            Times (or positions) from the first set. Expected shape (n,) or (n, 1).
        x1 : np.ndarray or torch.Tensor
            Times (or positions) from the second set. Expected shape (m,) or (m, 1).

        Returns:
        --------
        assignments : list of tuples
            A list of paired times as (time_x0, time_x1), sampled from the computed transport plan.
        """
        if self.mode == "time":
            if isinstance(x0, torch.Tensor):
                x0 = x0.detach().cpu().numpy()
            if isinstance(x1, torch.Tensor):
                x1 = x1.detach().cpu().numpy()

            if x0.ndim == 1:
                x0 = x0.reshape(-1, 1)
            if x1.ndim == 1:
                x1 = x1.reshape(-1, 1)

            M = (x0 - x1.T) ** 2
        elif self.mode == "rmsd":
            # M = self.compute_rmsd_matrix(x0, x1)
            # M = self.batched_kabsch_rmsd(x0, x1).detach().cpu().numpy()
            M = self.compute_rmsd_matrix_vectorized(x0, x1)
            # print("M1", M1)
            # print("M2", M2)
        else:
            raise ValueError(f"Invalid mode: {self.mode}")

        M = M.astype(np.float64)
        a = self.unif(x0.shape[0])
        b = self.unif(x1.shape[0])
        p = self.ot_fn(a, b, M)

        indices0 = []
        indices1 = []
        for i in range(p.shape[0]):
            row = p[i]
            total = row.sum()
            if total > 0:
                prob = row / total
                j = np.random.choice(np.arange(p.shape[1]), p=prob)
            else:
                j = np.random.randint(0, p.shape[1])
            indices0.append(i)
            indices1.append(j)
        return indices0, indices1

    def compute_rmsd_matrix_vectorized(self, structures0, structures1):
        """
        Compute the pairwise RMSD cost matrix between two sets of prealigned structures.
        Assumes structures0 has shape (n, N, 3) and structures1 has shape (m, N, 3).

        Parameters:
        -----------
        structures0 : np.ndarray, shape (n, N, 3)
        structures1 : np.ndarray, shape (m, N, 3)

        Returns:
        --------
        rmsd_matrix : np.ndarray, shape (n, m)
            RMSD values between each pair of structures.
        """
        P_centered = structures0 - structures0.mean(axis=1, keepdims=True)
        Q_centered = structures1 - structures1.mean(axis=1, keepdims=True)

        diff = P_centered[:, None, :, :] - Q_centered[None, :, :, :]

        sq_diff = np.sum(diff**2, axis=-1)

        rmsd_matrix = np.sqrt(np.mean(sq_diff, axis=-1))

        return rmsd_matrix

    # def compute_rmsd_matrix(self, structures0, structures1):
    #     """
    #     Compute the pairwise RMSD cost matrix between two sets of backbone structures.

    #     Parameters:
    #     structures0: np.ndarray of shape (n, N, 3)
    #     structures1: np.ndarray of shape (m, N, 3)

    #     Returns:
    #     M: np.ndarray of shape (n, m) where each element M[i,j] is the RMSD between
    #         structures0[i] and structures1[j].
    #     """
    #     n = structures0.shape[0]
    #     m = structures1.shape[0]
    #     # M = np.zeros((n, m))
    #     # for i in range(n):
    #     #     for j in range(m):
    #     #         M[i, j] = self.kabsch_rmsd(structures0[i], structures1[j])
    #     M = [
    #         [self.kabsch_rmsd(structures0[i], structures1[j]) for j in range(m)]
    #         for i in range(n)
    #     ]
    #     return np.array(M)

    # def kabsch_rmsd(self, P, Q):
    #     """
    #     Compute the RMSD between two structures P and Q after optimal alignment.
    #     Both P and Q should be arrays of shape (N, 3), where N is the number of atoms.
    #     """

    #     P = P.astype(np.float32)
    #     Q = Q.astype(np.float32)
    #     # Center both structures
    #     P_centered = P - np.mean(P, axis=0)
    #     Q_centered = Q - np.mean(Q, axis=0)

    #     # Compute the covariance matrix
    #     C = np.dot(P_centered.T, Q_centered)

    #     # Singular Value Decomposition
    #     U, S, Vt = np.linalg.svd(C)

    #     d = np.linalg.det(np.dot(Vt.T, U.T))
    #     if d < 0:
    #         Vt[-1, :] *= -1

    #     # Optimal rotation matrix
    #     R = np.dot(Vt.T, U.T)

    #     # Rotate Q to optimally align with P
    #     Q_aligned = np.dot(Q_centered, R)

    #     # Compute RMSD
    #     rmsd = np.sqrt(np.mean(np.sum((P_centered - Q_aligned) ** 2, axis=1)))
    #     return rmsd

    # def batched_kabsch_rmsd(self, P, Q):
    #     """
    #     Compute pairwise RMSD after Kabsch alignment using batched SVD.

    #     Parameters:
    #     -----------
    #     P : torch.Tensor of shape (n, N, 3)
    #         Batch of n structures (each with N atoms, 3 coordinates).
    #     Q : torch.Tensor of shape (m, N, 3)
    #         Batch of m structures.

    #     Returns:
    #     --------
    #     rmsd : torch.Tensor of shape (n, m)
    #         Pairwise RMSD values after optimal alignment.
    #     """
    #     n, N, _ = P.shape
    #     m, _, _ = Q.shape

    #     P = torch.from_numpy(P).float()
    #     Q = torch.from_numpy(Q).float()
    #     P_centered = P - P.mean(dim=1, keepdim=True)
    #     Q_centered = Q - Q.mean(dim=1, keepdim=True)

    #     C = torch.einsum("ika,jkb->ijab", P_centered, Q_centered)

    #     U, S, Vh = torch.linalg.svd(C)

    #     d = torch.det(torch.matmul(U, Vh))
    #     sign_d = torch.where(d < 0, -1.0, 1.0).unsqueeze(-1).unsqueeze(-1)
    #     correction = (
    #         torch.eye(3, device=P.device)
    #         .unsqueeze(0)
    #         .unsqueeze(0)
    #         .expand(n, m, 3, 3)
    #         .clone()
    #     )
    #     correction[..., -1, -1] = sign_d.squeeze()
    #     R = torch.matmul(
    #         torch.matmul(Vh.transpose(-2, -1), correction), U.transpose(-2, -1)
    #     )

    #     Q_exp = Q_centered.unsqueeze(0).expand(n, m, N, 3)
    #     Q_aligned = torch.einsum("ijab,ijnb->ijna", R, Q_exp)
    #     P_exp = P_centered.unsqueeze(1).expand(n, m, N, 3)

    #     diff = P_exp - Q_aligned
    #     rmsd = torch.sqrt(torch.mean(torch.sum(diff**2, dim=-1), dim=-1))

    #     return rmsd
