import numpy as np
import ot
dtype_prec = "float128"
esp_control = 1e-2
max_iter_control = 100

def p_norm(x, p):
    return np.linalg.norm(x, ord = p)

def compute_P(source, target, p = None):
    C = ot.dist(source, target, metric="minkowski", p = p)
    return C**p

def compute_s(source, target, P, p = None, eta1 = 0.01, esp = esp_control, max_iter = max_iter_control):
    m1, m2, dim = len(source), len(target), source.shape[1]
    threshold = 1/(m1*m2)
    s = np.zeros(dim)
    for l in range(max_iter):
        grads_E = np.zeros(dim)
        for i in range(m1):
            for j in range(m2):
                if P[i,j] >= threshold:
                    diff = source[i] - target[j] + s
                    if p == 1:
                        fij = np.sign(diff)
                    else:
                        fij = np.sign(diff) * (np.abs(diff)**(p-1))
                    grads_E += P[i,j]*fij
        s = s - eta1 * grads_E
        if p_norm(grads_E, p) <= esp:
            break
    return s


def rescale_to_nonnegative(source_probs, target_probs, matrix, tol=esp_control, max_iter = max_iter_control):
    """
    Rescale the input matrix to a non-negative matrix while keeping the row sums
    and column sums the same as the original matrix.

    Parameters:
    matrix (np.ndarray): The input matrix to be rescaled.
    tol (float): The tolerance for the convergence of the algorithm.
    max_iter (int): The maximum number of iterations to run the algorithm.

    Returns:
    np.ndarray: The rescaled non-negative matrix.
    """
    row_sums = source_probs
    col_sums = target_probs
    rescaled_matrix = matrix.copy()
    for _ in range(max_iter):
        row_scaling_factors = row_sums / np.sum(rescaled_matrix, axis=1)
        rescaled_matrix = rescaled_matrix * row_scaling_factors[:, np.newaxis]
        col_scaling_factors = col_sums / np.sum(rescaled_matrix, axis=0)
        rescaled_matrix = rescaled_matrix * col_scaling_factors

        if np.allclose(np.sum(rescaled_matrix, axis=1), row_sums, atol=tol) and \
                np.allclose(np.sum(rescaled_matrix, axis=0), col_sums, atol=tol):
            break

    return rescaled_matrix

def compute_RWp(source_dist, target_dist, p, eta2 = 0.01, eps2 = esp_control, maxiter = max_iter_control):
    source, source_probs = source_dist.locs, source_dist.probs
    target, target_probs = target_dist.locs, target_dist.probs    
    m1, m2, dim = len(source), len(target), source.shape[1]
    threshold = 1/(m1*m2)
    P = threshold*np.ones((m1, m2))
    s_last = np.ones(dim)*1000
    for k in range(maxiter):
        s = compute_s(source, target, P, p)
        C = compute_P(source+ s, target, p = p)
        P = P - eta2 * C
        P = np.maximum(P, 0)
        P = rescale_to_nonnegative(source_probs, target_probs, P)
        if p_norm(s - s_last, p) <= eps2:
            # return s
            # return np.sum(C * P) ** (1 / p)
            break
        s_last = s
    return np.sum(C * P)
