
import os
import torch
import numpy as np
import sklearn
from sys import argv
from time import time
from omegaconf import DictConfig
from tqdm import tqdm 

from sklearn.metrics import silhouette_score
from sklearn.manifold import trustworthiness

from gwdr.src.affinities import GramAffinity
from gwdr.src.dimension_reduction import AffinityMatcher
from gwdr.src.dr_grid import DR_grid, trusthworthiness_grid, silhouette_grid

from gwdr.data.data import load_dataset

#%% hyperparameters

# python3 main_grid.py 'mnist' cuda:1
# python3 main_grid.py 'fashion_mnist' cuda:1
# python3 main_grid.py 'coil' cuda:1
# python3 main_grid.py 'snareseq1' cuda:1
# python3 main_grid.py 'snareseq2' cuda:1
# python3 main_grid.py 'citeseq' cuda:1
# python3 main_grid.py 'zeisel' cuda:1

try:
    dataset = str(argv[1])
    device = str(argv[2])
except:
    dataset = 'mnist'
    device = 'cpu'

output_sam_list = [10 * i for i in range(1, 10, 1)]
seed_list = [0, 1, 2, 3, 4]

#%% settings

dtype = torch.float64

cfg_dict = {
    'marginal_loss' : False,
    'affinity_data' : 'GramAffinity',
    'affinity_embedding' : 'GramAffinity',
    'loss_fun' : 'square_loss',
    'max_iter' : 2000,
    'tol' : 1e-6,
    'dataset' : dataset,
    'device' : device,
    'T_init' : 'random'
    }

abspath = os.path.abspath('.')
grid_res = abspath + '/res_grid/%s/'%dataset
print('grid_res:', grid_res)
os.makedirs(grid_res, exist_ok=True)


#%%

if __name__ == "__main__":
    
    X, Y = load_dataset(dataset, device=device)
    X_cpu, Y_cpu = X.cpu(), Y.cpu()

    cfg = DictConfig(cfg_dict)
                                        
    affinity_data = GramAffinity(centering=True)
    affinity_embedding = GramAffinity()

    # benchmark to classical DR
    res_classical_DR_repo = grid_res + f'DR_CX{cfg.affinity_data}_CZ{cfg.affinity_embedding}/'
    
    if not os.path.exists(res_classical_DR_repo):
        os.mkdir(res_classical_DR_repo)
        
        res_classical_DR = {'silhouette':[], 'trust':[], 'time':[], 'seed':[]}
        for seed in tqdm(seed_list, desc='benchmark classical DR'):
            torch.manual_seed(seed)
            np.random.seed(seed)
            start = time()
            AM = AffinityMatcher(
                affinity_data, affinity_embedding, init="random", loss_fun='square_loss',
                lr=1e0, max_iter=1000, verbose=False)
            Z_pca = AM.fit_transform(X).cpu()
            end = time()
            trust_pca = trustworthiness(X_cpu, Z_pca)
            ss_pca = silhouette_score(Z_pca, Y_cpu)
        
            res_classical_DR['silhouette'].append(ss_pca)
            res_classical_DR['trust'].append(trust_pca)
            res_classical_DR['time'].append(end - start)
            res_classical_DR['seed'].append(seed)
        
        torch.save(res_classical_DR, res_classical_DR_repo +'results.pt')
    
    # setup benchmark with srGW projections on grids

    h = torch.ones(X.shape[0], dtype=dtype, device=device)
    
    res_srgw_cg_path = grid_res + f'srGW_CG_CX{cfg.affinity_data}_CZ{cfg.affinity_embedding}/'
    res_srgw_md_path = grid_res + f'srGW_MD_CX{cfg.affinity_data}_CZ{cfg.affinity_embedding}/'
    res_esrgw_md_path = grid_res + f'esrGW_MD_CX{cfg.affinity_data}_CZ{cfg.affinity_embedding}/'
    
    # validate 10 successive values of the entropic hyperparameter
    # epsilon in e.g {1, 5, 10, 50, 100, 500, 100 ...}
    # where initial_epsilon is detected as the first epsilon where algorithms are stable
    
    def incr_epsilon(eps):
        if np.log10(eps) == np.floor(np.log10(eps)):
            eps *= 5 # power of ten 
        else:
            eps *= 2
        return eps

    for pixels_per_row in output_sam_list:
        
        print('--- pixels_per_row: ', pixels_per_row)
        cfg.pixels_per_row = pixels_per_row 
        # setup initial random transport plan to initialize all srGW solvers
        
        list_T0 = []
        for seed in seed_list:
            torch.manual_seed(seed)
            T0_random = torch.rand(X.shape[0], cfg.pixels_per_row ** 2, dtype=X.dtype, device=device)
            T0_random /= T0_random.sum(-1, keepdim=True)
            T0_random *= h[:,None]
            list_T0.append(T0_random.clone())
        
        # compute Kx to setup vmax
        Kx_ = affinity_data.compute_affinity(X)
        X = X - X.mean()
        Kx = torch.stack([X, X], 0) # low-rank representation to speed-up computation
        
        Cx_vmax = torch.abs(Kx.min() - Kx.max()).item() / 2.
        
        # benchmark for grid using conditional gradient solver
        for alpha_vmax in tqdm([0.25, 0.5, 0.75, 1.], desc='(CG) srGW - validate vmax'):
            cg_subrepo = '/px%s_alphavmax%s/'%(
                pixels_per_row, str(alpha_vmax).replace('.', ''))
            os.makedirs(res_srgw_cg_path + cg_subrepo, exist_ok=True)
            
            res_srgw_cg = {'loss':[], 'silhouette':[], 'trust':[], 'time':[], 'seed':[]}
            list_T = []
            vmax = np.round(Cx_vmax * alpha_vmax)
            # implies that CZ in [-vmax, vmax]
            for seed in seed_list:
                start = time()
                (T, log_), Z, Kz = DR_grid(
                    X, Kx, affinity_data, affinity_embedding, h=h, vmax=vmax, epsilon=0., loss_fun=cfg.loss_fun,
                    max_iter=cfg.max_iter, tol=cfg.tol, pixels_per_row=pixels_per_row, init=list_T0[seed],
                    objective='exact', marginal_loss=cfg.marginal_loss, log=True, return_grid=True)
                end = time()
                T_cpu, Z_cpu = T.cpu(), Z.cpu()
                srgw_loss = log_['srgw_dist'].item()
                trust = trusthworthiness_grid(T_cpu, Z_cpu, X_cpu)
                ss = silhouette_grid(T_cpu, Z_cpu, Y_cpu)
                
                res_srgw_cg['loss'].append(srgw_loss)
                res_srgw_cg['silhouette'].append(ss)
                res_srgw_cg['trust'].append(trust)
                res_srgw_cg['time'].append(end - start)
                res_srgw_cg['seed'].append(seed)
            
                list_T.append(T_cpu)
            
            torch.save(res_srgw_cg, res_srgw_cg_path + cg_subrepo +'/results.pt')
            torch.save(list_T, res_srgw_cg_path + cg_subrepo +'/list_T.pt')
            torch.save(cfg_dict, res_srgw_cg_path + cg_subrepo +'/config.pt')
            
        # benchmark for grid using mirror descent with exact srgw objective
        for alpha_vmax in tqdm([0.25, 0.5, 0.75, 1.], desc='(MD) srGW - validate vmax'):
            initial_epsilon = None # it is actually sensitive alpha_max ...
            
            md_subrepo = '/px%s_alphavmax%s/'%(
                pixels_per_row, str(alpha_vmax).replace('.', ''))
            os.makedirs(res_srgw_md_path + md_subrepo, exist_ok=True)
            
            vmax = np.round(Cx_vmax * alpha_vmax)
            # implies that CZ in [-vmax, vmax]
            if initial_epsilon is None:
                print('looking for a proper epsilon')
                # test whether an initial epsilon values leads to convergence
                # else increase epsilon value
                local_seed = 0
                epsilon = 1.
                while initial_epsilon is None:
                    
                    (T, log_), _, _ = DR_grid(
                        X, Kx, affinity_data, affinity_embedding, h=h, vmax=vmax, epsilon=epsilon, loss_fun=cfg.loss_fun,
                        max_iter=cfg.max_iter, tol=cfg.tol, pixels_per_row=pixels_per_row, init=list_T0[local_seed],
                        objective='exact', marginal_loss=cfg.marginal_loss, log=True, return_grid=True)
                    
                    if torch.isnan(log_['srgw_dist']):
                        epsilon = incr_epsilon(epsilon)
                    
                    else:
                        initial_epsilon = epsilon
                print('found a proper initial epsilon:', initial_epsilon)

            if not initial_epsilon is None:
                epsilon = initial_epsilon
                for k in range(10):
                        
                    res_srgw_md = {'loss':[], 'silhouette':[], 'trust':[], 'time':[], 'seed':[]}
                    list_T = []
                        
                    for seed in seed_list:
                        start = time()
                        (T, log_), Z, Kz = DR_grid(
                            X, Kx, affinity_data, affinity_embedding, h=h, vmax=vmax, epsilon=epsilon, loss_fun=cfg.loss_fun,
                            max_iter=cfg.max_iter, tol=cfg.tol, pixels_per_row=pixels_per_row, init=list_T0[seed],
                            objective='exact', marginal_loss=cfg.marginal_loss, log=True, return_grid=True)
                        end = time()
                        T_cpu, Z_cpu = T.cpu(), Z.cpu()
                        srgw_loss = log_['srgw_dist'].item()
                        assert not torch.isnan(log_['srgw_dist'])
                        trust = trusthworthiness_grid(T_cpu, Z_cpu, X_cpu)
                        ss = silhouette_grid(T_cpu, Z_cpu, Y_cpu)
                        
                        res_srgw_md['loss'].append(srgw_loss)
                        res_srgw_md['silhouette'].append(ss)
                        res_srgw_md['trust'].append(trust)
                        res_srgw_md['time'].append(end - start)
                        res_srgw_md['seed'].append(seed)
                    
                        list_T.append(T_cpu)
                    
                    torch.save(res_srgw_md, res_srgw_md_path + md_subrepo +'/eps%s_results.pt'%epsilon)
                    torch.save(list_T, res_srgw_md_path + md_subrepo +'/eps%s_list_T.pt'%epsilon)
                    torch.save(cfg_dict, res_srgw_md_path + md_subrepo +'/config.pt')
                    
                    epsilon = incr_epsilon(epsilon)
        
        # benchmark for grid using mirror descent with * entropic * srgw objective
        for alpha_vmax in tqdm([0.25, 0.5, 0.75, 1.], desc='(MD) srGW entropic - validate vmax'):
            initial_epsilon = None
            
            emd_subrepo = '/px%s_alphavmax%s/'%(
                pixels_per_row, str(alpha_vmax).replace('.', ''))
            os.makedirs(res_esrgw_md_path + emd_subrepo, exist_ok=True)
            
            vmax = np.round(Cx_vmax * alpha_vmax)
            
            if initial_epsilon is None:
                print('looking for a proper epsilon')
                # test whether an initial epsilon values leads to convergence
                # else increase epsilon value
                # it is quite more sensitive with the entropic srGW objective
                # so we check across all seeds
                epsilon = 1.
                while initial_epsilon is None:
                    
                    for seed in seed_list:
                        (T, log_), _, _ = DR_grid(
                            X, Kx, affinity_data, affinity_embedding, h=h, vmax=vmax, epsilon=epsilon, loss_fun=cfg.loss_fun,
                            max_iter=cfg.max_iter, tol=cfg.tol, pixels_per_row=pixels_per_row, init=list_T0[local_seed],
                            objective='entropic', marginal_loss=cfg.marginal_loss, log=True, return_grid=True)
                        
                        if torch.isnan(log_['srgw_dist']):
                            epsilon = incr_epsilon(epsilon)
                            break
                        elif seed == seed_list[-1]: # all went well
                            initial_epsilon = epsilon
                        else:
                            continue
                print('found a proper initial epsilon:', initial_epsilon)
            
            if not initial_epsilon is None:
                epsilon = initial_epsilon
            
                for k in range(10):
                        
                    res_esrgw_md = {'loss':[], 'silhouette':[], 'trust':[], 'time':[], 'seed':[]}
                    list_T = []
                        
                    for seed in seed_list:
                        start = time()
                        (T, log_), Z, Kz = DR_grid(
                            X, Kx, affinity_data, affinity_embedding, h=h, vmax=vmax, epsilon=epsilon, loss_fun=cfg.loss_fun,
                            max_iter=cfg.max_iter, tol=cfg.tol, pixels_per_row=pixels_per_row, init=list_T0[seed],
                            objective='entropic', marginal_loss=cfg.marginal_loss, log=True, return_grid=True)
                        end = time()
                        T_cpu, Z_cpu = T.cpu(), Z.cpu()
                        srgw_loss = log_['srgw_dist'].item()
                        #assert not torch.isnan(log_['srgw_dist'])
                        
                        trust = trusthworthiness_grid(T_cpu, Z_cpu, X_cpu)
                        ss = silhouette_grid(T_cpu, Z_cpu, Y_cpu)
                        
                        res_esrgw_md['loss'].append(srgw_loss)
                        res_esrgw_md['silhouette'].append(ss)
                        res_esrgw_md['trust'].append(trust)
                        res_esrgw_md['time'].append(end - start)
                        res_esrgw_md['seed'].append(seed)
                    
                        list_T.append(T_cpu)
                    
                    torch.save(res_esrgw_md, res_esrgw_md_path + emd_subrepo + '/eps%s_results.pt'%epsilon)
                    torch.save(list_T, res_esrgw_md_path + emd_subrepo + '/eps%s_list_T.pt'%epsilon)
                    torch.save(cfg_dict, res_esrgw_md_path + emd_subrepo + '/config.pt')
                    
                    epsilon = incr_epsilon(epsilon)
