import torch
import torch.optim as optim

from src.utils import get_entropy_F, make_symmetric
from src.ot_utils import init_matrix, wotgrad, compute_cost, compute_local_cost
from src.sinkhorn import balanced_sinkhorn_stable, unbalanced_sinkhorn
from src.wot import WOT

class LWOT(WOT):
    """
    Implementation for Learned Wavelet Optimal Transport. Inherits from WOT.

    Parameters
    ----------
    X1 : torch.tensor
        Source data matrix.
    X2 : torch.tensor
        Target data matrix.
    n_scales : int
        Number of scales to use for wavelet transform.
    w_op : str
        Wavelet operator to use. Options are "heat", "mexican_hat", "itersine", 
        "simple_tight", "half_cosine_kernel" or "meyer".
    T : torch.tensor
        Coupling matrix. If None, then the coupling matrix will be calculated.
    """
    def solve(self, p=None, q=None, epsilon=1e-2, agg_op="sum", balanced=True, rho=None, rho2=None):
        """
        Solves the optimal transport problem.

        Parameters
        ----------
        p : torch.tensor
            Source distribution.
        q : torch.tensor
            Target distribution.

        Returns
        ----------
        T : torch.tensor (p.shape[0], q.shape[0])
            Coupling matrix.
        """
        super().solve(p, q)
        if balanced:
            self.T = get_balanced_lwot_T(
                self.X1_wavelet_coeffs, self.X2_wavelet_coeffs, self.p, self.q, 
                epsilon=epsilon, agg_op=agg_op
            )
        else:
            self.T = get_unbalanced_lwot_T(
                self.X1_wavelet_coeffs, self.X2_wavelet_coeffs, self.p, self.q, 
                epsilon=epsilon, agg_op=agg_op, rho=rho, rho2=rho2
            )

        return self.T

def get_balanced_lwot_T(wavelet_coeffs_X1, wavelet_coeffs_X2, p, q, epsilon=0.1, 
                N=100, K=10, tol=1e-9, verbose=False, agg_op="sum", h=0.4):
    """
    Runs balanced learned wavelet optimal transport. 

    Parameters
    ----------
    wavelet_coeffs_X1 : torch.tensor
        Wavelet coefficients for the source data matrix.
    wavelet_coeffs_X2 : torch.tensor
        Wavelet coefficients for the target data matrix.
    p : torch.tensor
        Source distribution.
    q : torch.tensor
        Target distribution.
    epsilon : float
        Entropic regularization parameter for Sinkhorn projection.
    N : int
        Number of iterations for outer loops.
    K : int
        Number of iterations for Sinkhron loops.
    tol : float
        Tolerance for Sinkhorn projection.
    verbose : bool
        If True, print out progress.
    agg_op : str  
        Aggregation operator to use. Options are "sum", "max", or "mean".
    h : float
        Bandwidth parameter for Gaussian kernel.

    Returns
    ----------
    T : torch.tensor (p.shape[0], q.shape[0])
        Coupling matrix.
    """
    assert agg_op == "sum" or agg_op == "max" or agg_op == "mean"

    p, q = p.cuda(), q.cuda()
    wavelet_coeffs_X1 = wavelet_coeffs_X1.cuda()
    wavelet_coeffs_X2 = wavelet_coeffs_X2.cuda()

    ns, ns = wavelet_coeffs_X1[0].shape
    nt, nt = wavelet_coeffs_X2[0].shape
    num_scales = len(wavelet_coeffs_X2)

    F1 = torch.ones((num_scales, ns, ns))
    F2 = torch.ones((num_scales, nt, nt))

    F1 = torch.nn.Parameter(F1.cuda())
    F2 = torch.nn.Parameter(F2.cuda())
    optimizer = optim.Adam([F1, F2], lr=1e-1)

    T = torch.ger(p, q)

    iter = 0
    err = 1
    weightF1, weightF2 = get_entropy_F(wavelet_coeffs_X1, wavelet_coeffs_X2, h=h)
    weightF1 = (weightF1.sum(axis=1) / weightF1.sum(axis=1).max()).cuda()
    weightF2 = (weightF2.sum(axis=1) / weightF2.sum(axis=1).max()).cuda()
 
    while (err > tol and iter < N):
        print(iter)
        Tprev = T
        with torch.no_grad():
            sym_F1 = make_symmetric(F1)
            sym_F2 = make_symmetric(F2)

            constC = torch.zeros((num_scales, ns, nt), dtype=torch.float).cuda()
            hc1 = torch.zeros((num_scales, ns, ns), dtype=torch.float).cuda()
            hc2 = torch.zeros((num_scales, nt, nt), dtype=torch.float).cuda()

            for i in range(num_scales):
                constC[i,:,:], hc1[i,:,:], hc2[i,:,:] = init_matrix(
                    torch.sqrt(weightF1[i]) * sym_F1[i] * wavelet_coeffs_X1[i], 
                    torch.sqrt(weightF2[i]) * sym_F2[i] * wavelet_coeffs_X2[i], p, q
                )
            wot_loss_matrix = torch.zeros((num_scales, ns, nt)).cuda()
            for i in range(num_scales):
                wot_loss_matrix[i] = wotgrad(constC[i], hc1[i], hc2[i], T)
        
            if agg_op == "sum":
                wot_loss_matrix = wot_loss_matrix.sum(axis=0)
            elif agg_op == "max":
                wot_loss_matrix,_ = torch.max(wot_loss_matrix, axis=0)
            elif agg_op == "mean":
                wot_loss_matrix = wot_loss_matrix.mean(axis=0)

            T = balanced_sinkhorn_stable(
                p, q, wot_loss_matrix, epsilon, method='sinkhorn_stabilized', numItermax=N,cuda=True
            )
            

        for _ in range(K):
            optimizer.zero_grad()
            sym_F1 = make_symmetric(F1)
            sym_F2 = make_symmetric(F2)
            loss = compute_cost(
                torch.sqrt(weightF1[:, None, None]) * sym_F1 * wavelet_coeffs_X1, 
                torch.sqrt(weightF2[:, None, None]) * sym_F2 * wavelet_coeffs_X2, T
            ).sum()
            reg = 2 * ((sym_F1 - 1) ** 2).sum() + 2 * ((sym_F2 - 1) ** 2).sum()

            loss = -(loss) + reg

            loss.backward()
            optimizer.step()

        if iter % 10 == 0:
            # we can speed up the process by checking for the error only all
            # the 10th iterations
            err = torch.norm(T - Tprev)
            if verbose:
                if iter % 200 == 0:
                    print('{:5s}|{:12s}'.format(
                        'It.', 'Err') + '\n' + '-' * 19)
                print('{:5d}|{:8e}|'.format(iter, err))

        iter += 1


    return T.cpu()

def get_unbalanced_lwot_T(wavelet_coeffs_X1, wavelet_coeffs_X2, p, q, epsilon=0.1, 
                N=100, K=10, tol=1e-9, verbose=False, agg_op="sum", rho=5e-1, rho2=1e-2, h=0.4):
    """
    Runs unbalanced learned wavelet optimal transport. 

    Parameters
    ----------
    wavelet_coeffs_X1 : torch.tensor
        Wavelet coefficients for the source data matrix.
    wavelet_coeffs_X2 : torch.tensor
        Wavelet coefficients for the target data matrix.
    p : torch.tensor
        Source distribution.
    q : torch.tensor
        Target distribution.
    epsilon : float
        Entropic regularization parameter for Sinkhorn projection.
    N : int
        Number of iterations for outer loops.
    K : int
        Number of iterations for Sinkhron loops.
    tol : float
        Tolerance for Sinkhorn projection.
    verbose : bool
        If True, print out progress.
    agg_op : str  
        Aggregation operator to use. Options are "sum", "max", or "mean".
    rho : float
        Mass change penalty on p, infinity value reduces to balanced LWOT.
    rho2 : float    
        Mass change penalty on q, infinity value reduces to balanced LWOT.
    h : float
        Bandwidth parameter for Gaussian kernel.

    Returns
    ----------
    T : torch.tensor (p.shape[0], q.shape[0])
        Coupling matrix.
    """
    assert agg_op == "sum" or agg_op == "max" or agg_op == "mean"

    if rho2 == None: rho2 = rho
    p, q = p.cuda(), q.cuda()
    wavelet_coeffs_X1 = wavelet_coeffs_X1.cuda()
    wavelet_coeffs_X2 = wavelet_coeffs_X2.cuda()

    ns, ns = wavelet_coeffs_X1[0].shape
    nt, nt = wavelet_coeffs_X2[0].shape
    num_scales = len(wavelet_coeffs_X2)

    F1 = torch.ones((num_scales, ns, ns))
    F2 = torch.ones((num_scales, nt, nt))

    F1 = torch.nn.Parameter(F1.cuda())
    F2 = torch.nn.Parameter(F2.cuda())
    optimizer = optim.Adam([F1, F2], lr=1e-1)

    T = torch.ger(p, q)

    iter = 0
    err = 1
    weightF1, weightF2 = get_entropy_F(wavelet_coeffs_X1, wavelet_coeffs_X2, h=h)
    weightF1 = (weightF1.sum(axis=1) / weightF1.sum(axis=1).max()).cuda()
    weightF2 = (weightF2.sum(axis=1) / weightF2.sum(axis=1).max()).cuda()
    while (err > tol and iter < N):
        Tprev = T
        with torch.no_grad():
            sym_F1 = make_symmetric(F1)
            sym_F2 = make_symmetric(F2)

            T_prev = T.clone()
            wot_loss_matrix = torch.zeros((num_scales, ns, nt)).cuda()
            for j in range(num_scales):
                wot_loss_matrix[j] = compute_local_cost(
                    T, p, torch.sqrt(weightF1[j]) * sym_F1[j] * wavelet_coeffs_X1[j], q, 
                    torch.sqrt(weightF2[j]) * sym_F2[j] * wavelet_coeffs_X2[j], epsilon, rho, rho2
                )
            
            if agg_op == "sum":
                wot_loss_matrix = wot_loss_matrix.sum(axis=0)
            elif agg_op == "max":
                wot_loss_matrix,_ = torch.max(wot_loss_matrix, axis=0)
            elif agg_op == "mean":
                wot_loss_matrix = wot_loss_matrix.mean(axis=0)

            mass_T = T.sum()

            (up, vp), T = unbalanced_sinkhorn(
                wot_loss_matrix, up, vp, p, q, mass_T + 1e-10, epsilon,
                torch.tensor(rho).cuda(), torch.tensor(rho2).cuda(),
                N, tol
            )
            if torch.any(torch.isnan(T)):
                raise Exception(
                    f"Solver got NaN plan with params (eps, rho, rho2) "
                    f" = {epsilon, rho, rho2}. Try increasing argument eps."
                )
            T = (mass_T / T.sum()).sqrt() * T


        for _ in range(K):
            optimizer.zero_grad()
            sym_F1 = make_symmetric(F1)
            sym_F2 = make_symmetric(F2)
            loss = compute_cost(
                torch.sqrt(weightF1[j]) * sym_F1 * wavelet_coeffs_X1, 
                torch.sqrt(weightF2[j]) * sym_F2 * wavelet_coeffs_X2, T
            ).sum()
            reg = 2.0 * ((sym_F1 - 1) ** 2).sum() + .20 * ((sym_F2 - 1) ** 2).sum()

            loss = -(loss) + reg

            loss.backward()
            optimizer.step()

        if iter % 10 == 0:
            # we can speed up the process by checking for the error only all
            # the 10th iterations
            # print(sym_F1)
            err = torch.norm(T - Tprev)

            if verbose:
                if iter % 200 == 0:
                    print('{:5s}|{:12s}'.format(
                        'It.', 'Err') + '\n' + '-' * 19)
                print('{:5d}|{:8e}|'.format(iter, err))

        iter += 1


    return T.cpu()