import torch
import argparse
import os 

from gwdr.src.affinities import (
    SymmetricEntropicAffinity,
    EntropicAffinity,
    GramAffinity,
    UMAPAffinityIn)

from gwdr.data.data import load_dataset

def list_of_ints(arg):
    return list(map(int, arg.split(',')))


mode = 'SEA'
list_modes = ['PCA', 'SEA', 'EA', 'UMAP']
assert mode in list_modes
# python3 compute_sea_aff.py --dataset coil --lr 0.1 --max_iter 1000000 --device cuda:1 --perp_list 20

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset", help="the dataset on which to compute the symmetric entropic affinity",
                    type=str)
    parser.add_argument("--lr", help="lr used to compute the symmetric entropic affinity",
                    type=float, default=1e0)
    parser.add_argument("--max_iter", help="max_iter used to compute the symmetric entropic affinity",
                    type=int, default=10000)
    parser.add_argument("--device", help="device on which to compute the symmetric entropic affinity",
                    type=str, default="cpu")  
    parser.add_argument("--perp_list", help="perplexity hyperparameters used for symmetric entropic affinity",
                    type=list_of_ints, default=[20, 50, 100, 150, 200, 250, 300])    
    args = parser.parse_args()

    X, _ = load_dataset(args.dataset, device=args.device)
    for perp in args.perp_list:
        if perp == 0.:
            save_path = os.path.abspath('.') + f'/{args.dataset}/{args.dataset}_PCA.pt'
            if not os.path.exists(save_path):
                aff = GramAffinity(centering=True)
                P = aff.compute_affinity(X)
                torch.save(P.cpu(), save_path)
        else:
            if mode == 'SEA':
                aff = SymmetricEntropicAffinity(perp, max_iter=args.max_iter, lr=args.lr)
                P = aff.compute_affinity(X)
        
                save_path = os.path.abspath('.') + f'/{args.dataset}/'
                if not os.path.exists(save_path):
                    os.mkdir(save_path)
                torch.save(P.cpu(), f'{args.dataset}/{args.dataset}_{perp}.pt')
            
            elif mode == 'EA':
                aff = EntropicAffinity(perp, max_iter=args.max_iter)
                P = aff.compute_affinity(X)
        
                save_path = os.path.abspath('.') + f'/{args.dataset}/'
                if not os.path.exists(save_path):
                    os.mkdir(save_path)
                torch.save(P.cpu(), f'{args.dataset}/{args.dataset}_EA_{perp}.pt')
            
            elif mode == 'UMAP':
                aff = UMAPAffinityIn(
                    n_neighbors=perp,
                    max_iter=args.max_iter,
                    device=args.device,
                    verbose=True)
                
                P = aff.compute_affinity(X)
                
                save_path = os.path.abspath('.') + f'/{args.dataset}/'
                if not os.path.exists(save_path):
                    os.mkdir(save_path)
                
                saved_P = P.cpu()
                
                torch.save(saved_P, f'{args.dataset}/{args.dataset}_UMAP_{perp}.pt')
            
            elif mode == 'PCA':
                raise f'mode PCA not compatible with perplexity = {perp}. Set perp=0.'
            
            else:
                raise f'Unknown mode = {mode}. Set mode in {list_modes}'