import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from .sliced_ot import emd1D, emd1D_dual, emd1D_dual_backprop

def logsumexp(f, a):
    # stabilized
    assert f.dim() == a.dim()
    if f.dim() > 1:
        xm = torch.amax(f + torch.log(a),dim=1).reshape(-1,1)
        return xm + torch.log(torch.sum(torch.exp(f + torch.log(a) - xm),dim=1)).reshape(-1,1)
    else:
        xm = torch.amax(f + torch.log(a))
        return xm + torch.log(torch.sum(torch.exp(f + torch.log(a) - xm)))


def rescale_potentials(f, g, a, b, rho1, rho2):
    tau = (rho1 * rho2) / (rho1 + rho2)
    transl = tau * (logsumexp(-f / rho1, a) - logsumexp(-g / rho2, b))
    return transl


def kullback_leibler(a, b):
    return (a * (a/b +1e-12).log()).sum(dim=-1) - a.sum(dim=-1) + b.sum(dim=-1)


def sample_projections(num_features, num_projections, dummy_data, seed_proj=None):
    if seed_proj is not None:
        torch.manual_seed(seed_proj)
    projections = torch.normal(mean=torch.zeros([num_features, num_projections]), std=torch.ones([num_features, num_projections])).type(dummy_data.dtype).to(dummy_data.device)
    projections = F.normalize(projections, p=2, dim=0)
    return projections


def project_support(x, y, projections):
    x_proj = (x @ projections).T
    y_proj = (y @ projections).T
    return x_proj, y_proj


def sort_support(x_proj):
    x_sorted, x_sorter = torch.sort(x_proj, -1)
    x_rev_sort = torch.argsort(x_sorter, dim=-1)
    return x_sorted, x_sorter, x_rev_sort


def sample_project_sort_data(x, y, num_projections, seed_proj=None):
    num_features = x.shape[1] # data dim

    # Random projection directions, shape (num_features, num_projections)
    projections = sample_projections(num_features, num_projections, dummy_data=x, seed_proj=seed_proj)

    # 2 ---- Project samples along directions and sort
    x_proj, y_proj = project_support(x, y, projections)
    x_sorted, x_sorter, x_rev_sort = sort_support(x_proj)
    y_sorted, y_sorter, y_rev_sort = sort_support(y_proj)
    return x_sorted, x_sorter, x_rev_sort, y_sorted, y_sorter, y_rev_sort, projections

    

def sliced_unbalanced_ot(a, b, x, y, p, num_projections, rho1, rho2=None, niter=10, mode='backprop', seed_proj=None):
    if rho2 is None:
        rho2 = rho1
    assert mode in ['backprop', 'icdf']

    # 1 ---- draw some random directions
    x_sorted, x_sorter, x_rev_sort, y_sorted, y_sorter, y_rev_sort, projections = sample_project_sort_data(x, y, num_projections)
    a = a[..., x_sorter]
    b = b[..., y_sorter]

    # 3 ----- Prepare and start FW

    # Initialize potentials
    f = torch.zeros_like(a)
    g = torch.zeros_like(b)

    for k in range(niter):
        # Output FW descent direction
        transl = rescale_potentials(f, g, a, b, rho1, rho2)

        # translate potentials
        f = f + transl
        g = g - transl
        # update measures
        A = a * torch.exp(-f / rho1)
        B = b * torch.exp(-g / rho2)
        # solve for new potentials
        if mode == 'icdf':
            fd, gd, loss = emd1D_dual(x_sorted, y_sorted, u_weights=A, v_weights=B, p=p, require_sort=False)
        if mode == 'backprop':
            fd, gd, loss = emd1D_dual_backprop(x_sorted, y_sorted, u_weights=A, v_weights=B, p=p, require_sort=False)
        # default step for FW
        t = 2. / (2. + k)

        f = f + t * (fd - f)
        g = g + t * (gd - g)

    # 4 ----- We are done. Get me out of here !
    # Last iter before output
    transl = rescale_potentials(f, g, a, b, rho1, rho2)
    f, g = f + transl, g - transl
    A, B = a * torch.exp(-f / rho1), b * torch.exp(-g / rho2)
    loss = torch.mean(emd1D(x_sorted, y_sorted, u_weights=A, v_weights=B, p=p, require_sort=False))
    loss = loss + rho1 * torch.mean(kullback_leibler(A, a)) + rho2 * torch.mean(kullback_leibler(B, b))

    # Reverse sort potentials and measures w.r.t order not sample (not sorted)
    f, g = torch.gather(f, 1, x_rev_sort), torch.gather(g, 1, y_rev_sort)
    A, B = torch.gather(A, 1, x_rev_sort), torch.gather(B, 1, y_rev_sort)
    
    return loss, f, g, A, B, projections



def reweighted_sliced_ot(a, b, x, y, p, num_projections, rho1, rho2=None, niter=10, mode='backprop', stochastic_proj=False):
    if rho2 is None:
        rho2 = rho1
    assert mode in ['backprop', 'icdf']

    # 1 ---- draw some random directions
    if not stochastic_proj:
        x_sorted, x_sorter, x_rev_sort, y_sorted, y_sorter, y_rev_sort, projections = sample_project_sort_data(x, y, num_projections)

    # 3 ----- Prepare and start FW

    # Initialize potentials - WARNING: They correspond to non-sorted samples
    f = torch.zeros(x.shape[0], dtype=a.dtype, device=a.device)
    g = torch.zeros(y.shape[0], dtype=a.dtype, device=a.device)

    for k in range(niter):
        # Output FW descent direction
        # translate potentials
        transl = rescale_potentials(f, g, a, b, rho1, rho2)
        f = f + transl
        g = g - transl

        # If stochastic version then sample new directions and re-sort data
        if stochastic_proj:
            x_sorted, x_sorter, x_rev_sort, y_sorted, y_sorter, y_rev_sort, projections = sample_project_sort_data(x, y, num_projections)

        # update measures
        A = (a * torch.exp(-f / rho1))[..., x_sorter]
        B = (b * torch.exp(-g / rho2))[..., y_sorter]
        
        # solve for new potentials
        if mode == 'icdf':
            fd, gd, loss = emd1D_dual(x_sorted, y_sorted, u_weights=A, v_weights=B, p=p, require_sort=False)
        if mode == 'backprop':
            fd, gd, loss = emd1D_dual_backprop(x_sorted, y_sorted, u_weights=A, v_weights=B, p=p, require_sort=False)
        # default step for FW
        t = 2. / (2. + k)
        f = f + t * (torch.mean(torch.gather(fd, 1, x_rev_sort), dim=0) - f)
        g = g + t * (torch.mean(torch.gather(gd, 1, y_rev_sort), dim=0) - g)

    # 4 ----- We are done. Get me out of here !
    # Last iter before output
    transl = rescale_potentials(f, g, a, b, rho1, rho2)
    f, g = f + transl, g - transl
    A, B = (a * torch.exp(-f / rho1))[..., x_sorter], (b * torch.exp(-g / rho2))[..., y_sorter]
    loss = torch.mean(emd1D(x_sorted, y_sorted, u_weights=A, v_weights=B, p=p, require_sort=False))
    A, B = a * torch.exp(-f / rho1), b * torch.exp(-g / rho2)
    loss = loss + rho1 * kullback_leibler(A, a) + rho2 * kullback_leibler(B, b)
    
    return loss, f, g, A, B, projections
    

def sliced_ot(a, b, x, y, p, num_projections, niter=10, mode='backprop', stochastic_proj=False):
    assert mode in ['backprop', 'icdf']

    # 1 ---- draw some random directions
    if not stochastic_proj:
        x_sorted, x_sorter, x_rev_sort, y_sorted, y_sorter, y_rev_sort, projections = sample_project_sort_data(x, y, num_projections)

    # 3 ----- Prepare and start FW

    # Initialize potentials - WARNING: They correspond to non-sorted samples
    f = torch.zeros(x.shape[0], dtype=a.dtype, device=a.device)
    g = torch.zeros(y.shape[0], dtype=a.dtype, device=a.device)

    # Output FW descent direction

    # If stochastic version then sample new directions and re-sort data
    if stochastic_proj:
        x_sorted, x_sorter, x_rev_sort, y_sorted, y_sorter, y_rev_sort, projections = sample_project_sort_data(x, y, num_projections)

    # update measures
    A = a[..., x_sorter]
    B = b[..., y_sorter]
    
    # solve for new potentials
    if mode == 'icdf':
        fd, gd, loss = emd1D_dual(x_sorted, y_sorted, u_weights=A, v_weights=B, p=p, require_sort=False)
    if mode == 'backprop':
        fd, gd, loss = emd1D_dual_backprop(x_sorted, y_sorted, u_weights=A, v_weights=B, p=p, require_sort=False)
    # default step for FW
    f = torch.mean(torch.gather(fd, 1, x_rev_sort), dim=0)
    g = torch.mean(torch.gather(gd, 1, y_rev_sort), dim=0)

    # 4 ----- We are done. Get me out of here !
    loss = torch.mean(emd1D(x_sorted, y_sorted, u_weights=A, v_weights=B, p=p, require_sort=False))
    
    return loss, f, g, projections