"""
Module: gaussian_certificate_asym.py

Description:
This module contains a collection of functions to compute the dual certificates of the inverse Optimal Transport (iOT) problem.

Functions:

    - compute_precertificate(theta: np.ndarray, Sigma_X: np.ndarray, Sigma_Y: np.ndarray, eps: float) -> np.ndarray:
        Main function, computes the precertificate for a given problem depending on the value of eps.
        
    - mysqrt(X: np.ndarray) -> np.ndarray:
        Computes the square root of a matrix using Singular Value Decomposition (SVD).
    
    - mysqrtinv(X: np.ndarray) -> np.ndarray:
        Computes the inverse of the square root of a matrix.
    
    - Sigma_fun(A: np.ndarray, eps: float) -> np.ndarray:
        Computes a modified sigma function based on matrix A and a given epsilon value.

    - Sigma_xy_fun(Sigma_x: np.ndarray, Sigma_y: np.ndarray, ground_cost: np.ndarray, eps: float) -> np.ndarray:
        Computes a general sigma function assuming identity covariance.

    - cvxpy_solver(Sigma_x: np.ndarray, Sigma_y: np.ndarray, Sigma_xy: np.ndarray, gamma: float, verbose: bool, maxiter: int) -> Tuple[np.ndarray, np.ndarray]:
        Solves a convex optimization problem and returns the primal and dual solution.

    - compute_minimal_norm_cert(theta: np.ndarray, Sigma_X: np.ndarray, Sigma_Y: np.ndarray, eps: float, gamma: float) -> np.ndarray:
        Computes the dual solution of the minimal norm certificate for a given problem.

Usage:
    >>> import <module_name> as mod
    >>> X = np.array([[1, 0], [0, 2]])
    >>> mod.mysqrt(X)
    >>> ...

Notes:
    - Ensure that appropriate dependencies and libraries are installed, such as cvxpy, numpy, etc.
    - Be cautious with matrix dimensions and data types to avoid unexpected errors.
"""

import jax.numpy as jnp
from jax import grad, jit, vmap,jacfwd, jacrev, jacobian
from jax import random
from jax.scipy.linalg import sqrtm


import numpy as np
from scipy.sparse import coo_matrix

import cvxpy as cp


def mysqrt(X):
    """
    Computes the matrix square root of a given matrix using Singular Value Decomposition (SVD).
    
    Args:
        X (jnp.ndarray): Input matrix. Assumed to be Hermitian. Shape: (n, n).
        
    Returns:
        jnp.ndarray: The square root of the matrix X. Shape: (n, n).
        
    Example:
        >>> X = jnp.array([[4, 0], [0, 9]])
        >>> sqrt_X = mysqrt(X)
    """
    u,s,vh = jnp.linalg.svd(X,full_matrices=False,hermitian=True)
    return u@jnp.diag(jnp.sqrt(s))@vh

def mysqrtinv(X):
    """
    Computes the inverse of the square root of a given matrix.
    
    Args:
        X (jnp.ndarray): Input matrix. Shape: (n, n).
        
    Returns:
        jnp.ndarray: The inverse of the square root of the matrix X. Shape: (n, n).
        
    Example:
        >>> X = jnp.array([[4, 0], [0, 9]])
        >>> sqrt_inv_X = mysqrtinv(X)
    """
    return jnp.linalg.inv(mysqrt(X))

def Sigma_fun(A,eps):
    """
    Computes modified singular value decomposition for input matrix A and a given epsilon.
    
    Args:
        A (jnp.ndarray): Input matrix. Shape: (n, n).
        eps (float): The epsilon parameter affecting the singular values.
        
    Returns:
        jnp.ndarray: Resultant matrix after applying modified SVD. Shape: (n, n).
        
    Example:
        >>> A = jnp.array([[1, 2], [3, 4]])
        >>> eps = 0.1
        >>> Sigma_A = Sigma_fun(A, eps)
    """
    u,s,vh = jnp.linalg.svd(A,full_matrices=False)
    
    s_inv = s
    s_inv = s_inv.at[s>0].set(1/s[s>0])

    s2 = jnp.sqrt(1+.25* eps**2 * s_inv**2)-.5* eps * s_inv
    return u@jnp.diag(s2)@vh


def Sigma_xy_fun(Sigma_x,Sigma_y,ground_cost,eps):
    """
    Computes the cross term in the OT_eps plan Sigma_xy for given Sigma_x and Sigma_y.
    
    Args:
        Sigma_x (jnp.ndarray): Input matrix Sigma_x. Shape: (n, n).
        Sigma_y (jnp.ndarray): Input matrix Sigma_y. Shape: (n, n).
        ground_cost (jnp.ndarray): The ground cost matrix. Shape: (n, n).
        eps (float): The epsilon parameter affecting the calculations.
        
    Returns:
        jnp.ndarray: Resultant Sigma_xy matrix. Shape: (n, n).
        
    Example:
        >>> Sigma_x = jnp.array([[1, 0], [0, 2]])
        >>> Sigma_y = jnp.array([[2, 0], [0, 3]])
        >>> ground_cost = jnp.array([[1, 2], [3, 4]])
        >>> eps = 0.1
        >>> Sigma_xy = Sigma_xy_fun(Sigma_x, Sigma_y, ground_cost, eps)
    """
    
    Sigma_y_sqrt = mysqrt(Sigma_y)
    Sigma_x_sqrt = mysqrt(Sigma_x)    
    A=Sigma_x_sqrt@ground_cost@Sigma_y_sqrt    
    A_inv = jnp.linalg.pinv(A)   
    u,s,vh = jnp.linalg.svd(A,full_matrices=False,hermitian=False)
    Sigma_xy = u@jnp.diag(jnp.sqrt(1+(eps**2*0.25)*1/(s**2)))@vh-0.5*eps*A_inv.T   
    return -Sigma_x_sqrt@Sigma_xy@Sigma_y_sqrt


def cvxpy_solver(Sigma_x, Sigma_y, Sigma_xy, gamma,verbose=False,maxiter=5000):
    """
    Solves the inverse Optimal Transport (iOT) problem with Gaussian parameterization using CVXPY.
    This function computes the primal and dual solution of the iOT problem where the primal solution is eps*primal.

    Args:
        Sigma_x (np.ndarray): Input matrix Sigma_x, assumed to be symmetric positive definite. Shape: (n, n).
        Sigma_y (np.ndarray): Input matrix Sigma_y, assumed to be symmetric positive definite. Shape: (m, m).
        Sigma_xy (np.ndarray): Input matrix Sigma_xy, represents the ground cost in the iOT problem. Shape: (n, m).
        gamma (float): The threshold value used in the constraints of the iOT problem.
        verbose (bool, optional): If True, the solver will display information about the progress. Default is False.
        maxiter (int, optional): Maximum number of iterations for the solver. Default is 5000.

    Returns:
        tuple: A tuple containing:
            - primal (np.ndarray): The primal solution of the iOT problem. Shape: (n, m).
            - dualSolution (np.ndarray): The dual solution of the iOT problem. Shape: (n, m).

    Raises:
        Exception: If the solver does not converge to the optimal solution.

    Example:
        >>> Sigma_x = np.array([[1, 0], [0, 2]])
        >>> Sigma_y = np.array([[2, 0], [0, 3]])
        >>> Sigma_xy = np.array([[1, 2], [3, 4]])
        >>> gamma = 0.1
        >>> primal, dual = cvxpy_solver(Sigma_x, Sigma_y, Sigma_xy, gamma, verbose=True)
    """
    
    n,m = Sigma_xy.shape


    Z = cp.Variable(shape=(n,m), PSD=False)

    objctiveFun=cp.Maximize(cp.log_det(cp.bmat([[Sigma_x,Z],[Z.T,Sigma_y]])))

    constraints=[-Sigma_xy+Z <= gamma,Sigma_xy-Z <= gamma]

    prob = cp.Problem(objctiveFun, constraints)

    prob.solve(verbose=verbose)
    #if prob.status != cp.OPTIMAL:
        #raise Exception('CVXPY Error')

    shiftedDual=Z.value
    dualSolution=-(Sigma_xy-shiftedDual)/gamma

    #recover the primal solution
    Sigma_x_inv = np.linalg.pinv(Sigma_x)
    inv=np.linalg.inv(Sigma_y-shiftedDual.T@Sigma_x_inv@shiftedDual)
    primal = -Sigma_x_inv@shiftedDual@inv

    return primal, dualSolution


def compute_minimal_norm_cert(theta, Sigma_X, Sigma_Y, eps, gamma=1e-4):
    """
    Computes the dual solution of the minimal norm certificate for a given problem.

    Args:
        theta (np.ndarray): The ground cost matrix. Shape: (n, m).
        Sigma_X (np.ndarray): The covariance matrix for X. Shape: (n, n).
        Sigma_Y (np.ndarray): The covariance matrix for Y. Shape: (m, m).
        eps (float): A parameter used for calculations within Sigma_xy_fun.
        gamma (float, optional): A small constant used in cvxpy_solver. Default is 1e-4.

    Returns:
        np.ndarray: The dual solution of the minimal norm certificate problem. Shape: (n, m).

    Example:
        >>> theta = np.array([[1, 2], [3, 4]])
        >>> Sigma_X = np.array([[1, 0], [0, 1]])
        >>> Sigma_Y = np.array([[2, 0], [0, 2]])
        >>> eps = 0.1
        >>> minimal_norm_cert = compute_minimal_norm_cert(theta, Sigma_X, Sigma_Y, eps)
    """
    
    n,m = theta.shape
    Sigma_XY =  Sigma_xy_fun(Sigma_X,Sigma_Y,theta,eps)
    _, dual = cvxpy_solver(Sigma_X, Sigma_Y, Sigma_XY, 
                           gamma,verbose=False,maxiter=10000)

    return dual



def compute_precertificate(theta, Sigma_X, Sigma_Y, eps):
    """
    Computes the precertificate for a given problem depending on the value of eps.

    Args:
        theta (np.ndarray): The ground cost matrix. Shape: (n, m).
        Sigma_X (np.ndarray): The covariance matrix for X. Shape: (n, n).
        Sigma_Y (np.ndarray): The covariance matrix for Y. Shape: (m, m).
        eps (float): Determines the type of computation used for generating the precertificate.

    Returns:
        np.ndarray: The computed precertificate. Shape: (n * m, ).

    Example:
        >>> theta = np.array([[1, 2], [3, 4]])
        >>> Sigma_X = np.array([[1, 0], [0, 1]])
        >>> Sigma_Y = np.array([[2, 0], [0, 2]])
        >>> eps = 0.1
        >>> precertificate = compute_precertificate(theta, Sigma_X, Sigma_Y, eps)
    """
    n,m = theta.shape
    if eps==jnp.inf:   
        X = jnp.kron(Sigma_X,Sigma_Y)
        [I,J] = jnp.where( theta!=0 )
        K = J + I*m
        U = X[K[:,None],K[None,:]]
        return X[:,K] @ np.linalg.inv(U) @ np.sign(theta[I,J])
    elif eps==0:
        # only valid for Sigma_X=Sigma_Y=Id
        [I,J] = jnp.where( theta!=0 )
        K = I + J*n
        Ci = np.linalg.inv( theta )
        CiKron = np.kron(Ci,Ci)
        U = CiKron[K[:,None],K[None,:]]
        return CiKron[:,K] @ np.linalg.inv(U) @ np.sign(theta[I,J])

    Sigma_Y_sqrt = mysqrt(Sigma_Y)
    Sigma_X_sqrt = mysqrt(Sigma_X)   
    
    #cross covariance
    S = Sigma_fun( Sigma_X_sqrt@theta@Sigma_Y_sqrt ,eps)
    
    #transpose matrix
    [J,I] = np.meshgrid( range(m), range(n) )
    I = I.flatten()
    J = J.flatten()
    T = coo_matrix(( np.ones(n*m), (I + J*n, J + I*m)))
    T = jnp.asarray(T.toarray())

    Zc = jnp.linalg.inv(Sigma_Y - Sigma_Y_sqrt@S.T@S@Sigma_Y_sqrt)
    Zr = jnp.linalg.inv(Sigma_X - Sigma_X_sqrt@S@S.T@Sigma_X_sqrt)
    d_Sigma_XY_inv = jnp.kron(Zr,Zc) + jnp.kron(theta,theta.T) @ T/eps**2
    d_Sigma_XY = jnp.linalg.inv(d_Sigma_XY_inv)
    
    
    # Computation of the certificate
    [I,J] = jnp.where( theta!=0 )
    K = J + I*m
    U = d_Sigma_XY[K[:,None],K[None,:]]
    certificate = d_Sigma_XY[:,K] @ np.linalg.inv(U) @ np.sign(theta[I,J])
    
    return certificate
