import os
import numpy as np
import argparse
import torch
import faiss
from pathlib import Path
from utils.non_neg_qpsolver import non_negative_qpsolver
import scipy.sparse.linalg
import matplotlib.pyplot as plt
import matplotlib.font_manager
import seaborn as sbn
from scipy.spatial.distance import squareform, pdist
from tqdm import tqdm
import warnings
from joblib import Parallel, delayed
import time 

parser = argparse.ArgumentParser(description='NNK Embedding Analysis of SSL')
parser.add_argument('--log_dir', type= Path, default= '/logs/NNK_augmentation_eval', 
                    help= 'directory to load SSL features')
parser.add_argument('--model', default='supervised',type=str,
                    help="model to be loaded: supervised")
parser.add_argument('--type_transfo', default= 'all', type= str,
                    help= ' Type of transformation to be analyzed: all, crop, colorjitter')
parser.add_argument('--type_transfo_2', default= 'crop', type= str,
                    help= ' Type of transformation to be compared with: crop, colorjitter')
parser.add_argument('--n_augmentations', default=50, type=int,
                    help = 'number transformation')
parser.add_argument('--data_index', default= 13, type= int,
                    help= ' Data sample to be analyzed')
parser.add_argument('--extract_analysis', default='get_graph',type=str,
                    help="select plot: get_graph, get_diam, get_angle, get_angle_between")
parser.add_argument('--layer', default='backbone_features', type= str,
                    help= 'Which layer the analysis is performed: backbone_features, projector_features, backbone_layer')
parser.add_argument('--top_k', default=25, type=int, help="initial no. of neighbors")
parser.add_argument('--trained', action='store_true')
parser.add_argument('--no-trained', action='store_false')
parser.set_defaults(feature=True)


def subspace_angles(QA, QB):
    """
    Modified code from scipy.linalg
    """
    # 2. Compute SVD for cosine
    QA_H_QB = np.dot(QA.T, QB)
    

    # 3. Compute matrix B
    if QA.shape[1] >= QB.shape[1]:
        B = QB - np.dot(QA, QA_H_QB)
        sigma = np.sqrt(np.clip(np.linalg.eigvals(QA_H_QB.T @ QA_H_QB), np.float32(0.), np.float32(1.)))
    else:
        B = QA - np.dot(QB, QA_H_QB.T)
        sigma = np.sqrt(np.clip(np.linalg.eigvals(QA_H_QB @ QA_H_QB.T), np.float32(0.), np.float32(1.)))



    # 4. Compute SVD for sine
    mask = sigma ** 2 >=  0.2*(sigma**2).sum()
    # if mask.any():
    mu_arcsin = np.arcsin(np.sqrt(np.clip(np.linalg.eigvals(B.T @ B), np.float32(0.), np.float32(1.)))) # ,  overwrite_a=True, check_finite=True
    # else:
    #     mu_arcsin = np.float32(0.)

    # 5. Compute the principal angles
    # with reverse ordering of sigma because smallest sigma belongs to largest
    # angle theta
    theta = np.where(mask, mu_arcsin, np.arccos(np.clip(sigma[::-1], np.float32(-1.), np.float32(1.))))
    return theta


def principal_components(centered_features):
    n_nodes = len(centered_features)
    #centered_features_m = centered_features - np.mean(centered_features, 0) # Mean center the centered features
    if n_nodes > 1:
        d, v = np.linalg.eig(centered_features @ centered_features.T)
        if d[0] == 0.0:
            return [np.nan]
        else: 
            principal_dirs = v.T[ np.where(d > 0)[0]] @ centered_features / np.sqrt(d[np.where(d > 0)[0], None])
    else:
        return centered_features / np.clip(np.linalg.norm(centered_features), 1e-10, None)

    return principal_dirs

def get_nnk_weighted_graph(aug_features, topk):
    n, dim = aug_features.shape
    normalized_aug_features = aug_features #/ np.linalg.norm(aug_features, axis=1, keepdims=True)
    
    support_data = normalized_aug_features
    index = faiss.IndexFlatL2(dim)
    index = faiss.index_cpu_to_all_gpus(index)
    index.add(support_data)
    
    similarities, indices = index.search(normalized_aug_features, topk+1)
    similarities = similarities[:, 1:]
    indices = indices[:, 1:]

    for i, x in enumerate(normalized_aug_features):
        neighbor_indices = indices[i, :]
        x_support = support_data[neighbor_indices]
        # g_i = 0.5 + 0.5*similarities[i]
        # G_i = 0.5 + 0.5*(x_support @ x_support.T)
        g_i = np.dot(support_data[neighbor_indices], x)
        G_i = x_support @ x_support.T
        x_opt = non_negative_qpsolver(G_i, g_i, g_i, x_tol=1e-10)
        similarities[i] = x_opt/np.sum(x_opt) # We need the normalization to account for origin-shift invariance!
    return similarities, indices

def get_diameter_poly(features, sim, indices):
    n, dim = features.shape
    norm_aug = np.linalg.norm(features, axis=1, keepdims=True)
    normalized_features = np.divide(features, norm_aug,
                                out=np.zeros_like(features), where=norm_aug!=0)

    diam_per_poly = np.zeros(n)
    for i in range(n):
        selected_nodes = indices[i, np.nonzero(sim[i])[0]]
        selected_node_feat = normalized_features[selected_nodes]
        norm_dist = 2*(1 - selected_node_feat @ selected_node_feat.T)
        # norm_dist = pdist(selected_node_feat, 'sqeuclidean')
        diam_per_poly[i] = norm_dist.max()
        
    return np.clip(diam_per_poly, a_min=0, a_max=None)

def get_angles_local(node_principal_components, sim, indices):
    n = sim.shape[0]
    
    pairwise_angle_local = []
    # t1 = time()
    # print(f"Time taken to obtain centered features: {t1-t0:.3f}s")
    for i in range(n):
        nnk_neighbors = indices[i, np.nonzero(sim[i])[0]]
        affinity_i = []
        if True not in np.isnan(node_principal_components[i]):
            if len(nnk_neighbors) > 1: #  and np.linalg.norm(centered_features[i]) > 1e-10 centered_features[i] is an array
                for k in nnk_neighbors:
                    #if centered_features[k].shape[2] >1: # Why do we need this? centered_features[k] is an array
                    if True not in np.isnan(node_principal_components[k]):
                        angle_i = subspace_angles(node_principal_components[i].T, node_principal_components[k].T)
                        affinity_i.append(np.sqrt((np.cos(angle_i)**2).mean()))    
            
        if len(affinity_i) > 0 :
            pairwise_angle_local.append(np.asarray(affinity_i).mean())
    # t2 = time() 
    # print(f"Time taken to compute affinity: {t2-t1:.3f}s")
    return np.hstack(pairwise_angle_local)

def get_aff(centered_features_aug1,centered_features_aug2):
    if np.linalg.norm(centered_features_aug1)>1e-10 and np.linalg.norm(centered_features_aug2)>1e-10 :
        angles = scipy.linalg.subspace_angles(centered_features_aug1.T, centered_features_aug2.T)
        return np.sqrt((np.cos(angles)**2).mean())


def get_node_principal_components(features_aug1, sim_aug1, indices_aug1):
    n = features_aug1.shape[0]
    normalized_features_aug1 = features_aug1 #/ np.linalg.norm(features_aug1, axis=1, keepdims=True)
    node_principal_components_aug1 = []

    for i in range(n):
        selected_nodes_aug1 = indices_aug1[i, np.nonzero(sim_aug1[i])[0]]
        centered_aug1 = normalized_features_aug1[selected_nodes_aug1] - normalized_features_aug1[i]
        node_principal_components_aug1.append(principal_components(centered_aug1))

    return node_principal_components_aug1


def get_angle_between_augs(node_principal_components_aug1, node_principal_components_aug2, sim_aug1, sim_aug2, indices_aug1, indices_aug2):
    n = indices_aug1.shape[0]
    n2 = indices_aug2.shape[0]

    pairwise_angles = []
    for i in range(n):
        affinity = []
        selected_nodes_aug1 = indices_aug1[i, np.nonzero(sim_aug1[i])[0]]
        if True not in np.isnan(node_principal_components_aug1[i]) and  len(selected_nodes_aug1)>1:
            for k in range(n2):
                    selected_nodes_aug2 = indices_aug2[k, np.nonzero(sim_aug2[k])[0]]
                    if True not in np.isnan(node_principal_components_aug2[k]) and  len(selected_nodes_aug2)>1:
                        angles = subspace_angles(node_principal_components_aug1[i].T, node_principal_components_aug2[k].T)
                        affinity.append(np.sqrt((np.cos(angles)**2).mean())) 
        if len(affinity) > 0 :
            pairwise_angles.append(np.asarray(affinity).mean())
    return np.hstack(pairwise_angles)

def get_graph_for_all(index, features, args):
    if args.type_transfo == 'none':
        graph, indices = get_nnk_weighted_graph(features, args.top_k)
        return graph, indices
    else:    
        graph_index = []
        indices_index = []
        unique = np.unique(index)
        for y in tqdm(unique):
            features_index = features[np.where(index==y)]
            graph, indices = get_nnk_weighted_graph(features_index, args.top_k)
            graph_index.append(graph)
            indices_index.append(indices)
        return graph_index, indices_index

def get_diam_for_all(index, graphs, indices, features, args):
    if args.type_transfo == 'none':
        diam_per_poly = get_diameter_poly(features, graphs, indices)
        return np.hstack(diam_per_poly)
    else:
        diam_per_poly = []
        unique = np.unique(index)
        for i, y in enumerate(unique):
            features_index = features[np.where(index==y)]
            graph_index, indices_index = graphs[i], indices[i]
            diam_per_poly.append(get_diameter_poly(features_index, graph_index, indices_index))
        return np.hstack(diam_per_poly)


def get_ID(X, sim, indices):
    data_norm = np.einsum('ij,ij->i',X,X)
    inv_mle = np.zeros(X.shape[0])
    for i, data_i in enumerate(X):
        neighbor_nodes = indices[i, :]
        neighbor_feat = X[neighbor_nodes]
        dist_i = data_norm[i] + data_norm[neighbor_nodes] - 2*neighbor_feat @ data_i.T
        # import IPython; IPython.embed()
        if dist_i.min() > 0 and len(dist_i) > 1:
            inv_mle[i] = np.sum(0.5*np.log(dist_i.max()/ dist_i))/(len(dist_i) - 1) # 0.5 to account for sq. distance 
        else:
            inv_mle[i] = 0
    mle = 1./inv_mle
    return mle, inv_mle

def get_neighbors_ID_for_all(index, graphs, indices, features, args):
    if args.type_transfo == 'none':
        neighbor_count = np.count_nonzero(graphs, 1)
        mle, inv_mle = get_ID(features, graphs, indices)
        return neighbor_count, mle, inv_mle
    else:
        mle = []
        inv_mle = []
        neighbor_count = []
        unique = np.unique(index)
        for i, y in enumerate(unique):
            features_index = features[np.where(index==y)]
            graph_index, indices_index = graphs[i], indices[i]
            neighbor_count.append(np.count_nonzero(graph_index, 1))
            mle_index, inv_mle_index = get_ID(features_index, graph_index, indices_index)
            mle.append(mle_index)
            inv_mle.append(inv_mle_index)
        
        return np.hstack(neighbor_count), np.hstack(mle), np.hstack(inv_mle)



args = parser.parse_args()
if args.type_transfo == 'none':
    args.n_augmentations = 1

filename = os.path.join(args.log_dir, args.model , args.type_transfo, f'{args.model}_{args.type_transfo}_n_aug_{args.n_augmentations}_trained_{args.trained}_ddp_rank_0.npz' )
data = np.load(filename, allow_pickle=True)
features = data[args.layer]
index = data['indices']
label = data['labels']
label_list = np.unique(label)

print(f' LAYER --- {args.layer}')
features = features.reshape(features.shape[0],-1)
print(f' FEATURE SHAPE ------ {features.shape}')

save_dir = os.path.join(args.log_dir, args.model, 'nnk_analysis_save')
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

if args.extract_analysis == 'get_graph':
    for k in tqdm(label_list):    
        graph_per_label, indices_per_label = get_graph_for_all(index[label == k], features[label == k], args)
        filename = os.path.join(save_dir, f"graph_label_{k}_layer_{args.layer}_trained_{args.trained}_aug_{args.type_transfo}.npz")
        print(f' Saving Graph {filename}')
        np.savez_compressed(filename, graphs = graph_per_label, indices = indices_per_label)

if args.extract_analysis == 'save_subpsaces':
    print(f' ----------------- Saving Subspaces: Aug {args.type_transfo} -----------------')
    subspace_save_dir = os.path.join(save_dir, 'subspace_directions')
    if not os.path.exists(subspace_save_dir):
        os.makedirs(subspace_save_dir)
        
    for k in tqdm(label_list):
        filename1 = os.path.join(save_dir, f"graph_label_{k}_layer_{args.layer}_trained_{args.trained}_aug_{args.type_transfo}.npz")
        graph_save = np.load(filename1, allow_pickle=True)
        graphs1 = graph_save['graphs']
        indices1 = graph_save['indices']
        if args.type_transfo == 'none':
            node_principal_components_aug1 = get_node_principal_components(features[label == k], graphs1, indices1)
        else:
            index_i = index[label == k]
            unique = np.unique(index_i)
            
            node_principal_components_aug1 = []
            for i, y in enumerate(unique):
                features_index = features[i*args.n_augmentations: (i+1)*args.n_augmentations]
                graph_index, indices_index = graphs1[i], indices1[i]
                node_principal_components_aug1.append(get_node_principal_components(features_index, graph_index, indices_index))
        
        save_filename = os.path.join(subspace_save_dir, f"subspace_dir_{k}_layer_{args.layer}_trained_{args.trained}_aug_{args.type_transfo}.npz")
        # import IPython; IPython.embed()
        np.savez_compressed(save_filename, subspace_dir=np.asanyarray(node_principal_components_aug1), graphs=graphs1, indices=indices1)
    
elif args.extract_analysis == 'get_diam':
    print(f' ----------------- Diameters: Aug {args.type_transfo} -----------------')
    diam_per_label = []
    for k in tqdm(label_list):
        filename = os.path.join(save_dir, f"graph_label_{k}_layer_{args.layer}_trained_{args.trained}_aug_{args.type_transfo}.npz")
        graph_save = np.load(filename, allow_pickle=True)
        graphs = graph_save['graphs']
        indices = graph_save['indices']
        diam_per_label.append(get_diam_for_all(index[label == k], graphs, indices, features[label == k], args))
    filename = os.path.join(save_dir, f"diameter_layer_{args.layer}_trained_{args.trained}_aug_{args.type_transfo}.npz")
    print(f' Saving Polytopes Diameters {filename}')
    np.savez_compressed(filename, output= np.hstack(diam_per_label), label= label)

elif args.extract_analysis == 'get_angle':
    print(f' ----------------- Angle: Aug {args.type_transfo} -----------------')
    angle_per_label = []
    for k in tqdm(label_list):
        filename = os.path.join(save_dir, f"graph_label_{k}_layer_{args.layer}_trained_{args.trained}_aug_{args.type_transfo}.npz")
        graph_save = np.load(filename, allow_pickle=True)
        graphs = graph_save['graphs']
        indices = graph_save['indices']
        angle_per_label.append(get_angle_for_all(index[label == k], graphs, indices, features[label == k], args))
    filename = os.path.join(save_dir, f"angle_layer_{args.layer}_trained_{args.trained}_aug_{args.type_transfo}.npz")
    print(f' Saving Angle Manifold {filename}')
    np.savez_compressed(filename, output= np.hstack(angle_per_label), label = label)

elif  args.extract_analysis == 'get_neighbors_ID':
    print(f' ----------------- # of NNK neighbors, ID: Aug {args.type_transfo} -----------------')
    neighbors_per_label = []
    mle_per_label = []
    inv_mle_per_label = []
    for k in tqdm(label_list):
        filename = os.path.join(save_dir, f"graph_label_{k}_layer_{args.layer}_trained_{args.trained}_aug_{args.type_transfo}.npz")
        graph_save = np.load(filename, allow_pickle=True)
        graphs = graph_save['graphs']
        indices = graph_save['indices']
        neighbor_count, mle, inv_mle = get_neighbors_ID_for_all(index[label == k], graphs, indices, features[label == k], args)
        neighbors_per_label.append(neighbor_count)
        mle_per_label.append(mle) 
        inv_mle_per_label.append(inv_mle)

    filename = os.path.join(save_dir, f"neighbor_ID_layer_{args.layer}_trained_{args.trained}_aug_{args.type_transfo}.npz")
    print(f' Saving # Neighbors, ID for Manifold {filename}')
    np.savez_compressed(filename, neighbors_per_label= np.hstack(neighbors_per_label), mle_per_label= np.hstack(mle_per_label),
                        inv_mle_per_label = np.hstack(inv_mle_per_label), label = label)


elif args.extract_analysis == 'get_angle_local':
    print(f' ----------------- Angle Local: Aug {args.type_transfo} -----------------')
    angle_per_label = []
    for label_k in tqdm(range(len(label_list))):
        k = label_list[label_k]
        filename1 = os.path.join(save_dir, 'subspace_directions', f"subspace_dir_{k}_layer_{args.layer}_trained_{args.trained}_aug_{args.type_transfo}.npz")
        subspace_save = np.load(filename1, allow_pickle=True)
        graphs1 = subspace_save['graphs']
        indices1 = subspace_save['indices']
        node_principal_components_aug1_all = subspace_save['subspace_dir']
        

        if args.type_transfo == 'none':
            k_angles = get_angles_local(node_principal_components_aug1_all, graphs1, indices1)

        else:
            index_i = index[label == k]
            k_angles = []
            unique = np.unique(index_i)
            for i, y in enumerate(unique):
                graph_index1, indices_index1 = graphs1[i], indices1[i]
                node_principal_components_aug1 = node_principal_components_aug1_all[i]
                k_angles.append(get_angles_local(node_principal_components_aug1, graph_index1, indices_index1))
        
        angle_per_label.append(np.hstack(k_angles))
        
    filename = os.path.join(save_dir, f"angle_local_layer_{args.layer}_trained_{args.trained}_aug_{args.type_transfo}.npz")
    print(f' Saving Local Angle Manifold {filename}')
    np.savez_compressed(filename, output= np.hstack(angle_per_label), label = label)

elif args.extract_analysis == 'get_angle_between_augs':
    t = time.time()
    print(f' ----------------- Angle Between: Aug1 {args.type_transfo} and Aug2 {args.type_transfo_2}  -----------------')
    angle_between = []
    for label_k in tqdm(range(len(label_list))):
        k = label_list[label_k]
        filename1 = os.path.join(save_dir, 'subspace_directions', f"subspace_dir_{k}_layer_{args.layer}_trained_{args.trained}_aug_{args.type_transfo}.npz")
        subspace_save = np.load(filename1, allow_pickle=True)
        graphs1 = subspace_save['graphs']
        indices1 = subspace_save['indices']
        node_principal_components_aug1_all = subspace_save['subspace_dir']

        filename2 = os.path.join(save_dir, 'subspace_directions', f"subspace_dir_{k}_layer_{args.layer}_trained_{args.trained}_aug_{args.type_transfo_2}.npz")
        subspace_save = np.load(filename2, allow_pickle=True)
        graphs2 = subspace_save['graphs']
        indices2 = subspace_save['indices']
        node_principal_components_aug2_all = subspace_save['subspace_dir']

        
        index_i = index[label == k]
        k_angle_between = []
        unique = np.unique(index_i)
        for i, y in enumerate(unique):
            graph_index1, indices_index1 = graphs1[i], indices1[i]
            graph_index2, indices_index2 = graphs2[i], indices2[i]
            
            node_principal_components_aug1 = node_principal_components_aug1_all[i]
            node_principal_components_aug2 = node_principal_components_aug2_all[i]
            
            k_angle_between.append(get_angle_between_augs(node_principal_components_aug1, node_principal_components_aug2, graph_index1, graph_index2, indices_index1, indices_index2))
            
        angle_between.append(np.hstack(k_angle_between))
        
        
    print(time.time() - t)
    filename = os.path.join(save_dir, f"between_angles_layer_{args.layer}_trained_{args.trained}_aug1_{args.type_transfo}_aug2_{args.type_transfo_2}.npz")
    print(f' Saving Angle Between Augmentations Manifold: {filename}')
    np.savez_compressed(filename, output= np.hstack(angle_between), label = label)

elif args.extract_analysis == 'get_angle_between_sem_aug':
    print(f' ----------------- Angle Between: Aug1 {args.type_transfo} and Aug2 {args.type_transfo_2}  -----------------')
    data2 = np.load(os.path.join(args.log_dir, args.model , args.type_transfo_2,f'{args.model}_{args.type_transfo_2}_n_aug_50_trained_{args.trained}_ddp_rank_0.npz'), allow_pickle=True)   
    index2 = data2['indices']
    label2 = data2['labels']

    angle_between = []
    for label_k in tqdm(label_list):
        k = label_list[label_k]
        filename1 = os.path.join(save_dir, 'subspace_directions', f"subspace_dir_{k}_layer_{args.layer}_trained_{args.trained}_aug_{args.type_transfo}.npz")
        subspace_save = np.load(filename1, allow_pickle=True)
        graphs1 = subspace_save['graphs']
        indices1 = subspace_save['indices']
        node_principal_components_aug1 = subspace_save['subspace_dir']
        
        
        filename2 = os.path.join(save_dir, 'subspace_directions', f"subspace_dir_{k}_layer_{args.layer}_trained_{args.trained}_aug_{args.type_transfo_2}.npz")
        subspace_save = np.load(filename2, allow_pickle=True)
        graphs2 = subspace_save['graphs']
        indices2 = subspace_save['indices']
        node_principal_components_aug2_all = subspace_save['subspace_dir']
        

        index_i = index2[label2 == k]
        k_angle_between = []
        unique = np.unique(index_i)

        for i, y in enumerate(unique):
            graph_index2, indices_index2 = graphs2[i], indices2[i]
            node_principal_components_aug2 = node_principal_components_aug2_all[i]
            
            k_angle_between.append(get_angle_between_augs(node_principal_components_aug1, node_principal_components_aug2, graphs1, graph_index2, indices1, indices_index2))
        angle_between.append(np.hstack(k_angle_between))
        
    
    filename = os.path.join(save_dir, f"between_angles_layer_{args.layer}_trained_{args.trained}_aug1_{args.type_transfo}_aug2_{args.type_transfo_2}.npz")
    print(f' Saving Angle Between Sem. and Augs. Manifold: {filename}')
    np.savez_compressed(filename, output= np.hstack(angle_between), label = label) 

# elif args.extract_analysis == 'get_angle_between_sem': # This is correlated with imagenet accuracy
#     angle_between = []
#     for label_k in tqdm(range(len(label_list)-1)):
#         k = label_list[label_k]
#         filename1 = os.path.join(save_dir, 'subspace_directions', f"subspace_dir_{k}_layer_{args.layer}_trained_{args.trained}_aug_{args.type_transfo}.npz")
#         subspace_save = np.load(filename1, allow_pickle=True)
#         graphs1 = subspace_save['graphs']
#         indices1 = subspace_save['indices']
#         node_principal_components_aug1 = subspace_save['subspace_dir']
        
#         for label_j in range(label_k+1, len(label_list)):
#             j = label_list[label_j]
#             filename2 = os.path.join(save_dir, 'subspace_directions', f"subspace_dir_{j}_layer_{args.layer}_trained_{args.trained}_aug_{args.type_transfo}.npz")
#             subspace_save = np.load(filename2, allow_pickle=True)
#             graphs2 = subspace_save['graphs']
#             indices2 = subspace_save['indices']
#             node_principal_components_aug2 = subspace_save['subspace_dir']
#             # if j % 10 == 0:
#             #     print(f"Completed {j}/{len(label_list)}..")
            
#             angle_between.append(get_angle_between_augs(node_principal_components_aug1, node_principal_components_aug2, graphs1, graphs2, indices1, indices2))

#     filename = os.path.join(save_dir, f"between_angles_layer_{args.layer}_trained_{args.trained}_aug1_{args.type_transfo}_aug2_{args.type_transfo_2}.npz")

#     print(f' Saving Angle Between Sem. Manifold: {filename}')
#     np.savez_compressed(filename, output= np.hstack(angle_between), label = label) 

