""" This file solves (LWMC) and (EMWC) model for each input data sample defined in the paper.
The representation matrix C is saved in the {save_path} defined in the input argument.

Input arguments:
dataset: {dataset name} = cifar_10_clip or cifar_100_clip or cifar_20_clip
save_path: {output directory path} = path of directory where output will be saved
N: an integer value. The code solved the model for the first N input data samples in CIFAR dataset
expo_weights: if this argument is defined (EWMC) is solved, else (LWMC) is solved
seed: input seed value """


import time
import sys
import argparse

import os
os.environ['KMP_WARNINGS'] = 'off' # for SSC and EnSC errors

cwd = os.getcwd()
sys.path.insert(0, cwd)

from utils.cluster_metrics import *
from utils.selfrepresentation import ElasticNetManifoldClustering
from data.datasets_sc import parse_dataset

parser = argparse.ArgumentParser(description='General Training Pipeline')
parser.add_argument('--dataset', type=str, default='cifar_10_clip')
parser.add_argument('--knn', type=int, default=100, 
                    help='number of neighbors in knn')
parser.add_argument('--method', type=str, default='ssc_spams',
                    help='base similarity')
parser.add_argument('--n_init', type=int, default=20,
                    help='number of distinct kmeans runs to average over')
parser.add_argument('--solver_type', type=str, default='lm',
                    help='eigensolver type',
                    choices=['lm', 'la', 'shift_invert'])
parser.add_argument('--extra_dim', type=int, default=0,
                    help='number extra dimensions in spectral embedding')
parser.add_argument('--eigs_tol', type=float, default=0,
                    help='tolerance for stopping eigensolver')
parser.add_argument('--no_print_aff', action='store_true',
                    help='set to not print affinity info')
parser.add_argument('--N', type=int, default=50,
                    help='number of points (for certain datasets)')
parser.add_argument('--normalize', type=str, default='none',
                    help='normalizing the dataset',
                    choices=['none', 'unit_sph', 'whiten', 'whiten_unit_sph'])
parser.add_argument('--save_path', type=str, default='./results/240209-cifar10-clip', help='feature path')
parser.add_argument('--seed', type=int, default=0, help='random seed for reproducibility')
parser.add_argument('--expo_weights', action='store_true',  help='set to define W as an exponential function of distances')

args = parser.parse_args()
print(args)


torch.manual_seed(args.seed)
np.random.seed(args.seed)

X, label, nclass = parse_dataset(args)

# this should always be true as a result of parse_datasets
assert isinstance(X, np.ndarray)

print('Data Shape:', X.shape, ' | nclass:', nclass, '\n')


### COMPUTE SIMILARITY ###

solvers = []

if 'spams' in args.method:
    algorithm = 'spams'
elif 'compare_elastic' in args.method:
    algorithm = 'compare_elastic'
elif 'elastic_cd' in args.method:
    algorithm = 'elastic_cd'
elif 'lasso_lars' in args.method:
    algorithm = 'lasso_lars'
elif 'lasso_cd' in args.method:
    algorithm = 'lasso_cd'
else:
    raise ValueError('Invalid algorithm')

# import pdb; pdb.set_trace()

active_support = 'active' in args.method

print('args.method:', args.method, ' | Algorithm:', algorithm, ' | Active Support:', active_support)

if 'ssc' in args.method and 'omp' not in args.method:
    print('Using SSC builder')
    def builder(gamma, eta):
        def _solver():
            print(f'gamma: {gamma} | eta: {eta}')
            model = ElasticNetManifoldClustering(n_clusters=nclass, algorithm=algorithm, active_support=active_support, gamma=gamma, tau=1.0, eta=eta,expW=args.expo_weights)
            model.fit_self_representation(X)
            model._representation_to_affinity()
            A = model.affinity_matrix_
            C = model.representation_matrix_
            return A, f'gam_{gamma}_eta{eta}', C
        return _solver
    gamma_lst = (50, 20)
    #eta_lst = (0.2, 2, 20)
    eta_lst = (1, 20, 100, 400)
    for gamma in gamma_lst:
        for eta in eta_lst:
            solvers.append(builder(gamma, eta))
elif 'ensc' in args.method:
    print('Using EnSC builder')
    def builder(gamma, tau, eta):
        def _solver():
            print(f'gamma: {gamma} | tau: {tau} | eta: {eta}')
            model = ElasticNetManifoldClustering(n_clusters=nclass, algorithm=algorithm, active_support=active_support, gamma=gamma, tau=tau, eta=eta, expW=args.expo_weights)
            model.fit_self_representation(X)
            model._representation_to_affinity()
            A = model.affinity_matrix_
            C = model.representation_matrix_
            return A, f'gam_{gamma}_tau_{tau}', C
        return _solver
    gamma_lst = (2, 5, 10, 20, 50)
    # gamma_lst = (2, 5, 10, 20, 50)
    tau_lst = (0.9, 0.8)
    # eta_lst = (0.2, 2, 20)
    eta_lst = (1,)
    for gamma in gamma_lst:
        for tau in tau_lst:
            for eta in eta_lst:
                solvers.append(builder(gamma, tau, eta))
else:
    raise ValueError('Invalid method')
    
filename = f'{args.save_path}/{args.dataset}/{args.method}'

if not os.path.exists(os.path.dirname(filename)):
    os.makedirs(os.path.dirname(filename))

### CLUSTER ###
FULL_START_TIME = time.time()
best_result = (-1, -1, -1, -1, -1)
for i in range(len(solvers)):
    print('--------------------------------------------')
    start_time = time.time()
    A, param_str, C = solvers[i]()
    elapsed = time.time() - start_time
    print(f'Elapsed (sec): {elapsed:.2f}')
    if not args.no_print_aff:
        print_affinity_info(A)
    results = spectral_clustering_metrics(A, nclass, label,
                        n_init=args.n_init, solver_type=args.solver_type,
                                extra_dim=args.extra_dim, tol=args.eigs_tol);
    acc_lst, nmi_lst, fd_error, nnz = results
    acc_val, nmi_val = np.mean(acc_lst), np.mean(nmi_lst)
    if acc_val > best_result[0]:
        best_result = acc_val, nmi_val, fd_error, nnz, param_str
    print(f'Writing results to {filename}')
    with open(filename+'.csv', 'a') as f:
        f.write(f'{args.method},{param_str},{acc_val:.4f},{nmi_val:.4f},{fd_error:.4f},{nnz:.1f}, {elapsed:.2f},result {i}\n')
    try:
        np.save(f'{filename}_{param_str}_A.npy', A)
        np.save(f'{filename}_{param_str}_C.npy', C)
    except:
        print(f'Error saving {filename}_{param_str}.npy')
        import pdb; pdb.set_trace()

with open(filename+'.csv', 'a') as f:
    f.write(f'{args.method},{best_result[4]},{best_result[0]:.4f},{best_result[1]:.4f},{best_result[2]:.4f},{best_result[3]:.1f},BEST\n')

print("\nFULL ELAPSED TIME:", time.time() - FULL_START_TIME)