import numpy as np
import torch
import ot
import random
import tqdm
from torch.nn.functional import pad
def quantile_function(qs, cws, xs):
    n = xs.shape[0]
    cws = cws.T.contiguous()
    qs = qs.T.contiguous()
    idx = torch.searchsorted(cws, qs, right=False).T
    return torch.gather(xs, 0, torch.clamp(idx, 0, n - 1))
def compute_true_Wasserstein(X,Y,p=2):
    M = ot.dist(X.detach().numpy(), Y.detach().numpy())
    a = np.ones((X.shape[0],)) / X.shape[0]
    b = np.ones((Y.shape[0],)) / Y.shape[0]
    return ot.emd2(a, b, M)
def compute_Wasserstein(M,device='cpu',e=0):
    if(e==0):
        pi = ot.emd([],[],M.cpu().detach().numpy()).astype('float32')
    else:
        pi = ot.sinkhorn([], [], M.cpu().detach().numpy(),reg=e).astype('float32')
    pi = torch.from_numpy(pi).to(device)
    return torch.sum(pi*M)

def rand_projections(dim, num_projections=1000,device='cpu'):
    projections = torch.randn((num_projections, dim),device=device)
    projections = projections / torch.sqrt(torch.sum(projections ** 2, dim=1, keepdim=True))
    return projections


def one_dimensional_Wasserstein_prod(X,Y,theta,p):
    X_prod = torch.matmul(X, theta.transpose(0, 1))
    Y_prod = torch.matmul(Y, theta.transpose(0, 1))
    X_prod = X_prod.view(X_prod.shape[0], -1)
    Y_prod = Y_prod.view(Y_prod.shape[0], -1)
    wasserstein_distance = torch.abs(
        (
                torch.sort(X_prod, dim=0)[0]
                - torch.sort(Y_prod, dim=0)[0]
        )
    )
    wasserstein_distance = torch.mean(torch.pow(wasserstein_distance, p), dim=0,keepdim=True)
    return wasserstein_distance

def one_dimensional_Wasserstein(X, Y,theta, u_weights=None, v_weights=None, p=2):
    if (X.shape[0] == Y.shape[0] and u_weights is None and v_weights is None):
        return one_dimensional_Wasserstein_prod(X,Y,theta,p)
    u_values = torch.matmul(X, theta.transpose(0, 1))
    v_values = torch.matmul(Y, theta.transpose(0, 1))
    n = u_values.shape[0]
    m = v_values.shape[0]
    if u_weights is None:
        u_weights = torch.full(u_values.shape, 1. / n,
                               dtype=u_values.dtype, device=u_values.device)
    elif u_weights.ndim != u_values.ndim:
        u_weights = torch.repeat_interleave(
            u_weights[..., None], u_values.shape[-1], -1)
    if v_weights is None:
        v_weights = torch.full(v_values.shape, 1. / m,
                               dtype=v_values.dtype, device=v_values.device)
    elif v_weights.ndim != v_values.ndim:
        v_weights = torch.repeat_interleave(
            v_weights[..., None], v_values.shape[-1], -1)

    u_sorter = torch.sort(u_values, 0)[1]
    u_values = torch.gather(u_values, 0, u_sorter)

    v_sorter = torch.sort(v_values, 0)[1]
    v_values = torch.gather(v_values, 0, v_sorter)

    u_weights = torch.gather(u_weights, 0, u_sorter)
    v_weights = torch.gather(v_weights, 0, v_sorter)

    u_cumweights = torch.cumsum(u_weights, 0)
    v_cumweights = torch.cumsum(v_weights, 0)

    qs = torch.sort(torch.cat((u_cumweights, v_cumweights), 0), 0)[0]
    u_quantiles = quantile_function(qs, u_cumweights, u_values)
    v_quantiles = quantile_function(qs, v_cumweights, v_values)

    pad_width = [(1, 0)] + (qs.ndim - 1) * [(0, 0)]
    how_pad = tuple(element for tupl in pad_width[::-1] for element in tupl)
    qs = pad(qs, how_pad)

    delta = qs[1:, ...] - qs[:-1, ...]
    diff_quantiles = torch.abs(u_quantiles - v_quantiles)
    return torch.sum(delta * torch.pow(diff_quantiles, p), dim=0)


def BSW(Xs,X,L=10,p=2,device='cpu'):
    dim = X.size(1)
    theta = rand_projections(dim, L, device)
    Xs_prod = torch.matmul(Xs, theta.transpose(0, 1))
    X_prod = torch.matmul(X, theta.transpose(0, 1))
    Xs_prod_sorted = torch.sort(Xs_prod,dim=1)[0]
    X_prod_sorted = torch.sort(X_prod, dim=0)[0]
    wasserstein_distance = torch.abs(Xs_prod_sorted-X_prod_sorted)
    wasserstein_distance = torch.mean(torch.pow(wasserstein_distance, p), dim=1)# K\times L
    sw = torch.mean(wasserstein_distance,dim=1)
    return torch.mean(sw)

def OBSW(Xs,X,L=10,lam=1,p=2,device='cpu'):
    dim = X.size(1)
    theta = rand_projections(dim, L, device)
    Xs_prod = torch.matmul(Xs, theta.transpose(0, 1))
    X_prod = torch.matmul(X, theta.transpose(0, 1))
    Xs_prod_sorted = torch.sort(Xs_prod,dim=1)[0]
    X_prod_sorted = torch.sort(X_prod, dim=0)[0]
    wasserstein_distance = torch.abs(Xs_prod_sorted-X_prod_sorted)
    wasserstein_distance = torch.mean(torch.pow(wasserstein_distance, p), dim=1)# K\times L
    sw = torch.mean(wasserstein_distance,dim=1)
    return torch.mean(sw)+lam*torch.cdist(sw.view(-1,1),sw.view(-1,1),p=1).sum()/(sw.shape[0]*sw.shape[0] - sw.shape[0] )
def BSW_list(Xs,X,L=10,p=2,device='cpu'):
    dim = X.size(1)
    K = len(Xs)
    theta = rand_projections(dim, L, device)
    wasserstein_distance = [one_dimensional_Wasserstein(Xs[i],X,theta,p=p) for i in range(K)]
    wasserstein_distance = torch.stack(wasserstein_distance,dim=0)
    sw = torch.mean(wasserstein_distance,dim=1)
    return torch.mean(sw)

def lowerboundFBSW(Xs,X,L=10,p=2,device='cpu'):
    dim = X.size(1)
    theta = rand_projections(dim, L, device)
    Xs_prod = torch.matmul(Xs, theta.transpose(0, 1))
    X_prod = torch.matmul(X, theta.transpose(0, 1))
    Xs_prod_sorted = torch.sort(Xs_prod,dim=1)[0]
    X_prod_sorted = torch.sort(X_prod, dim=0)[0]
    wasserstein_distance = torch.abs(Xs_prod_sorted-X_prod_sorted)
    wasserstein_distance = torch.mean(torch.pow(wasserstein_distance, p), dim=1)# K\times L
    sw = torch.mean(wasserstein_distance,dim=1)
    return torch.max(sw)

def lowerboundFBSW_list(Xs,X,L=10,p=2,device='cpu'):
    dim = X.size(1)
    K = len(Xs)
    theta = rand_projections(dim, L, device)
    wasserstein_distance = [one_dimensional_Wasserstein(Xs[i], X, theta,p=p) for i in range(K)]
    wasserstein_distance = torch.stack(wasserstein_distance, dim=0)
    sw = torch.mean(wasserstein_distance, dim=1)

    return torch.max(sw)

def FBSW(Xs,X,L=10,p=2,device='cpu'):
    dim = X.size(1)
    theta = rand_projections(dim, L, device)
    Xs_prod = torch.matmul(Xs, theta.transpose(0, 1))
    X_prod = torch.matmul(X, theta.transpose(0, 1))
    Xs_prod_sorted = torch.sort(Xs_prod,dim=1)[0]
    X_prod_sorted = torch.sort(X_prod, dim=0)[0]
    wasserstein_distance = torch.abs(Xs_prod_sorted-X_prod_sorted)
    wasserstein_distance = torch.mean(torch.pow(wasserstein_distance, p), dim=1)# K\times L
    sw = torch.max(wasserstein_distance,dim=0)[0]
    return torch.mean(sw)

def FBSW_list(Xs,X,L=10,p=2,device='cpu'):
    dim = X.size(1)
    K = len(Xs)
    theta = rand_projections(dim, L, device)
    wasserstein_distance = [one_dimensional_Wasserstein(Xs[i], X, theta,p=p) for i in range(K)]
    wasserstein_distance = torch.stack(wasserstein_distance, dim=0)
    sw = torch.max(wasserstein_distance, dim=0)[0]
    return torch.mean(sw)


def EFBSW(Xs,X,L=10,p=2,device='cpu'):
    dim = X.size(1)
    theta = rand_projections(dim, L, device)
    Xs_prod = torch.matmul(Xs, theta.transpose(0, 1))
    X_prod = torch.matmul(X, theta.transpose(0, 1))
    Xs_prod_sorted = torch.sort(Xs_prod,dim=1)[0]
    X_prod_sorted = torch.sort(X_prod, dim=0)[0]
    wasserstein_distance = torch.abs(Xs_prod_sorted-X_prod_sorted)
    wasserstein_distance = torch.mean(torch.pow(wasserstein_distance, p), dim=1)# K\times L
    sw = torch.max(wasserstein_distance, dim=0)[0]
    weight=torch.softmax(sw,dim=-1)
    return torch.sum(weight*sw)

def EFBSW_list(Xs,X,L=10,p=2,device='cpu'):
    dim = X.size(1)
    K = len(Xs)
    theta = rand_projections(dim, L, device)
    wasserstein_distance = [one_dimensional_Wasserstein(Xs[i], X, theta,p=p) for i in range(K)]
    wasserstein_distance = torch.stack(wasserstein_distance, dim=0)
    sw = torch.max(wasserstein_distance, dim=0)[0]
    weight = torch.softmax(sw,dim=-1)
    return torch.sum(weight * sw)


def lowerbound_EFBSW(Xs,X,L=10,p=2,device='cpu'):
    dim = X.size(1)
    theta = rand_projections(dim, L, device)
    Xs_prod = torch.matmul(Xs, theta.transpose(0, 1))
    X_prod = torch.matmul(X, theta.transpose(0, 1))
    Xs_prod_sorted = torch.sort(Xs_prod,dim=1)[0]
    X_prod_sorted = torch.sort(X_prod, dim=0)[0]
    wasserstein_distance = torch.abs(Xs_prod_sorted-X_prod_sorted)
    wasserstein_distance = torch.mean(torch.pow(wasserstein_distance, p), dim=1)# K\times L
    sw = torch.max(wasserstein_distance, dim=0)[0]
    weight=torch.softmax(sw,dim=-1)
    return torch.max(torch.sum(weight.view(1,L)*wasserstein_distance,dim=1))

def lowerbound_EFBSW_list(Xs,X,L=10,p=2,device='cpu'):
    dim = X.size(1)
    K = len(Xs)
    theta = rand_projections(dim, L, device)
    wasserstein_distance = [one_dimensional_Wasserstein(Xs[i], X, theta,p=p) for i in range(K)]
    wasserstein_distance = torch.stack(wasserstein_distance, dim=0)
    sw = torch.max(wasserstein_distance, dim=0)[0]
    weight = torch.softmax(sw,dim=-1)
    return torch.max(torch.sum(weight.view(1,L)*wasserstein_distance,dim=1))

def transform(src,target1,target2,src_label,origin,sw_type='SW',L=10,num_iter=1000,lam=1,device='cuda'):
    s = np.array(src).reshape(-1, 3)
    s = torch.from_numpy(s).float().to(device)
    t1 = np.array(target1).reshape(-1, 3)
    t1 = torch.from_numpy(t1).float().to(device)
    t2 = np.array(target2).reshape(-1,3)
    t2 = torch.from_numpy(t2).float().to(device)
    Xs =torch.stack([t1,t2],dim=0)
    s = torch.nn.parameter.Parameter(s)
    opt = torch.optim.SGD([s], lr=0.0001)
    for i in tqdm.tqdm(range(num_iter)):
        opt.zero_grad()
        if (sw_type == 'bsw'):
            g_loss = BSW(Xs, s, L=L,device=device)
        elif (sw_type == 'obsw'):
            g_loss = OBSW(Xs, s, L=L,lam=lam,device=device)
        elif (sw_type == 'fbsw'):
            g_loss = FBSW(Xs, s, L=L,device=device)
        elif (sw_type == 'fbswl'):
            g_loss = lowerboundFBSW(Xs, s, L=L,device=device)
        elif (sw_type == 'efbsw'):
            g_loss = EFBSW(Xs, s, L=L,device=device)
        elif (sw_type == 'efbswl'):
            g_loss = lowerbound_EFBSW(Xs, s,L=L,device=device)
        g_loss = g_loss*s.shape[0]
        opt.zero_grad()
        g_loss.backward()
        opt.step()
        s.data = torch.clamp(s, min=0,max=255)
    img_ot_transf = s.cpu().detach().numpy()[src_label].reshape(origin.shape)
    s = s.cpu().detach().numpy()
    img_ot_transf = img_ot_transf /np.max(img_ot_transf) * 255
    img_ot_transf = img_ot_transf.astype("uint8")
    return s, img_ot_transf
