import torch

def init_matrix(wc_1, wc_2, p, q):
    def f1(a):
        return (a**2) / 2

    def f2(b):
        return (b**2) / 2

    def h1(a):
        return a

    def h2(b):
        return b
    oq = torch.ones(len(q)).view(1, -1)
    op = torch.ones(len(p)).view(-1, 1)
    oq, op = oq.cuda(), op.cuda()

    constc1 = torch.mm(torch.mm(f1(wc_1), p.view(-1, 1)), oq)
    constc2 = torch.mm(op, torch.mm(q.view(1, -1), torch.t(f2(wc_2))))
    constC = constc1 + constc2
    hc1 = h1(wc_1)
    hc2 = h2(wc_2)
    return constC, hc1, hc2

def tensor_product(constC, hC1, hC2, T):
    A = - torch.mm(hC1, T).mm(torch.t(hC2))
    loss_matrix = constC + A
    return loss_matrix

def wotloss(constC, hC1, hC2, T):
    loss_matrix = tensor_product(constC, hC1, hC2, T)
    return torch.sum(loss_matrix * T)

def wotgrad(constC, hC1, hC2, T):
    return 2 * tensor_product(constC, hC1, hC2, T)

def compute_cost(wavelet_coeffs_X1, wavelet_coeffs_X2, T):
    T = T.unsqueeze(0).repeat(len(wavelet_coeffs_X1),1,1)
    distxy = torch.einsum("bij,bkj->bik", wavelet_coeffs_X1, torch.einsum("bkl,bjl->bkj", wavelet_coeffs_X2, T))

    mu, nu = torch.sum(T, dim=2), torch.sum(T, dim=1)
    distxx = torch.einsum("bij,bj->bi", wavelet_coeffs_X1 ** 2, mu)
    distyy = torch.einsum("bkl,bl->bk", wavelet_coeffs_X2 ** 2, nu)

    cost_matrix = distxx[:, :, None] + distyy[:, None, :] - 2 * distxy
    cost_matrix = cost_matrix.sum(axis=0)

    return cost_matrix

def compute_loss(dx, dy, T, T_prev, F1, F2):
    T = T.unsqueeze(0).repeat(20,1,1)
    T_prev = T_prev.unsqueeze(0).repeat(20,1,1)
    distxx = torch.einsum(
        "ijk,ij,ik->i", dx ** 2, T.sum(dim=2), T_prev.sum(dim=2)
    )
    distyy = torch.einsum(
        "ijk,ij,ik->i", dy ** 2, T.sum(dim=1), T_prev.sum(dim=1)
    )
    distxy = torch.sum(
        torch.einsum("kij,kjl->kil", dx, T)
        * torch.einsum("kij,kjl->kil", T_prev, dy),
        dim=(1, 2),
    )
    loss = distxx + distyy - 2 * distxy

    return loss

def compute_local_cost(pi, a, dx, b, dy, eps, rho, rho2, complete_cost=True):
    """Compute the local cost by averaging the distortion with the current
    transport plan.

    Parameters
    ----------
    pi: torch.Tensor of size [Batch, size_X, size_Y]
    transport plan used to compute local cost

    a: torch.Tensor of size [Batch, size_X]
    Input measure of the first mm-space.

    dx: torch.Tensor of size [Batch, size_X, size_X]
    Input metric of the first mm-space.

    b: torch.Tensor of size [Batch, size_Y]
    Input measure of the second mm-space.

    dy: torch.Tensor of size [Batch, size_Y, size_Y]
    Input metric of the second mm-space.

    eps: float
    Strength of entropic regularization.

    rho: float
    Strength of penalty on the first marginal of pi.

    rho2: float
    Strength of penalty on the first marginal of pi. If set to None it is
    equal to rho.

    complete_cost: bool
    If set to True, computes the full local cost, otherwise it computes the
    cross-part on (X,Y) to reduce computational complexity.

    Returns
    ----------
    lcost: torch.Tensor of size [Batch, size_X, size_Y]
    local cost depending on the current transport plan.
    """
    distxy = torch.einsum("ij,kj->ik", dx, torch.einsum("kl,jl->kj", dy, pi))
    kl_pi = torch.sum(
        pi * (pi / (a[:, None] * b[None, :]) + 1e-10).log()
    )
    if not complete_cost:
        return - 2 * distxy + eps * kl_pi

    mu, nu = torch.sum(pi, dim=1), torch.sum(pi, dim=0)
    distxx = torch.einsum("ij,j->i", dx ** 2, mu)
    distyy = torch.einsum("kl,l->k", dy ** 2, nu)

    lcost = (distxx[:, None] + distyy[None, :] - 2 * distxy) + eps * kl_pi
    if rho < float("Inf"):
        lcost = (
                lcost
                + rho
                * torch.sum(mu * (mu / a + 1e-10).log())
        )
    if rho2 < float("Inf"):
        lcost = (
                lcost
                + rho2
                * torch.sum(nu * (nu / b + 1e-10).log())
        )
    return lcost