from sklearn.neighbors import KNeighborsRegressor
from sklearn.neighbors import KNeighborsClassifier
import numpy as np
from numpy import linalg as LA
import pickle
import time
from sklearn.utils import shuffle
from sklearn.ensemble import RandomTreesEmbedding
from sklearn.manifold import MDS
from sklearn.decomposition import TruncatedSVD
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import MinMaxScaler
import matplotlib
import matplotlib.pyplot as plt


#basic sensitivity calculation suited for multiprocessing; works for both training and test data and takes the orders for both the feature space and label space p-metric
def sens_calc(
              ref_data_input,
              ref_labels_input,
              data_input,
              label_input,
              ord_data,
              ord_labels,
              handle,
              bootstrap_iterations,
              bootstraps_length
             ):

    epsilon = 1e-7
    start_time = time.time()

    #only working on copies as we manipulate the data
    ref_data = ref_data_input.copy()
    ref_labels = ref_labels_input.copy()
    data = data_input.copy()

    len_ref_data = len(ref_data)
    len_data = len(data)
    L = [0.0] * len_data


    if label_input is None: # only None for test data; otherwise we will use the true training labels as "proxy labels"; only 1 job for the nn-models as we can not use multiprocessing in child processes here (to be optimised in the future)

        if handle == "REG":
            print("Determining Proxy Regression-Labels", "Time Elapsed:", round(time.time()-start_time, 2), "Secs")
            print()
            #print("reflabs before all", ref_labels[:10])
            neigh = KNeighborsRegressor(n_neighbors=1, n_jobs=1) #by definition KNR uses the euclidean metric, equivalent to the Frobenius norm
            neigh.fit(ref_data, ref_labels)
            proxy_labels = neigh.predict(data)
            #print("proxy labs as predictions", proxy_labels[:10])

        if handle == "CLA":
            print("Determining Proxy Classification-Labels", "Time Elapsed:", round(time.time()-start_time, 2), "Secs")
            print()
            #print("reflabs before all", ref_labels[:10])
            neigh = KNeighborsClassifier(n_neighbors=1, n_jobs=1) #by definition KNC uses the euclidean metric, equivalent to the Frobenius norm
            neigh.fit(ref_data, ref_labels)
            proxy_labels = neigh.predict(data)
            #print("proxy labs as predictions", proxy_labels[:10])
    else:
        print("Copying Labels")
        print()
        proxy_labels = label_input.copy() #to save computation time


    print("Calculating Sensitivity Lists", "Time Elapsed:", round(time.time()-start_time, 2), "Secs")
    print()


    BOOT = np.array([-1.0]*bootstrap_iterations, dtype='float32')
    

    #main calculation part
    for index in range(len_data):
        if index % 100 == 0:
            print("Percentage:", round(index / len_data,4), "Time Elapsed:", round(time.time()-start_time, 2), "Secs")

        for bs in range(bootstrap_iterations):
            ref_data_boot, ref_labels_boot = shuffle(ref_data, ref_labels)
            ref_data_boot, ref_labels_boot = ref_data_boot[:bootstraps_length], ref_labels_boot[:bootstraps_length]

            MAX = float("-inf")
            for i in range(len(ref_data_boot)):
                norm_labels_difference = LA.norm(ref_labels_boot[i] - proxy_labels[index], ord=ord_labels) #for the CLA case and the p-metric this is either in {0,1} for binary tasks or in {0, "p-th root of 2"} for multilabel classification (either way the label map has sup-norm 1)
                #if i%1000==0:
                    #print(norm_labels_difference)
                if norm_labels_difference > 0:
                    MAX = max(MAX, norm_labels_difference / (LA.norm(ref_data_boot[i]-data[index], ord=ord_data) + epsilon))
            BOOT[bs]=MAX
        L[index] = np.median(BOOT) # Robustness value assuming |y|_\infty = 1

        
    return L


def manifold_embedder_and_plotter(data_handle, data, robustness_vals, robustness_ordering, n_neighs, _n_jobs_, show_graphs, path, i_t_r, sample_size):

    #from sklearn: https://scikit-learn.org/stable/auto_examples/manifold/plot_lle_digits.html#sphx-glr-auto-examples-manifold-plot-lle-digits-py
    #n_components = 2 means we embed (and plot) in two dimensions, n_neighs varies with the size (and assumed denseness of the dataset)
    embeddings = {

        "MDS_Embedding": MDS(n_components=2, n_init=5, max_iter=300, n_jobs=_n_jobs_),

        "Random_Trees_Embedding_50": make_pipeline(RandomTreesEmbedding(n_estimators=50, max_depth=10), TruncatedSVD(n_components=2)),
        "Random_Trees_Embedding_100": make_pipeline(RandomTreesEmbedding(n_estimators=100, max_depth=10), TruncatedSVD(n_components=2)),
        "Random_Trees_Embedding_200": make_pipeline(RandomTreesEmbedding(n_estimators=200, max_depth=10), TruncatedSVD(n_components=2)),
        "Random_Trees_Embedding_300": make_pipeline(RandomTreesEmbedding(n_estimators=300, max_depth=10), TruncatedSVD(n_components=2)),

    }

    factor = 200
    len_data = len(data)
    ratio = len_data / sample_size
    colour_ordering = [robustness_ordering[int(k*ratio)] for k in range(sample_size)] #this picks every "ratio" point (as in every "tenth" point), starting from the least robust and ascending to the most robust
    data_sample = data[np.array(colour_ordering)] #already ordered from least to most robust
    robustness_sample = np.array(robustness_vals, dtype='float32')[np.array(colour_ordering)].reshape(-1, 1) #already ordered from least to most robust



    for name, transformer in embeddings.items():

        name = data_handle + name

        projection = transformer.fit_transform(data_sample.copy())

        fig = plt.figure(figsize=(15,15))
        #ax = plt.axes(projection='3d')
        ax = plt.axes()
        #ax.set_title(name + '_' + str(i_t_r))
        #ax.scatter3D(projection[:,0], projection[:,1], projection[:,2], c=colour_ordering, cmap='seismic') #seismic -> blue to red
        ax.scatter(projection[:,0], projection[:,1], c=list(range(sample_size)), cmap='seismic', s=factor*MinMaxScaler().fit_transform(robustness_sample), edgecolor="k") #seismic -> red (most sensitive) to blue (least sensitive)
        plt.tight_layout()
        plt.tick_params(left = False, right = False, labelleft = False, labelbottom = False, bottom = False)
        plt.savefig(path + name + '_' + str(i_t_r) + '.png', dpi=200)
        if show_graphs:
            plt.show()

  
