import torch

from src.ot_utils import init_matrix, wotgrad, compute_local_cost
from src.sinkhorn import balanced_sinkhorn_stable, unbalanced_sinkhorn
from src.wot import WOT

class VanillaWOT(WOT):
    """
    Implementation for Vanilla Wavelet Optimal Transport. Inherits from WOT.

    Parameters
    ----------
    X1 : torch.tensor
        Source data matrix.
    X2 : torch.tensor
        Target data matrix.
    n_scales : int
        Number of scales to use for wavelet transform.
    w_op : str
        Wavelet operator to use. Options are "heat", "mexican_hat", "itersine", 
        "simple_tight", "half_cosine_kernel" or "meyer".
    T : torch.tensor
        Coupling matrix. If None, then the coupling matrix will be calculated.
    """
    def solve(self, p=None, q=None, epsilon=1e-2, agg_op="sum", balanced=True, rho=None, rho2=None):
        """
        Solves the optimal transport problem.

        Parameters
        ----------
        p : torch.tensor
            Source distribution.
        q : torch.tensor
            Target distribution.

        Returns
        ----------
        T : torch.tensor (p.shape[0], q.shape[0])
            Coupling matrix.
        """
        super().solve(p, q)
        if balanced:
            self.T = get_balanced_T(
                self.X1_wavelet_coeffs, self.X2_wavelet_coeffs, self.p, self.q, 
                epsilon=epsilon, agg_op=agg_op
            )
        else:
            self.T = get_unbalanced_T(
                self.X1_wavelet_coeffs, self.X2_wavelet_coeffs, self.p, self.q, 
                epsilon=epsilon, agg_op=agg_op, rho=rho, rho2=rho2
            )

        return self.T

def get_balanced_T(wavelet_coeffs_X1, wavelet_coeffs_X2, p, q, epsilon=0.1, 
                N=100, tol=1e-9, verbose=False, agg_op="sum"):
    """
    Runs balanced entropy based wavelet optimal transport. 

    Parameters
    ----------
    wavelet_coeffs_X1 : torch.tensor
        Wavelet coefficients for the source data matrix.
    wavelet_coeffs_X2 : torch.tensor
        Wavelet coefficients for the target data matrix.
    p : torch.tensor
        Source distribution.
    q : torch.tensor
        Target distribution.
    epsilon : float
        Entropic regularization parameter for Sinkhorn projection.
    N : int
        Number of iterations for outer loops.
    K : int
        Number of iterations for Sinkhron loops.
    tol : float
        Tolerance for Sinkhorn projection.
    verbose : bool
        If True, print out progress.
    agg_op : str  
        Aggregation operator to use. Options are "sum", "max", or "mean".
    h : float
        Bandwidth parameter for Gaussian kernel.

    Returns
    ----------
    T : torch.tensor (p.shape[0], q.shape[0])
        Coupling matrix.
    """
    assert agg_op == "sum" or agg_op == "max" or agg_op == "mean"

    p, q = p.cuda(), q.cuda()
    wavelet_coeffs_X1 = wavelet_coeffs_X1.cuda()
    wavelet_coeffs_X2 = wavelet_coeffs_X2.cuda()

    ns, ns = wavelet_coeffs_X1[0].shape
    nt, nt = wavelet_coeffs_X2[0].shape
    num_scales = len(wavelet_coeffs_X2)

    T = torch.ger(p, q)

    iter = 0
    err = 1

    constC = torch.zeros((num_scales, ns, nt), dtype=torch.float).cuda()
    hc1 = torch.zeros((num_scales, ns, ns), dtype=torch.float).cuda()
    hc2 = torch.zeros((num_scales, nt, nt), dtype=torch.float).cuda()

    for i in range(num_scales):
        constC[i,:,:], hc1[i,:,:], hc2[i,:,:] = init_matrix(
            wavelet_coeffs_X1[i], wavelet_coeffs_X2[i], p, q
        )

    while (err > tol and iter < N):
        Tprev = T

        constC = torch.zeros((num_scales, ns, nt), dtype=torch.float).cuda()
        hc1 = torch.zeros((num_scales, ns, ns), dtype=torch.float).cuda()
        hc2 = torch.zeros((num_scales, nt, nt), dtype=torch.float).cuda()
        # coupling = torch.from_numpy(coupling).float()
        for i in range(num_scales):
            constC[i,:,:], hc1[i,:,:], hc2[i,:,:] = init_matrix(
                wavelet_coeffs_X1[i], wavelet_coeffs_X2[i], p, q
            )

        wot_loss_matrix = torch.zeros((num_scales, ns, nt)).cuda()
        for i in range(num_scales):
            wot_loss_matrix[i] = wotgrad(constC[i], hc1[i], hc2[i], T)
    
        if agg_op == "sum":
            wot_loss_matrix = wot_loss_matrix.sum(axis=0)
        elif agg_op == "max":
            wot_loss_matrix,_ = torch.max(wot_loss_matrix, axis=0)
        elif agg_op == "mean":
            wot_loss_matrix = wot_loss_matrix.mean(axis=0)
        
        T = balanced_sinkhorn_stable(
            p, q, wot_loss_matrix, epsilon, 
            method='sinkhorn_stabilized', numItermax=N,cuda=True
        )
        
        if iter % 10 == 0:
            # we can speed up the process by checking for the error only all
            # the 10th iterations
            err = torch.norm(T - Tprev)

            if verbose:
                if iter % 200 == 0:
                    print('{:5s}|{:12s}'.format(
                        'It.', 'Err') + '\n' + '-' * 19)
                print('{:5d}|{:8e}|'.format(iter, err))

        iter += 1

    return T.cpu()

def get_unbalanced_T(wavelet_coeffs_X1, wavelet_coeffs_X2, p, q, epsilon=0.01,
                     N=100, tol=1e-6, rho=1.0, rho2=1e-2, agg_op="sum"):
    """
    Runs unbalanced learned wavelet optimal transport. 

    Parameters
    ----------
    wavelet_coeffs_X1 : torch.tensor
        Wavelet coefficients for the source data matrix.
    wavelet_coeffs_X2 : torch.tensor
        Wavelet coefficients for the target data matrix.
    p : torch.tensor
        Source distribution.
    q : torch.tensor
        Target distribution.
    epsilon : float
        Entropic regularization parameter for Sinkhorn projection.
    N : int
        Number of iterations for Sinkhorn loops.
    tol : float
        Tolerance for Sinkhorn projection.
    verbose : bool
        If True, print out progress.
    agg_op : str  
        Aggregation operator to use. Options are "sum", "max", or "mean".
    rho : float
        Mass change penalty on p, infinity value reduces to balanced LWOT.
    rho2 : float    
        Mass change penalty on q, infinity value reduces to balanced LWOT.
    h : float
        Bandwidth parameter for Gaussian kernel.

    Returns
    ----------
    T : torch.tensor (p.shape[0], q.shape[0])
        Coupling matrix.
    """
    if rho2 is None:
        rho2 = rho

    wavelet_coeffs_X1 = wavelet_coeffs_X1.cuda()
    wavelet_coeffs_X2 = wavelet_coeffs_X2.cuda()
    p = p.cuda()
    q = q.cuda()

    num_scales, ns, ns = wavelet_coeffs_X1.shape
    num_scales, nt, nt = wavelet_coeffs_X2.shape

    # Initialize plan and local cost 
    T = torch.outer(p, q).cuda()
    T_prev = torch.zeros_like(T).cuda()
    up, vp = torch.zeros_like(p).cuda(), torch.zeros_like(q).cuda()

    for i in range(N):
        print(f"iter {i}")
        T_prev = T.clone()
        cost = torch.zeros((num_scales, ns, nt)).cuda()
        for j in range(num_scales):
            cost[j] = compute_local_cost(
                T, p, wavelet_coeffs_X1[j], q, wavelet_coeffs_X2[j], epsilon, rho, rho2
            )
        
        if agg_op == "sum":
            wot_loss_matrix = wot_loss_matrix.sum(axis=0)
        elif agg_op == "max":
            wot_loss_matrix,_ = torch.max(wot_loss_matrix, axis=0)
        elif agg_op == "mean":
            wot_loss_matrix = wot_loss_matrix.mean(axis=0)

        mass_T = T.sum()

        (up, vp), T = unbalanced_sinkhorn(
            cost, up, vp, p, q, mass_T + 1e-10, epsilon, torch.tensor(rho).cuda(), torch.tensor(rho2).cuda(),
            N, tol
        )
        if torch.any(torch.isnan(T)):
            raise Exception(
                f"Solver got NaN plan with params (eps, rho, rho2) "
                f" = {epsilon, rho, rho2}. Try increasing argument eps."
            )
        T = (mass_T / T.sum()).sqrt() * T
        if (T - T_prev).abs().max().item() < tol:
            break

    return T.cpu()