import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from scipy.integrate import romb
from timeit import default_timer as timer
from typing import Optional
from scipy.linalg import cho_factor
from scipy.linalg import solve_triangular, solve
from scipy.sparse.linalg import cg
from scipy.spatial import distance_matrix
from scipy.optimize import toms748

### Calculate distances between observations in an array
def normalise_distances_by_diameter(dist_mtx):
    diameter = np.max(dist_mtx)
    print("Diameter " + str(round(diameter, 2)))
    return dist_mtx/diameter

def get_dist(X, p=2, normalise_by_diameter=True):
    dist = distance_matrix(X, X, p=p)
    if normalise_by_diameter:
        dist = normalise_distances_by_diameter(dist)
    return dist

### Magnitude calculations
def magnitude_cholesky_inversion_smart(D):
    M = np.exp(-D)
    c, lower = cho_factor(M)
    x = solve_triangular(c, np.ones(M.shape[0]), trans=1)
    return x.T @ x

def magnitude_naive(D):
    M = np.exp(-D)
    Z = np.linalg.inv(M)
    return Z.sum()

def compute_magnitude_from_distances(dist_mtx, ts=np.arange(0.01, 5, 0.01), method="smart"):
    if method=="smart":
        mag_fn = magnitude_cholesky_inversion_smart
    else:
        mag_fn = magnitude_naive

    magnitude = []
    for t in ts:
        if t==0:
            magnitude.append(1)
        else:
            try:
                magnitude.append(mag_fn(t*dist_mtx))
            except Exception as e:
                print(f'Exception: {e} for t: {t} perturbing matrix')
                D = (
                        np.exp(-t * dist_mtx)
                        +
                        0.01 * np.identity(
                    n=dist_mtx.shape[0]
                )  # perturb similarity mtx to invert
                )
                magnitude.append(mag_fn(D))
                
    return np.array(magnitude)

def compute_magnitude_cholesky_inversion(W, p=2, ts=np.arange(0.01, 5, 0.01), normalise_by_diameter=True):
    ''' Computes the magnitude Mag(tX) for all t in the set ts.
    	W - a matrix where each row is a point in R^n
		p - an integer the metric that should be used [1 is for l1 (Manhattan), 2 is for l2 (euclidean)]
		ts - an array of the values of t for which the magnitude should be computed '''
    dist_mtx = get_dist(W, p, normalise_by_diameter=normalise_by_diameter)
    magnitude = compute_magnitude_from_distances(dist_mtx, ts=ts)
    return magnitude

### Magnitude weight calculations
def magnitude_weights_naive(D):
    M = np.exp(-D)
    Z = np.linalg.inv(M)
    return Z.sum(axis=1)

def magnitude_weights(D):
    M = np.exp(-D)
    w = solve(M, np.ones(M.shape[0]), assume_a = "pos")
    return w

def magnitude_weights_cholesky(D):
    M = np.exp(-D)
    c, lower = cho_factor(M)
    linv = solve_triangular(c, np.identity(M.shape[0]), trans=0)
    return (linv @ linv.T).sum(axis=1)

def magnitude_weights_conjugate_gradient_iteration(D):
    M = np.exp(-D)
    ones = np.ones(M.shape[0])
    x, _ = cg(M, ones, atol=1e-3)
    return x

def compute_magnitude_weights_from_distances(dist_mtx, ts=np.arange(0.01, 5, 0.01), method="smart"):
    if method=="smart":
        mag_fn = magnitude_weights
    elif method=="cholesky":
        mag_fn = magnitude_weights_cholesky
    elif method =="conjugate_gradient_iteration":
        mag_fn = magnitude_weights_conjugate_gradient_iteration
    else:
        mag_fn = magnitude_weights_naive
    
    magnitude = np.zeros(shape=(dist_mtx.shape[0], len(ts)))
    for i, t in enumerate(ts):
        if t==0:
            continue
        else:
            try:
                magnitude[:,i] = (mag_fn(t*dist_mtx))
            except Exception as e:
                print(f'Exception: {e} for t: {t} perturbing matrix')
                D = (
                        np.exp(-t * dist_mtx)
                        +
                        0.01 * np.identity(
                    n=dist_mtx.shape[0]
                )  # perturb similarity mtx to invert
                )
                magnitude[:,i] = (mag_fn(D))
    return magnitude

def compute_magnitude_weights(W, p=2, ts=np.arange(0.01, 5, 0.01), normalise_by_diameter=True):
    ''' Computes the magnitude Mag(tX) for all t in the set ts.
    	W - a matrix where each row is a point in R^n
		p - an integer the metric that should be used [1 is for l1 (Manhattan), 2 is for l2 (euclidean)]
		ts - an array of the values of t for which the magnitude should be computed '''
    dist_mtx = get_dist(W, p=p, normalise_by_diameter=normalise_by_diameter)
    magnitude = compute_magnitude_weights_from_distances(dist_mtx, ts=ts)
    return magnitude

### Helper function to remove duplicate rows / points
def remove_duplicates_in_emb(X, y):
    X_unique, indices = np.unique(X, axis=0, return_index=True)
    y_unique = y[indices]
    n=X_unique.shape[0]
    print(str(round(n)) + " distinct points in X_emb")
    return X_unique, y_unique, indices, n

### Magnitude convergence estimation
def specify_magnitude_computations(input_type="points", p=2, normalise_by_diameter=True):
    if input_type=="distances":
        comp_mag = compute_magnitude_from_distances
    else:
        comp_mag = compute_magnitude_cholesky_inversion

    if 'p' in comp_mag.__code__.co_varnames:  # Check if function has 'k' as an argument
        comp_mag = lambda arg, p=p, func=comp_mag: comp_mag(arg, p)
        
    return comp_mag

def mag_convergence(W, x0, x1, f=None, max_iterations=100):
    return toms748(f, x0, x1, maxiter=max_iterations, rtol=1e-05)

def guess_target_point(X, comp_mag, target_value, guess=10):
    def f(x, W=X):
        mag = comp_mag(W, ts=[x])
        return mag[0] - target_value
    
    lower_guess = 0
    f_guess = f(guess)
    while (f_guess<0):
        lower_guess = guess
        guess = guess*10
        f_guess = f(guess)

    conv = mag_convergence(X, lower_guess, guess, f, max_iterations=100)
    return conv

def compute_magnitude_until_target(X, target_value, num_intervals=40, guess=100, 
    p=2, input_type="points", normalise_by_diameter=True):
    
    if input_type=="points":
        X = get_dist(X, p=p, normalise_by_diameter=normalise_by_diameter)

    comp_mag = specify_magnitude_computations(input_type="distances",
                normalise_by_diameter=normalise_by_diameter, p=p)
    
    conv = guess_target_point(X, comp_mag, target_value, guess=guess)
    low = 0

    ts= np.linspace(low, conv, num=num_intervals)
    magnitude = comp_mag(X, ts=ts)
    return magnitude, ts, conv

def compute_target_scale(X, target_value, guess=100, p=2, input_type="points", normalise_by_diameter=True):
    if input_type=="points":
        X = get_dist(X, p=p, normalise_by_diameter=normalise_by_diameter)

    comp_mag = specify_magnitude_computations(input_type="distances", normalise_by_diameter=normalise_by_diameter, p=p)
    conv = guess_target_point(X, comp_mag, target_value, guess=guess)
    return conv

### Magnitude function difference
### Calculate the difference between two magnitude functions evaluated across the same number of ewually spaces intervals
def magnitude_diff(mag_original, mag_emb, diff_method="trpz", p=1):
    ## assume mag_original and mag_emb have been evaluated 
    ## on the same number of steps across comparable intervals
    
    n_intervals=(mag_emb.shape[0]-1)

    if diff_method=="trpz":
        summ = np.trapz
    elif diff_method=="romb":
        summ = romb
    else:
        summ = np.sum
    pointwise_diff = (mag_emb-mag_original)
    mag_diff = (summ(np.abs((pointwise_diff)**p)))**(1/p)/n_intervals
    return mag_diff


### Magnitude weight difference
def magnitude_weight_diff(mag_original, mag_emb, diff_method="trpz", p=1, multi_scale=False):
    ## assume mag_original and mag_emb have been evaluated 
    ## on the same number of steps across comparable intervals
    
    n_intervals=(mag_emb.shape[1]-1)

    if diff_method=="trpz":
        summ = np.trapz
    elif diff_method=="romb":
        summ = romb
    else:
        summ = np.sum

    pointwise_diff = np.abs(mag_emb-mag_original)
    sum_per_interval = pointwise_diff.sum(axis=0)
    if multi_scale:
        return sum_per_interval
    mag_diff = (summ(np.abs((sum_per_interval)**p)))**(1/p)/n_intervals
    return mag_diff

### Magnitude weight difference across multiple examples
def get_diffs_from_weights(weights, all_differences=False):
    nn = weights["original"].shape[0]
    keys = [k for k in weights.keys()]
    mag_diffs = np.zeros(shape=(len(keys), len(keys)))
    for i, k in enumerate(keys):
        for j in range(i+1, len(keys)):
            mag_original = weights[k]
            mag_emb = weights[keys[j]]
            if mag_original.shape[0] == mag_emb.shape[0]:
                mag_diffs[i, j] = magnitude_weight_diff(mag_original/nn, mag_emb/nn,  diff_method="romb")
            else:
                mag_diffs[i, j] = np.nan
    diff_df = pd.DataFrame(mag_diffs + mag_diffs.T, index=keys, columns=keys)
    if all_differences:
        return diff_df
    else:
        return diff_df["original"].sort_values()
    
def get_diffs_from_mag_functions(weights, n_points=None,  all_differences=False):
    if n_points is None:
        nn =1
        nn2=1
    else:
        nn = n_points["original"]
    keys = [k for k in weights.keys()]
    mag_diffs = np.zeros(shape=(len(keys), len(keys)))
    for i, k in enumerate(keys):
        for j in range(i+1, len(keys)): 
            mag_original = weights[k][0]
            mag_emb = weights[keys[j]][0]
            if n_points is not None:
                nn2 =n_points[keys[j]]
            if mag_original.shape[0]== mag_emb.shape[0]:
                mag_diffs[i, j] = magnitude_diff(mag_original/nn, mag_emb/nn2,  diff_method="romb")
            else:
                mag_diffs[i, j] = np.nan
    diff_df = pd.DataFrame(mag_diffs + mag_diffs.T, index=keys, columns=keys)
    if all_differences:
        return diff_df
    else:
        return diff_df["original"].sort_values()

####################################
### Results for the Swiss Roll #####
####################################
### Define the embedding methods

def isomap(X, n_neighbors=30, n_components = 2):
    from sklearn.manifold import Isomap
    isomap = Isomap(n_neighbors=n_neighbors, n_components=n_components, p=1)
    X_emb = isomap.fit_transform(X)
    return X_emb

def mds(X, n_components = 2):
    from sklearn.manifold import MDS
    md_scaling = MDS(
    n_components=n_components,
    max_iter=50,
    n_init=4,
    random_state=0,
    normalized_stress=False,
    )
    X_emb = md_scaling.fit_transform(X)
    return X_emb

def lle(X, n_neighbors=30, n_components=2):
    from sklearn.manifold import locally_linear_embedding
    X_emb, sr_err = locally_linear_embedding(X, n_neighbors=n_neighbors, n_components=n_components)
    return X_emb

def phate(
    X, n_pca: Optional[int] = None, gamma: float = 1
):
    from phate import PHATE

    phate_op = PHATE(n_pca=n_pca, verbose=False, n_jobs=-1, gamma=gamma)
    X_emb = phate_op.fit_transform(X)
    return X_emb


def pca(X, n_comps=2, svd_solver='arpack', random_state=22):
    from sklearn.decomposition import PCA

    pca_ = PCA(
            n_components=n_comps, svd_solver=svd_solver, random_state=random_state
        )
    X_pca = pca_.fit_transform(X)
    return X_pca

def tsne(X, n_pca=None, n_neighbors=30):
    import scanpy as sc
    import anndata as ad

    adata = ad.AnnData(X)

    if n_pca is not None:
        n_pca = min(n_pca, (min(X.shape[0], X.shape[1])-1))
        adata.obsm["X_pca"] = sc.tl.pca(adata, n_comps=n_pca, svd_solver="arpack")
        sc.tl.tsne(adata, use_rep="X_pca", n_pcs=n_pca, perplexity = n_neighbors)
    else:
        sc.tl.tsne(adata, perplexity = n_neighbors)
    X_emb = adata.obsm["X_tsne"]
    return X_emb

def umap(X, n_pca=None,  densmap=False, n_neighbors=30):
    from umap import UMAP
    import anndata as ad
    import scanpy as sc

    adata_input = ad.AnnData(X)

    if n_pca is not None:
        sc.tl.pca(adata_input, n_comps=n_pca, svd_solver="arpack")
        X = adata_input.obsm["X_pca"]
    else:
        X = adata_input.X

    X_emb = UMAP(densmap=densmap, random_state=42, n_neighbors=n_neighbors).fit_transform(X)
    return X_emb

def get_swiss_roll_embeddings(X, true_emb, lables, plot=True):
    def return_original(X, true_emb=X):
        return true_emb
    def return_truth(X, true_emb=true_emb):
        return true_emb

    fig, ax = plt.subplots(2, 4, figsize=(16, 9))
    names = ["pca", "mds","phate","tsne", "umap",  "isomap",  "lle", "truth"]
    embedding_methods = [pca, mds,  phate, tsne, umap, isomap,  lle, return_truth]
    embeddings=[]

    ### Create the embeddings
    i=0
    j=0
    X_original = X.copy()
    for method, name in zip(embedding_methods, names):  
        print(name)
        one_emb = method(X_original)
        embeddings.append(one_emb)
        if plot:
            ax[i, j].scatter([b[0] for b in one_emb], [b[1] for b in one_emb], c=lables)
            ax[i, j].set_title(name)

        if j==(4-1):
            j=0
            i+=1
        else:
            j+=1
    embeddings.append(X)
    names.append("original")
    return embeddings, names


### Calculate magnitude weights and magnitude functions
def magnitude_embedding_experiment(embeddings, names, colors, num_intervals, p, target_proportion=0.95):
    weights={}
    times=[]
    mag_functions = {}
    nr_points=[]
    convs=[]
    times2=[]
    diameters=[]
    lables = colors.copy()

    for this_emb, name in zip(embeddings, names):
        print(name)
        
        emb, y_unique, indices_emb, n = remove_duplicates_in_emb(this_emb, lables)
        target_value = n*target_proportion
        
        start = timer()
        conv = compute_target_scale(emb, target_value, guess=100, p=p, input_type="points", normalise_by_diameter=True)
        convs.append(conv)
        ts=np.linspace(0, conv, num=num_intervals)
        end = timer()
        time = end - start
        print("finding convergence took " + str(round(time,2)) + " seconds")

        start = timer()
        a = compute_magnitude_weights(emb, p=p, ts=ts, normalise_by_diameter=True)
        magnitudes = a.sum(axis=0)
        end = timer()
        time_2 = end - start
        print("computing weights took " + str(round(time_2,2)) + " seconds")

        weights[name] = a
        nr_points.append(n)
        times.append(time)
        times2.append(time_2)
        mag_functions[name] = [magnitudes, ts]
        print("-------------------")

    other_information=pd.DataFrame({"unique_points": nr_points,
                                    "conv_scales": convs, 
                                    "time1": times, "time2": times2}, index=names)
    return mag_functions, weights, other_information