import torch
import argparse
import os 
from tqdm import tqdm

from sklearn.manifold import spectral_embedding
from gwdr.src.affinities import (
    GramAffinity,
    )

from gwdr.data.data import load_dataset

perp_list = [20, 50, 100, 150, 200, 250, 300]
n_max = 250

# python3 compute_spectral_embeddings.py --dataset coil --affinity_data SymmetricEntropicAffinity
# python3 compute_spectral_embeddings.py --dataset coil --affinity_data UMAP

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset", help="the dataset on which to compute the symmetric entropic affinity",
                    type=str)
    parser.add_argument('--affinity_data', type=str, default='GramAffinity')
    
    args = parser.parse_args()
    
    X, _ = load_dataset(args.dataset)
    print('X shape:', X.shape)
    
    if args.affinity_data == 'GramAffinity':
        save_path = f'{args.dataset}/{args.dataset}_gram_spectralembeddings.pt'
        if not os.path.exists(save_path):
            aff = GramAffinity(centering=True)
            KX = aff.compute_affinity(X)
            if (KX<0).any():  # Ensuring KX has positive coefficients for performing spectral clustering
                KX -= KX.min()
            embeddings = spectral_embedding(KX.numpy(), n_components=n_max,
                                            random_state=0, drop_first=False)

            torch.save(embeddings, save_path)
        
    elif args.affinity_data == 'SymmetricEntropicAffinity':
        for perp in tqdm(perp_list, desc='computing different perplexities'):
            save_path = f'{args.dataset}/{args.dataset}_{perp}_spectralembeddings.pt'
            if not os.path.exists(save_path):
                try:    
                    KX = torch.load(f'{args.dataset}/{args.dataset}_{perp}.pt')
                except:
                    print('could not find perp = %s'%perp)
                    continue
                if (KX<0).any():  # Ensuring KX has positive coefficients for performing spectral clustering
                    KX -= KX.min()

                embeddings = spectral_embedding(KX.numpy(), n_components=n_max,
                                                random_state=0, drop_first=False)

                torch.save(embeddings, save_path)
    
    elif args.affinity_data == 'EntropicAffinity':
        for perp in tqdm(perp_list, desc='computing different perplexities'):
            save_path = f'{args.dataset}/{args.dataset}_EA_{perp}_spectralembeddings.pt'
            if not os.path.exists(save_path):
                try:    
                    KX = torch.load(f'{args.dataset}/{args.dataset}_EA_{perp}.pt')
                    print('loaded KX:', KX.shape)
                except:
                    print('could not find perp = %s'%perp)
                    continue
                if (KX<0).any():  # Ensuring KX has positive coefficients for performing spectral clustering
                    KX -= KX.min()

                embeddings = spectral_embedding(KX.numpy(), n_components=n_max,
                                                random_state=0, drop_first=False)

                torch.save(embeddings, save_path)
    
    elif args.affinity_data == 'UMAP':
        for perp in tqdm(perp_list, desc='computing different perplexities'):
            save_path = f'{args.dataset}/{args.dataset}_UMAP_{perp}_spectralembeddings.pt'
            if not os.path.exists(save_path):
                try:    
                    KX = torch.load(f'{args.dataset}/{args.dataset}_UMAP_{perp}.pt')
                
                except:
                    print('could not find perp = %s'%perp)
                    continue

                if (KX<0).any():  # Ensuring KX has positive coefficients for performing spectral clustering
                    KX -= KX.min()
                
                
                embeddings = spectral_embedding(KX.numpy(), n_components=n_max,
                                                random_state=0, drop_first=False)

                torch.save(embeddings, save_path)