"""
SISTA.py
========
This module, SISTA.py, contains implementations of various functions and algorithms related to
Sparse Inverse Shrinkage Thresholding Algorithm (SISTA) and inverse Optimal Transport (iOT) Variational Problem.
The main functionalities of this package are to provide methods for finding cost vectors for the sparse iOT problem,
performing entropic regularized optimal transport, and solving variational problems in the iOT setting.

Functions
---------
- `Eucl_Cost(x: jnp.ndarray, y: jnp.ndarray) -> Callable`:
    Constructs a cost linear operator between input vectors `x` and `y`.

- `sinkhorn(p: jnp.ndarray, q: jnp.ndarray, c_ij: jnp.ndarray, eps: float, maxiter: int = 1000) -> jnp.ndarray`:
    Computes the Sinkhorn distances and returns the coupling matrix for given probability vectors, cost matrix, 
    and entropic regularization parameter.

- `wthresh(x: jnp.ndarray, gamma: float) -> jnp.ndarray`:
    Performs element-wise soft thresholding on the input array `x` with threshold `gamma`.

- `SISTA(hat_pi: jnp.ndarray, eps: float, beta_init: jnp.ndarray, cost: Callable, gamma: float = .01, 
         rho: float = .1, maxiter: int = 500) -> Tuple[jnp.ndarray, List[float]]`:
    Implements the SISTA algorithm to solve the sparse iOT problem and returns the cost vector `beta` and a list
    of objective function values at each iteration.

- `iOT_VarPro_emp(n: int, eps: float, x_init: jnp.ndarray, cost: Callable, gamma: float = .01, 
                 niter: int = 30) -> Tuple[jnp.ndarray, OptimizeResult]`:
    Implements the iOT Variational Problem with empirical distributions, optimizing over semi-dual variables to find
    the optimal cost vector `beta` for a given cost operator.

- `iOT_VarPro(hat_pi: jnp.ndarray, eps: float, x_init: jnp.ndarray, cost: Callable, gamma: float = .01, 
             niter: int = 30) -> Tuple[jnp.ndarray, OptimizeResult]`:
    Implements the iOT Variational Problem using given coupling matrix `hat_pi`, and solving it with BFGS optimization
    on the semi-dual form of the iOT problem.
"""


import jax.numpy as jnp
from jax import grad, jit, vmap,jacfwd, jacrev, jacobian
from jax import random
from jax.scipy.linalg import sqrtm
from jax.scipy.special import logsumexp
from jax.scipy.optimize import minimize

def Eucl_Cost(x,y):
    """
    Constructs a linear operator that represents the cost in terms of the 
    inner product between `beta * x_j` and `y_i`, denoted as C_ij.
    
    The constructed cost function serves as a forward operator when `mode=0`
    and as an adjoint operator when `mode=1`.

    Parameters:
    - x (jnp.ndarray): A 2D array representing the first input matrix. 
                       It's shape is (n, r).
    - y (jnp.ndarray): A 2D array representing the second input matrix.
                       It's shape is (m, c).

    Returns:
    - Callable: A function that takes in a matrix `z` and an optional parameter `mode`.
                When `mode=0`, it operates as a forward operator and when `mode=1`, 
                it operates as an adjoint operator.

    Usage:
    ```python
    x = jnp.array([[1, 2], [3, 4]])
    y = jnp.array([[5, 6], [7, 8]])
    cost_function = Eucl_Cost(x, y)
    result_forward = cost_function(z, mode=0)  # z is a jnp.ndarray
    result_adjoint = cost_function(z, mode=1)  # z is a jnp.ndarray
    ```
    
    Note: 
    - The dimensions of `x` and `y` should be compatible with the intended operations.
    - The returned function is intended to be used with JAX's NumPy (jnp) library. Regular 
      NumPy arrays may need to be converted to jnp arrays before usage.
    """
    n,r = x.shape
    m,c = y.shape

    def cost(z, mode=0):
        if mode==0: #forward
            beta = jnp.reshape(z, (c, r))
            betaX = jnp.einsum('cr,nr->cn',beta,x)
            return jnp.einsum('cn,mc->mn',betaX,y)
        else: #transpose
            zX = jnp.einsum('mn,nr->mr',z,x)
            return jnp.einsum('mr,mc->cr',zX,y).reshape(-1)

    return cost


def sinkhorn(p,q,c_ij,eps,maxiter=1000):
    """
    Computes the Sinkhorn-Knopp algorithm to obtain the coupling matrix 
    that represents the optimal transport plan between two discrete 
    probability distributions with entropic regularization.
    
    Parameters:
    - p (jnp.ndarray): 1D array representing the first probability distribution.
    - q (jnp.ndarray): 1D array representing the second probability distribution.
    - c_ij (jnp.ndarray): 2D array representing the cost matrix. The element c_ij denotes the
                          cost of moving mass from the i-th element of `p` to the j-th element of `q`.
    - eps (float): The entropic regularization parameter. A small value denotes less regularization 
                   and makes the algorithm more accurate but potentially less stable.
    - maxiter (int, optional): The maximum number of iterations to perform. Default is 1000.
    
    Returns:
    - jnp.ndarray: A 2D array representing the coupling matrix that approximates the optimal transport
                   plan between the probability distributions `p` and `q`.
    
    Usage:
    ```python
    p = jnp.array([0.5, 0.5])
    q = jnp.array([0.5, 0.5])
    c_ij = jnp.array([[0.0, 1.0], [1.0, 0.0]])
    eps = 0.01
    coupling_matrix = sinkhorn(p, q, c_ij, eps)
    ```
    
    Note:
    - The input probability vectors `p` and `q` should be non-negative and sum to 1.
    - The algorithm uses the log domain to avoid underflow and overflow issues.
    - The function is intended to be used with JAX's NumPy (jnp) library. Regular 
      NumPy arrays may need to be converted to jnp arrays before usage.
    """
    log_p = jnp.log(p)
    log_q = jnp.log(q)
    f_i = jnp.ones(q.shape)
    g_j = jnp.ones(p.shape)

    for i in range(maxiter):
        g_j = log_p-logsumexp(-c_ij/eps+ f_i[:,None], axis=0)
        f_i = log_q-logsumexp(-c_ij/eps+ g_j[None,:], axis=1)

    return jnp.exp(f_i[:,None] + g_j[None,:] - c_ij/eps)


def wthresh(x, gamma):
    """
    Performs element-wise soft thresholding on the input array `x`.
    
    Soft thresholding is a type of shrinkage operator used in various 
    optimization algorithms, including those in sparse recovery and 
    compressed sensing. For each element in `x`, it either shrinks the 
    value by `gamma` or sets it to zero, depending on the magnitude of 
    the element.
    
    Parameters:
    - x (jnp.ndarray): The input array on which the soft thresholding operation 
                       is to be performed. It can be of any shape.
    - gamma (float): The threshold level. Elements in `x` with absolute value 
                     less than `gamma` are set to zero, and `gamma` is subtracted 
                     from the absolute value of elements greater than `gamma`.
    
    Returns:
    - jnp.ndarray: The array obtained after applying the soft thresholding operator 
                   to `x`, having the same shape as `x`.
    
    Usage:
    ```python
    x = jnp.array([1.5, -0.5, 0.2, -1.7])
    gamma = 0.5
    result = wthresh(x, gamma)  # result will be jnp.array([1.0, 0.0, 0.0, -1.2])
    ```
    
    Note:
    - The function is intended to be used with JAX's NumPy (jnp) library. Regular 
      NumPy arrays may need to be converted to jnp arrays before usage.
    """
    return jnp.maximum(jnp.abs(x)-gamma,jnp.zeros_like(x))*jnp.sign(x)



def SISTA(hat_pi,  eps, beta_init, cost, gamma = .01, rho = .1, maxiter = 500):
    """
    Solves the sparse inverse Optimal Transport (iOT) problem using the Surrogate Inference Splitting 
    Algorithm (SISTA), which is an initial algorithm designed for sparse iOT problems. It is typically 
    slower and used for comparison with optimized versions.
    
    Parameters:
    - hat_pi (jnp.ndarray): The observed coupling matrix, usually of shape (m, n).
    - eps (float): The entropic regularization parameter.
    - beta_init (jnp.ndarray): The initialization vector for the cost vector beta.
    - cost (Callable): The cost operator and its transpose. Typically, a function that computes 
                       cost matrices or linear operators.
    - gamma (float, optional): The L1 regularization parameter. Default is .01.
    - rho (float, optional): The stepsize. Default is .1.
    - maxiter (int, optional): The maximum number of iterations. Default is 500.
    
    Returns:
    - beta (jnp.ndarray): The obtained cost vector beta after the optimization process.
    - fvals (List[float]): A list containing the values of the objective function at each iteration.
    
    The objective function that is minimized is defined as:
    f = eps * sum(pi_ij) + sum(cost_trans_pi * beta) 
      - eps * sum(hat_pi.T @ f_i) - eps * sum(hat_pi @ g_j)
      + gamma * sum(abs(beta))
      
    Here, pi_ij represents the optimal coupling obtained at each iteration, and f_i, g_j are dual variables
    updated in the Sinkhorn steps.
    
    Usage:
    ```python
    hat_pi = some_coupling_matrix
    beta_init = some_initial_vector
    cost_function = some_cost_function # e.g., Eucl_Cost
    eps = 0.01
    gamma = 0.1
    rho = 0.1
    maxiter = 500
    beta, fvals = SISTA(hat_pi, eps, beta_init, cost_function, gamma, rho, maxiter)
    ```
    
    Note:
    - fvals are recorded to monitor the convergence of the algorithm, useful for debugging and optimization purposes.
    """
    
    #define objective function
    cost_trans_pi = cost(hat_pi, mode=1)
    print(cost_trans_pi.shape, beta_init.shape)
    def objective(pi_ij, f_i, g_j, beta):
        f= eps*jnp.sum(pi_ij)+jnp.sum(cost_trans_pi*beta) \
              -eps*jnp.sum(hat_pi.T@f_i) - eps*jnp.sum(hat_pi@g_j)\
        +gamma*jnp.sum(jnp.abs(beta))
        return f
    fvals = []

    beta= beta_init
    m,n = hat_pi.shape
    log_p = jnp.log(jnp.sum(hat_pi,axis=0))
    log_q = jnp.log(jnp.sum(hat_pi,axis=1))
    #f_i = random.normal(key, shape=(m,))
    #g_j = random.normal(key, shape=(n,))
    f_i = jnp.ones((m,))
    g_j = jnp.ones((n,))
    for i in range(maxiter):
        c_ij = cost(beta)

        # sinkhorn steps
        g_j = log_p - logsumexp( -c_ij/eps + f_i[:,None] ,axis=0)
        f_i = log_q - logsumexp( -c_ij/eps + g_j[None,:] ,axis=1)

        # record objective
        pi_ij = jnp.exp( f_i[:,None] + g_j[None,:] - c_ij/eps )
        fvals.append(objective(pi_ij, f_i, g_j, beta))

        # soft thresholding step
        beta = wthresh( beta - rho* cost( hat_pi - pi_ij , mode=1) , rho*gamma )

    return beta, fvals

def iOT_VarPro_emp(n,  eps, x_init, cost, gamma = .01, niter=30):
    """
    Implements the inverse Optimal Transport (iOT) Variational Problem using empirical distributions
    and solving it with BFGS optimization on the semi-dual form of the iOT problem. 
    This function is designed to work with empirical distributions, and it optimizes over semi-dual variables 
    to find the optimal cost vector beta.
    
    Parameters:
    - n (int): The number of elements in the support of the empirical distribution.
    - eps (float): The entropic regularization parameter.
    - x_init (jnp.ndarray): The initial value of the optimization variables [u, v, f].
    - cost (Callable): The cost operator function that computes cost matrices or linear operators.
    - gamma (float, optional): The L2 regularization parameter. Default is .01.
    - niter (int, optional): The number of iterations for the BFGS optimization method. Default is 30.
    
    Returns:
    - beta (jnp.ndarray): The obtained cost vector beta after the optimization process.
    - res (OptimizeResult): The optimization result object from scipy.optimize.minimize. It contains 
                            information about the optimization process, e.g., success status, number of iterations, etc.
    
    The semi-dual objective function that is minimized is defined as:
    obj = -sum(f/n) + eps * sum(1/n * logsumexp(-C/eps + f[:, None]/eps, axis=0))
        + gamma * norm(u)^2/2 + gamma * norm(v)^2/2 + trace(C)/n
    
    Here, u, v are optimization variables related to the cost vector, f is a dual variable, and C is the cost matrix.
    
    Usage:
    ```python
    n = number_of_elements
    eps = 0.01
    x_init = some_initial_vector
    cost_function = some_cost_function # e.g., Eucl_Cost
    gamma = 0.1
    niter = 30
    beta, res = iOT_VarPro_emp(n, eps, x_init, cost_function, gamma, niter)
    ```
    
    Note:
    - Ensure the initial vector x_init and cost function are compatible with the problem dimensions.
    - This function is suited for cases where distributions are empirical, and exact distributions are not available.
    """
    k = (len(x_init)-n)//2
     ###########################define semidual function
    def iOT_semidual(uvf,gamma):
        u = uvf[:k]
        v = uvf[k:2*k]
        f = uvf[2*k:]
        n = len(f)
        C = cost(u*v)
        obj = -jnp.sum(f/n) +eps*jnp.sum(1/n * logsumexp(-C/eps+f[:,None]/eps,axis=0) ) \
        + gamma* jnp.linalg.norm(u)**2/2 + gamma*jnp.linalg.norm(v)**2/2 + jnp.trace(C)/n 
        return obj
    ##################### apply LBFGS to semidual function
    res = minimize(lambda x: iOT_semidual(x, gamma), x_init,method='BFGS', options={'maxiter': niter})
    x = res.x
    beta = x[:k]*x[k:2*k]
    
    return beta, res

def iOT_VarPro(hat_pi,  eps, x_init, cost, gamma = .01, niter=30):
    """
    Implements the inverse Optimal Transport (iOT) Variational Problem using given coupling matrix `hat_pi`
    and solving it with BFGS optimization on the semi-dual form of the iOT problem.
    This function is designed to optimize over semi-dual variables to find the optimal cost vector beta.
    
    Parameters:
    - hat_pi (jnp.ndarray): The given coupling matrix representing the joint distribution.
    - eps (float): The entropic regularization parameter.
    - x_init (jnp.ndarray): The initial value of the optimization variables [u, v, f].
    - cost (Callable): The cost operator function that computes cost matrices or linear operators.
    - gamma (float, optional): The L2 regularization parameter. Default is .01.
    - niter (int, optional): The number of iterations for the BFGS optimization method. Default is 30.
    
    Returns:
    - beta (jnp.ndarray): The obtained cost vector beta after the optimization process.
    - res (OptimizeResult): The optimization result object from scipy.optimize.minimize. It contains 
                            information about the optimization process, e.g., success status, number of iterations, etc.
    
    The semi-dual objective function that is minimized is defined as:
    obj = -sum(f*q) + eps * sum(p * logsumexp(-C/eps + f[:, None]/eps, axis=0))
        + gamma * norm(u)^2/2 + gamma * norm(v)^2/2 + sum(C*hat_pi)
        
    Here, u, v are optimization variables related to the cost vector, f is a dual variable, and C is the cost matrix.
    `p` and `q` are the marginal distributions obtained from the coupling `hat_pi`.
    
    Usage:
    ```python
    hat_pi = some_coupling_matrix
    eps = 0.01
    x_init = some_initial_vector
    cost_function = some_cost_function # e.g., Eucl_Cost
    gamma = 0.1
    niter = 30
    beta, res = iOT_VarPro(hat_pi, eps, x_init, cost_function, gamma, niter)
    ```
    
    Note:
    - Ensure the initial vector x_init and cost function are compatible with the problem dimensions.
    - Ensure that `hat_pi` represents a valid coupling, i.e., a joint distribution whose marginals are the 
      distributions of interest.
    """
    
    p = jnp.sum(hat_pi, axis=0)
    q = jnp.sum(hat_pi, axis = 1)
    k = (len(x_init)-len(q))//2
     ###########################define semidual function
    def iOT_semidual(uvf,gamma):
        u = uvf[:k]
        v = uvf[k:2*k]
        f = uvf[2*k:]
        C = cost(u*v)

        obj = -jnp.sum(f*q) +eps*jnp.sum(p* logsumexp(-C/eps+f[:,None]/eps,axis=0) ) \
            + gamma* jnp.linalg.norm(u)**2/2 + gamma*jnp.linalg.norm(v)**2/2 + jnp.sum(C*hat_pi)

        return obj

    ##################### apply LBFGS to semidual function

    #x_init = random.normal(key, shape=(2*k+N,))
    res = minimize(lambda x: iOT_semidual(x, gamma), x_init,method='BFGS', options={'maxiter': niter})
    x = res.x
    beta = x[:k]*x[k:2*k]
    
    return beta, res