#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
"""

import os
os.environ["OPENBLAS_NUM_THREADS"] = "1"
os.environ["OMP_NUM_THREADS"] = "1" # export OMP_NUM_THREADS=1
os.environ["OPENBLAS_NUM_THREADS"] = "1" # export OPENBLAS_NUM_THREADS=1
os.environ["MKL_NUM_THREADS"] = "1" # export MKL_NUM_THREADS=1
os.environ["VECLIB_MAXIMUM_THREADS"] = "1" # export VECLIB_MAXIMUM_THREADS=1
os.environ["NUMEXPR_NUM_THREADS"] = "1" # export NUMEXPR_NUM_THREADS=1

import sys
from pathlib import Path
sys.path.append(str(Path('.').absolute().parent))


from sklearn.preprocessing import StandardScaler, MinMaxScaler
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import top_k_accuracy_score

import torch
from torch import optim
import numpy as np

from data_utils import samples_n_vertices
from metrics_utility import my_pairwise_distances, metric_gw, metric_sgw, metric_risgw, metric_distrib_min_sse
from TransformNet import TransformNet, TransformLatenttoOrig

import argparse

def select_transforms(trsf, choices_trsf:list):
    return [k for k, tf in enumerate(trsf) if tf in transformations]

#%%
parser = argparse.ArgumentParser(description='Shape Invariance Experiment')
parser.add_argument('--n_vertices', type=int, default=100)
parser.add_argument('--nproj', type=int, default=1000  )
parser.add_argument('--nproj_dist_d', type=int, default=10  )
parser.add_argument('--n_iter_inner',type=int,default=2)
parser.add_argument('--max_epoch',type=int,default=50)
parser.add_argument('--dim_latent',type=int,default=5)
parser.add_argument('--max_iter_ri',type=int,default=100)
parser.add_argument('--lr_ri',type=float,default=0.01)
parser.add_argument('--scaling',type=str, default='minmax' )
parser.add_argument('--data',type=str, default='correspondence' )
parser.add_argument('--model',type=str, default='distrib_min_sse' )

args = parser.parse_args()
print(args)


n_vertices = args.n_vertices
nproj = args.nproj         # nb of projections in sse
nproj_dist_d = args.nproj_dist_d
n_iter_inner = args.n_iter_inner #num of inner iterations for distrib_min 
max_epoch = args.max_epoch # number of epochs for distrib_min_sse
dim_latent = args.dim_latent
max_iter_ri = args.max_iter_ri
lr_ri = args.lr_ri
scaling = args.scaling
data = args.data
model = args.model 

# other parameters
transformations = ["null", "isometry"]
vector_k = [1, 2] #to compute top-1 and top-2 accuracies
n_repeat = 10

expe = "InvarianceKNN"
n_jobs_pw = 3
device = 'cpu' #torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# init
allperf_dist_gw = np.empty((n_repeat, len(vector_k)))
allperf_dist_sgw = np.empty((n_repeat, len(vector_k)))
allperf_dist_risgw = np.empty((n_repeat, len(vector_k)))
allperf_distrib_min_sse = np.empty((n_repeat, len(vector_k)))

exception_gw = 0


np.random.seed(0)
torch.manual_seed(0)

# pathes
pathdata = './data/'
filename = data
pathres='./result/'
res_filename = f"shape_{expe}_{model}_dataset_{data}_vertices_{n_vertices}_s_{scaling}_nproj_{nproj}_nproj_d_{nproj_dist_d}_nb{max_epoch}_iterinner{n_iter_inner:d}_latent{dim_latent:d}"


for i in range(n_repeat):
    
    print(f"\n{model} - run = {i}")
    # load shapes dataset
    with np.load(pathdata+filename+'.npz', allow_pickle=True) as data:
        y = data['y']
        X = data['X']
        trf = data['transf']
    
    # select only the isometry transformed shapes + null transf shape
    ind = select_transforms(trf, transformations)
    X = X[ind]
    y = y[ind]
    trsf = trf[ind]
    
    # in-place sampling of n vertices on each shape
    X = samples_n_vertices(X, n_vertices)
    dim_s = X[0].shape[1] 
    dim_t = dim_s
    
    # normalize each shape using either minmax or standard scaler
    if scaling == "standard":
        scaler = StandardScaler()
    elif scaling == "minmax":
        scaler = MinMaxScaler(feature_range=(-1, 1))
    
     # normalize each shape (otherwise distance computation for GW messes up)
    for idx, shape in enumerate(X):
        X[idx] = scaler.fit_transform(shape)
        
    # ----- Compute pairwise distance matrix using according to benchmarked metrics and perform 1-NN classification ------
    knn  = KNeighborsClassifier(n_neighbors=1, metric="precomputed")
    
    ## ========== GW =========
    if model == "gw":
        try:
            mat_gw = my_pairwise_distances(X=X, Y=X, metric=metric_gw, n_jobs=n_jobs_pw)
            np.fill_diagonal(mat_gw, 1e8) # set arbritrary the diagonal elts to a large value
        
            knn.fit(np.abs(mat_gw), y)
            # top-k accuracy
            y_gw = knn.predict_proba(np.abs(mat_gw))
            for j, K in enumerate(vector_k):
                allperf_dist_gw[i, j] = top_k_accuracy_score(y, y_gw, k=K)
    
            print("Accuracy: Top 1 \t Top 2")
            print(f"\t {allperf_dist_gw[i][0]:.3f} \t\t {allperf_dist_gw[i][1]:.3f}")
            #print(f"perf = {allperf_dist_gw[i]}")
        except ZeroDivisionError as err:
            print('Handling run-time error:', err)
            exception_gw += 1
    
    ## ============= Sliced GW ================
    elif model == "sgw":
        mat_sgw = my_pairwise_distances(X=X, Y=X, metric=metric_sgw, n_jobs=n_jobs_pw, nproj=nproj)
        np.fill_diagonal(mat_sgw, 1e8)
        
        knn.fit(np.abs(mat_sgw), y)
        
        # top-k accuracy
        y_sgw = knn.predict_proba(np.abs(mat_sgw))
        
        for j in range(len(vector_k)):
            allperf_dist_sgw[i, j] = top_k_accuracy_score(y, y_sgw, k=vector_k[j])
        
        print("Accuracy: Top 1 \t Top 2")
        print(f"\t {allperf_dist_sgw[i][0]:.3f} \t\t {allperf_dist_sgw[i][1]:.3f}")
        
    ## ============= Rotation Invariant Sliced GW ================
    elif model == "risgw":
        mat_risgw = my_pairwise_distances(X=X, Y=X, metric=metric_risgw, n_jobs=n_jobs_pw, nproj=nproj, lr = lr_ri, max_iter = max_iter_ri)
        np.fill_diagonal(mat_risgw, 1e8)
        
        knn.fit(np.abs(mat_risgw), y)
        
        # top-k accuracy
        y_risgw = knn.predict_proba(np.abs(mat_risgw))
        
        for j in range(len(vector_k)):
            allperf_dist_risgw[i, j] = top_k_accuracy_score(y, y_risgw, k=vector_k[j])
        
        print("Accuracy: Top 1 \t Top 2")
        print(f"\t {allperf_dist_risgw[i][0]:.3f} \t\t {allperf_dist_risgw[i][1]:.3f}")
        

    ## =========== Distributional Min SSE ================
    elif model == "distrib_min_sse":
        transf_net = TransformNet(dim_latent).to(device)
        fp = TransformLatenttoOrig(dim_latent,dim_s).to(device)
        fq = TransformLatenttoOrig(dim_latent,dim_t).to(device)
        
        transf_net_optim = optim.Adam(transf_net.parameters(), lr=0.001, betas=(0.5, 0.999),weight_decay=0.5)
        fp_optim = optim.Adam(fp.parameters(), lr=0.001, betas=(0.5, 0.999),weight_decay=0.5)
        fq_optim = optim.Adam(fq.parameters(), lr=0.001, betas=(0.5, 0.999),weight_decay=0.5)
        
        # 1-NN
        mat_dmsse = my_pairwise_distances(X=X, Y=X, metric=metric_distrib_min_sse, n_jobs=1, 
                                          transf_net = transf_net, s_latent2orig_net = fp, t_latent2orig_net = fq, 
                                          opt_trannet = transf_net_optim, opt_s=fp_optim, opt_t = fq_optim,
                                          dim_latent = dim_latent, nproj_dist = nproj_dist_d, 
                                          num_epochs=max_epoch, num_sup_iter = n_iter_inner, 
                                          num_inf_iter = n_iter_inner)
        np.fill_diagonal(mat_dmsse, 1e8)
        
        knn.fit(np.abs(mat_dmsse), y)
        
        # accuracy
        y_dmsse = knn.predict_proba(np.abs(mat_dmsse))
        for j in range(len(vector_k)):
            allperf_distrib_min_sse[i, j] = top_k_accuracy_score(y, y_dmsse, k=vector_k[j])
        
        print("Accuracy: Top 1 \t Top 2")
        print(f"\t {allperf_distrib_min_sse[i][0]:.3f} \t\t {allperf_distrib_min_sse[i][1]:.3f}")
        
    # save intermediate results    
    np.savez(pathres+res_filename+"-partial",
         vector_k = vector_k,
         perf_gw=allperf_dist_gw,
         perf_sgw=allperf_dist_sgw,
         perf_risgw=allperf_dist_risgw,
         perf_distrib_min_sse = allperf_distrib_min_sse,
         n_vertices = n_vertices,
         n_repeat = n_repeat,
         nproj = nproj,
         nproj_dist_d = nproj_dist_d,
         max_iter_ri = max_iter_ri,
         lr_ri = lr_ri,
         max_epoch = max_epoch,
         transformations = transformations,
         n_iter_inner = n_iter_inner,
         dim_latent = dim_latent
         )
    
#%% save final results   
np.savez(pathres+res_filename,
     vector_k = vector_k,
     perf_gw=allperf_dist_gw,
     perf_sgw=allperf_dist_sgw,
     perf_risgw=allperf_dist_risgw,
     perf_distrib_min_sse = allperf_distrib_min_sse,
     n_vertices = n_vertices,
     n_repeat = n_repeat,
     nproj = nproj,
     nproj_dist_d = nproj_dist_d,
     max_iter_ri = max_iter_ri,
     lr_ri = lr_ri,
     max_epoch = max_epoch,
     transformations = transformations,
     n_iter_inner = n_iter_inner,
     dim_latent = dim_latent
     )

# Remove the intermediate saved results
if os.path.exists(pathres+res_filename+"-partial.npz"):
    os.remove(pathres+res_filename+"-partial.npz")



