import torch
import numpy as np
from tqdm import tqdm
from sklearn.manifold import SpectralEmbedding
from entropic_affinity import SNE_affinity, se_affinity, se_affinity_dual_ascent
from self_sinkhorn import log_selfsink
import root_finding as root_finding


OPTIMIZERS = {'Adam': torch.optim.Adam, 'SGD': torch.optim.SGD}


# ---------- Spectral dimension reduction methods ----------


def PCA(X, q=2):
    U, S, _ = torch.linalg.svd(X, full_matrices=False)
    return U[:, :q]*S[:q]


def LE(P, q=2):
    return SpectralEmbedding(n_components=q, affinity='precomputed').fit_transform(P)


# ---------- Dimension reduction methods related to stochastic neighbor embedding, including ours ----------


def affinity_coupling(P0, Z, kernel=None, eps=1.0,  lr=1, max_iter=1000, optimizer='Adam', loss='KL', verbose=True, tol=1e-4, pz=2, exaggeration=False):

    Z.requires_grad = True
    f = None
    optimizer = OPTIMIZERS[optimizer]([Z], lr=lr)
    counter_cv = 0

    log = {}
    log['loss'] = []
    #log['Z'] = []

    pbar = tqdm(range(max_iter))
    for k in pbar:
        C = torch.cdist(Z, Z, p=pz)**2
        # C.fill_diagonal_(0)
        optimizer.zero_grad()

        if exaggeration and k < 100:
            P = 12*P0
        else:
            P = P0

        if kernel == 'student' or kernel == 'gaussian':
            log_k = log_kernel(C, kernel=kernel)
            log_Q = log_k - torch.logsumexp(log_k, dim=(0, 1))
        else:
            assert eps is not None
            student = (kernel == 'tsnekhorn')
            log_Q, f = log_selfsink(C=C, eps=eps, f=f, student=student)
        # else:
        #     raise ValueError('Kernel not implemented')

        loss = torch.nn.functional.kl_div(log_Q, P, reduction='sum')
        if torch.isnan(loss):
            raise Exception(f'NaN in loss at iteration {k}')

        loss.backward()
        optimizer.step()

        log['loss'].append(loss.item())
        # log['Z'].append(Z.clone().detach().cpu())

        if k > 1:
            delta = abs(log['loss'][-1] - log['loss'][-2]) / \
                abs(log['loss'][-2])
            if delta < tol:
                counter_cv += 1
                if counter_cv > 10:  # Convergence criterion satisfied for more than 10 iterations in a row -> stop the algorithm
                    if verbose:
                        print('---------- delta loss convergence ----------')
                    break
            else:
                counter_cv = 0

            if verbose:
                pbar.set_description(f'Loss : {float(loss.item()): .3e}, '
                                     f'delta : {float(delta): .3e} '
                                     )

    return Z.detach(), log


def log_kernel(C: torch.Tensor, kernel: str):
    if kernel == 'student':
        return - torch.log(1 + C)
    else:  # Gaussian
        return - 0.5 * C

def SNE(X, Z, perp, **coupling_kwargs):
    C = torch.cdist(X, X, p=2)**2
    P = SNE_affinity(C, perp)
    return affinity_coupling(P, Z, **coupling_kwargs)


def DSNE(X, Z, eps, **coupling_kwargs):
    C = torch.cdist(X, X, p=2)**2
    T0 = torch.exp(log_selfsink(C, eps=eps)[0])
    return affinity_coupling(T0, Z, **coupling_kwargs)


def SSNE(X, Z, perp, **coupling_kwargs):
    C = torch.cdist(X, X, p=2)**2
    P = se_affinity_dual_ascent(C, perp=perp)
    return affinity_coupling(P, Z, **coupling_kwargs)


# ---------- UMAP ----------


def umap(P, Z, optimizer='Adam', lr=1e-3, tol=1e-4, max_iter=10000, verbose=True, **coupling_kwargs):
    Z.requires_grad = True
    optimizer = OPTIMIZERS[optimizer]([Z], lr=lr)

    log = {}
    log['loss'] = []
    #log['Z'] = []

    # default values for the heavy tail kernel in low dimension (see UMAP paper for more details)
    a = 1.929
    b = 0.7915

    #P = umap_affinity(Cx, perp=perp)
    loss = torch.nn.BCELoss()

    pbar = tqdm(range(max_iter))
    for k in pbar:
        Cz = torch.pow(torch.cdist(Z, Z, p=2), 2*b)
        optimizer.zero_grad()

        W = torch.pow(1 + a*Cz, -1)
        lossi = loss(W, P)
        if torch.isnan(lossi):
            raise Exception(f'NaN in loss at iteration {k}')

        lossi.backward()
        optimizer.step()

        log['loss'].append(lossi.item())
        # log['Z'].append(Z.clone().detach().cpu())

        if k > 1:
            delta = abs(log['loss'][-1] - log['loss'][-2]) / \
                abs(log['loss'][-2])
            if delta < tol:
                if verbose:
                    print('---------- delta loss convergence ----------')
                break

            if verbose:
                pbar.set_description(f'Loss : {float(lossi.item()): .3e}, '
                                     f'delta : {float(delta): .3e} '
                                     )

    return Z.detach(), log


def umap_affinity(C, perp=30, max_iter=100, tol=1e-5, verbose=True):
    n = C.shape[0]
    rho = C.min(-1)[0]

    def f(sigma):
        log_P = -(C - rho[:, None])/(sigma[:, None])
        return torch.exp(torch.logsumexp(log_P, -1, keepdim=False)) - np.log(perp)

    if verbose:
        print('---------- Computing the Affinity Matrix ----------')

    sigma_star, _, _ = root_finding.false_position(f=f, n=n, begin=None, end=None, tol=tol, max_iter=max_iter, verbose=verbose)
        
    P = torch.exp(-(C - rho[:,None])/(sigma_star[:,None]))
    return P + P.T - P*P.T
