import torch
from torch.autograd import Variable
import cv2
import numpy as np

def sinkhorn_normalized(x, y, epsilon, n, niter):

    Wxy = sinkhorn_loss(x, y, epsilon, n, niter)
    Wxx = sinkhorn_loss(x, x, epsilon, n, niter)
    Wyy = sinkhorn_loss(y, y, epsilon, n, niter)
    return 2 * Wxy - Wxx - Wyy


def sinkhorn_loss(x, y, epsilon, n, niter):
    """
    Given two emprical measures with n points each with locations x and y
    outputs an approximation of the OT cost with regularization parameter epsilon
    niter is the max. number of steps in sinkhorn loop
    """

    # The Sinkhorn algorithm takes as input three variables :
    C = cost_matrix(x, y)  # Wasserstein cost function

    # both marginals are fixed with equal weights
    # mu = Variable(1. / n * torch.cuda.FloatTensor(n).fill_(1), requires_grad=False)
    # nu = Variable(1. / n * torch.cuda.FloatTensor(n).fill_(1), requires_grad=False)
    mu = (1. / n * torch.FloatTensor(n).fill_(1)).to(x.device)
    nu =(1. / n * torch.FloatTensor(n).fill_(1)).to(x.device)

    # Parameters of the Sinkhorn algorithm.
    rho = 1  # (.5) **2          # unbalanced transport
    tau = -.8  # nesterov-like acceleration
    lam = rho / (rho + epsilon)  # Update exponent
    thresh = 10**(-1)  # stopping criterion

    # Elementary operations .....................................................................
    def ave(u, u1):
        "Barycenter subroutine, used by kinetic acceleration through extrapolation."
        return tau * u + (1 - tau) * u1

    def M(u, v):
        "Modified cost for logarithmic updates"
        "$M_{ij} = (-c_{ij} + u_i + v_j) / \epsilon$"
        return (-C + u.unsqueeze(1) + v.unsqueeze(0)) / epsilon

    def lse(A):
        "log-sum-exp"
        return torch.log(torch.exp(A).sum(1, keepdim=True) + 1e-6)  # add 10^-6 to prevent NaN

    # Actual Sinkhorn loop ......................................................................
    u, v, err = 0. * mu, 0. * nu, 0.
    actual_nits = 0  # to check if algorithm terminates because of threshold or max iterations reached

    for i in range(niter):
        u1 = u  # useful to check the update
        u = epsilon * (torch.log(mu) - lse(M(u, v)).squeeze()) + u
        v = epsilon * (torch.log(nu) - lse(M(u, v).t()).squeeze()) + v
        # accelerated unbalanced iterations
        # u = ave( u, lam * ( epsilon * ( torch.log(mu) - lse(M(u,v)).squeeze()   ) + u ) )
        # v = ave( v, lam * ( epsilon * ( torch.log(nu) - lse(M(u,v).t()).squeeze() ) + v ) )
        err = (u - u1).abs().sum()

        actual_nits += 1
        if (err < thresh).data.cpu().numpy():
            break
    U, V = u, v
    pi = torch.exp(M(U, V))  # Transport plan pi = diag(a)*K*diag(b)
    cost = torch.sum(pi * C)  # Sinkhorn cost

    return cost


def cost_matrix(x, y, p=2):
    "Returns the matrix of $|x_i-y_j|^p$."
    x_col = x.unsqueeze(1)
    y_lin = y.unsqueeze(0)
    c = torch.sum((torch.abs(x_col - y_lin)) ** p, 2)
    return c


def MMD(x, y, kernel='gaussian', k_sigma=0.12):
    kernel = kernel.lower()

    m = x.size()[0]

    X = torch.vstack([x,y])
    dist = torch.cdist(X, X)
    if kernel == 'l2':
        K = dist
    elif kernel == 'gaussian':
        K = torch.exp(-(dist**2)/(2 * k_sigma * k_sigma))
    elif kernel == 'laplacian':
        K = torch.exp(-dist/k_sigma)
        
    xx = K[:m, :m].mean()
    xy = K[:m, m:].mean()
    yy = K[m:, m:].mean()
   
    if kernel == 'l2':
        return xy - 0.5*(xx+yy)
    else:
        return xx - 2*xy + yy
    
    
def calc_color_histogram(image, bins=256, mask=None):
    """
    이미지의 BGR 각 채널에 대해 정규화된 히스토그램을 계산합니다.
    
    Args:
        image: np.ndarray, BGR 이미지
        bins: int, 히스토그램 bin 개수
        mask: np.ndarray or None, 마스크 영역 지정 (선택)
        
    Returns:
        hist: np.ndarray, shape=(3, bins), 각 채널별 정규화 히스토그램
    """
    hist = []
    for i in range(3):  # B,G,R 채널
        h = cv2.calcHist([image], [i], mask, [bins], [0, 256])
        h = h.flatten()
        h = h / (h.sum() + 1e-6)  # 정규화
        hist.append(h)
    return np.array(hist)

def color_histogram_emd(hist1, hist2):
    """
    두 컬러 히스토그램 간 EMD를 계산합니다.
    각 채널별 EMD 평균값을 반환.
    
    Args:
        hist1, hist2: np.ndarray, shape=(3, bins)
    
    Returns:
        emd_mean: float, 3채널 EMD 평균
    """
    emd_total = 0
    for i in range(3):
        # OpenCV EMD 함수는 float32 타입 필요, 샘플 차원으로 변환
        signature1 = np.vstack([np.arange(hist1.shape[1]), hist1[i]]).astype(np.float32).T
        signature2 = np.vstack([np.arange(hist2.shape[1]), hist2[i]]).astype(np.float32).T
        
        # OpenCV EMD는 2차원 포인트간 거리 기반, 1차원 히스토그램에 맞게 사용
        emd, _, _ = cv2.EMD(signature1, signature2, cv2.DIST_L2)
        emd_total += emd
    return emd_total / 3

def histogram_intersection(hist1, hist2):
    """
    두 히스토그램 간 교집합 계산 (0~1 사이)
    채널별 평균값 반환
    
    Args:
        hist1, hist2: np.ndarray, shape=(3, bins)
        
    Returns:
        intersection_mean: float
    """
    intersection_total = 0
    for i in range(3):
        intersection = np.sum(np.minimum(hist1[i], hist2[i]))
        intersection_total += intersection
    return intersection_total / 3