#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
"""

import os
os.environ["OPENBLAS_NUM_THREADS"] = "1"
os.environ["OMP_NUM_THREADS"] = "1" # export OMP_NUM_THREADS=1
os.environ["OPENBLAS_NUM_THREADS"] = "1" # export OPENBLAS_NUM_THREADS=1
os.environ["MKL_NUM_THREADS"] = "1" # export MKL_NUM_THREADS=1
os.environ["VECLIB_MAXIMUM_THREADS"] = "1" # export VECLIB_MAXIMUM_THREADS=1
os.environ["NUMEXPR_NUM_THREADS"] = "1" # export NUMEXPR_NUM_THREADS=1

import sys
from pathlib import Path
sys.path.append(str(Path('.').absolute().parent))

from torch import optim
import numpy as np
from time import process_time as time
from sklearn.preprocessing import StandardScaler, MinMaxScaler

from data_utils import  samples_n_vertices
from metrics_utility import my_pairwise_distances, metric_gw, metric_sgw, metric_distrib_min_sse
from TransformNet import TransformNet, TransformLatenttoOrig

import argparse

#%%
parser = argparse.ArgumentParser(description='Shape : computation time evaluation')
parser.add_argument('--nproj', type=int, default=1000  )
parser.add_argument('--nproj_dist_d', type=int, default=10  )
parser.add_argument('--n_iter_inner',type=int,default=1)
parser.add_argument('--max_epoch',type=int,default=50)
parser.add_argument('--dim_latent',type=int,default=5)
parser.add_argument('--scaling',type=str, default='standard')
parser.add_argument('--data',type=str, default='shapes' )
parser.add_argument('--model',type=str, default='distrib_min_sse' )

args = parser.parse_args()
print(args)

nproj = args.nproj         # nb of projections in sse
nproj_dist_d = args.nproj_dist_d
n_iter_inner = args.n_iter_inner #num of inner iterations for distrib_min 
max_epoch = args.max_epoch # number of epochs for distrib_min_sse
dim_latent = args.dim_latent
scaling = args.scaling
data = args.data
model = args.model 

#%% other parameters
vector_n = [100, 250, 500, 1000, 1500, 2000] #number of vertices to be sampled on each shape
n_pairs = 100 # number of pairs to be considered to estimate computation time
n_samples_timing = int(np.sqrt(n_pairs)) # we will randomly select n_samples_timing shapes to evalution computation burden
n_jobs_time = 1 # here we use a single CPU to monitore average computation time of n_pairs distances
device = "cpu"
n_repeat = 10
err_gw = 0

# Init recorded performances
time_gw = np.zeros((len(vector_n), n_repeat))
time_sgw = np.zeros((len(vector_n), n_repeat))
time_distrib_min_sse = np.zeros((len(vector_n), n_repeat))


# pathes
expe = "timing"
pathdata = './data/'
filename = data
pathres='./result/'
res_filename = f"shape_{expe}_{model}_dataset_{data}_s_{scaling}_nproj_{nproj}_nproj_d_{nproj_dist_d}_nb{max_epoch}_iterinner{n_iter_inner:d}_latent{dim_latent:d}"

#%%
for i in range(len(vector_n)):
    print(f"\nn = {vector_n[i]}")
    
    # load shapes dataset
    with np.load(pathdata+filename+'.npz', allow_pickle=True) as data:
        y = data['y']
        X = data['X']
    
    # draw randomly the indices of samples used to evaluate computation time of each method
    ind_samples = np.random.choice(X.shape[0], n_samples_timing, replace=False)
    X = X[ind_samples]
    
    dim_s = X[0].shape[1] # = 3
    dim_t = dim_s 
    
    for j in range(n_repeat):
        print(f"{model} - run = {j}")
        # random sampling of n vertices on each shape
        Xred = samples_n_vertices(X, vector_n[i])
        
        # normalize each shape using either minmax or standard scaler
        if scaling == "standard":
            scaler = StandardScaler()
        elif scaling == "minmax":
            scaler = MinMaxScaler(feature_range=(-1, 1))
           
        for idx, shape in enumerate(Xred):
            Xred[idx] = scaler.fit_transform(shape)

        ## ========== GW =========
        if model == "gw":
            tic = time()
            try:
                _ = my_pairwise_distances(X=Xred, Y=Xred, metric=metric_gw, n_jobs = n_jobs_time)
                time_gw[i,j] = (time()-tic)/n_pairs
            except ZeroDivisionError as err:
                print('Handling run-time error:', err)
                err_gw += 1
                
        ## ============= Sliced GW ================
        elif model =="sgw":
            tic = time()
            _ = my_pairwise_distances(X=Xred, Y=Xred, metric=metric_sgw, n_jobs = n_jobs_time, nproj=nproj)
            time_sgw[i,j] = (time()-tic)/n_pairs

        ## =========== Distributional Min SSE ================
        elif model == "distrib_min_sse":
            transf_net = TransformNet(dim_latent).to(device)
            fp = TransformLatenttoOrig(dim_latent,dim_s).to(device)
            fq = TransformLatenttoOrig(dim_latent,dim_t).to(device)
    
            transf_net_optim = optim.Adam(transf_net.parameters(), lr=0.001, betas=(0.5, 0.999),weight_decay=0.5)
            fp_optim = optim.Adam(fp.parameters(), lr=0.001, betas=(0.5, 0.999),weight_decay=0.5)
            fq_optim = optim.Adam(fq.parameters(), lr=0.001, betas=(0.5, 0.999),weight_decay=0.5)
    
            tic = time()
            _ = my_pairwise_distances(X=Xred, Y=Xred,metric=metric_distrib_min_sse, n_jobs=1,
                                      transf_net = transf_net, s_latent2orig_net = fp, t_latent2orig_net = fq, 
                                      opt_trannet = transf_net_optim, opt_s=fp_optim, opt_t = fq_optim,
                                      dim_latent = dim_latent, nproj_dist = nproj_dist_d, 
                                      num_epochs=max_epoch, num_sup_iter = n_iter_inner, num_inf_iter = n_iter_inner)
    
            time_distrib_min_sse[i,j] = (time()-tic)/n_pairs
            
        # save intermediate results    
        np.savez(pathres+res_filename+"-partial",
                 time_gw=time_gw,
                 time_sgw=time_sgw,
                 time_distrib_min_sse = time_distrib_min_sse,
                 vector_n = vector_n,
                 n_pairs = n_pairs,
                 n_repeat = n_repeat
         )

#%% save final results
np.savez(pathres+res_filename,
         time_gw=time_gw,
         time_sgw=time_sgw,
         time_distrib_min_sse = time_distrib_min_sse,
         vector_n = vector_n,
         n_pairs = n_pairs,
         n_repeat = n_repeat
         )

# Remove the intermediate saved results
if os.path.exists(pathres+res_filename+"-partial.npz"):
    os.remove(pathres+res_filename+"-partial.npz")