""" This file constructs affinity by first defining three different commonly used strategies.
These strategies are defined in the appendix.
Next, the code defines all possible permutations of the three strategies and their subsets.
Finally, the code constructs affinity matrices from the representation matrix C using the combination of 
strategies defined above.
The code then uses k-means on each affinity matrix to cluster the images
The clustering accuracy as well as other output metrics are recorded for each affinity matrix and saved in 
the output directory in the path given by input argument {save_path}

Input arguments:
{save_path}: output directory path. Note, this path should be the same path where the output of executing
'./exps/240301-benchmark-mc.py' is saved
dataset: {dataset name} = cifar_10_clip or cifar_100_clip or cifar_20_clip
N: an integer value. The code solved the model for the first N input data samples in CIFAR dataset """


import time
import sys
import argparse
from sklearn.preprocessing import normalize
from sklearn.neighbors import kneighbors_graph
import itertools

import os
os.environ['KMP_WARNINGS'] = 'off' # for SSC and EnSC errors

cwd = os.getcwd()
sys.path.insert(0, cwd)

from utils.cluster_metrics import *
from data.datasets_sc import parse_dataset
from utils.processC import processC

parser = argparse.ArgumentParser(description='General Training Pipeline')
parser.add_argument('--dataset', type=str, default='cifar_10_clip')
parser.add_argument('--knn', type=int, default=100, 
                    help='number of neighbors in knn')
parser.add_argument('--method', type=str, default='ssc_lasso_lars_active',
                    help='base similarity')
parser.add_argument('--n_init', type=int, default=20,
                    help='number of distinct kmeans runs to average over')
parser.add_argument('--solver_type', type=str, default='lm',
                    help='eigensolver type',
                    choices=['lm', 'la', 'shift_invert'])
parser.add_argument('--extra_dim', type=int, default=0,
                    help='number extra dimensions in spectral embedding')
parser.add_argument('--eigs_tol', type=float, default=0,
                    help='tolerance for stopping eigensolver')
parser.add_argument('--no_print_aff', action='store_true', default=False,
                    help='set to not print affinity info')
parser.add_argument('--N', type=int, default=50000,
                    help='number of points (for certain datasets)')
parser.add_argument('--normalize', type=str, default='none',
                    help='normalizing the dataset',
                    choices=['none', 'unit_sph', 'whiten', 'whiten_unit_sph'])
parser.add_argument('--mp', action='store_true', default=False, help='whether to use multiprocessing')
parser.add_argument('--processes', type=int, default=4, help='number of processes if using multiprocessing')
parser.add_argument('--save_path', type=str, default='./results/240306-mc-cifar10-clip-nonorm-50k', help='feature path')
parser.add_argument('--seed', type=int, default=0, help='random seed for reproducibility')
parser.add_argument('--symmetrize', action='store_true', default=True, help='Symmetrize C?')
parser.add_argument('--normalize_aff', action='store_true', default=True, help='Normalize C?')
parser.add_argument('--connect_kNN', type=int, default=15, help='Generate graph of k-NN from C? 0 for no, k for yes')

args = parser.parse_args()
print(args)


torch.manual_seed(args.seed)
np.random.seed(args.seed)

X, label, nclass = parse_dataset(args)

# this should always be true as a result of parse_datasets
assert isinstance(X, np.ndarray)

print('Data Shape:', X.shape, ' | nclass:', nclass, '\n')

affinity_structure = ''
if args.symmetrize:
    affinity_structure=f'{affinity_structure}S'
if args.normalize_aff:
    affinity_structure=f'{affinity_structure}N'
if args.connect_kNN > 0:
    affinity_structure=f'{affinity_structure}K'
    affinity_kNeighbors=args.connect_kNN

print(affinity_structure)

solvers = []
if 'ssc' in args.method and 'omp' not in args.method:
    print('Using SSC builder')
    def builder(gamma, eta):
        def _solver():
            print(f'gamma: {gamma} | eta: {eta}')
            return f'gam_{gamma}_eta{eta}'
        return _solver
    gamma_lst = (20, 50)
    #eta_lst = (0.2, 2, 20)
    eta_lst = (1, 20, 100, 400)
    for gamma in gamma_lst:
        for eta in eta_lst:
            solvers.append(builder(gamma, eta))
elif 'ensc' in args.method:
    print('Using EnSC builder')
    def builder(gamma, tau, eta):
        def _solver():
            print(f'gamma: {gamma} | tau: {tau} | eta: {eta}')
            return f'gam_{gamma}_tau_{tau}'
        return _solver
    gamma_lst = (1, 2, 5, 10, 20, 50)
    # gamma_lst = (2, 5, 10, 20, 50)
    tau_lst = (0.9, 0.8)
    # eta_lst = (0.2, 2, 20)
    eta_lst = (1,)
    for gamma in gamma_lst:
        for tau in tau_lst:
            for eta in eta_lst:
                solvers.append(builder(gamma, tau, eta))
else:
    raise ValueError('Invalid method')
    
filename = f'{args.save_path}/{args.dataset}/{args.method}'

if not os.path.exists(os.path.dirname(filename)):
    os.makedirs(os.path.dirname(filename))




def permute(a, l, r, C,param_str, best_result): 
    if l == r: 
        param_str=f'{param_str}_{a}'
        a1 = list(a)
        print(a1)
        A = C
        A = np.abs(A)
        while a1 != []:
            if a1[0] == 'S':
                A = 0.5 * (np.abs(A)+np.abs(A.T))
            if a1[0] == 'N':
                A = np.abs(A)
                A = normalize(A, 'l2')
            if a1[0] == 'K':
                A = kneighbors_graph(A, affinity_kNeighbors, mode='connectivity', include_self=False)
                A = 0.5 * (np.abs(A)+np.abs(A.T))
            a1.pop(0)
        print('start analysis')
        results = spectral_clustering_metrics(A, nclass, label,
                            n_init=args.n_init, solver_type=args.solver_type,
                                   extra_dim=args.extra_dim, tol=args.eigs_tol)
        acc_lst, nmi_lst, fd_error, nnz = results
        acc_val, nmi_val = np.mean(acc_lst), np.mean(nmi_lst)
        if acc_val > best_result[0]:
            best_result = acc_val, nmi_val, fd_error, nnz, param_str
        print(f'Writing results to {filename}')
        with open(filename+'.csv', 'a') as f:
            f.write(f'{args.method},{param_str},{acc_val:.4f},{nmi_val:.4f},{fd_error:.4f},{nnz:.1f}\n')
        
    else: 
        for i in range(l, r): 
            a1 = list(a)
            a1[l], a1[i] = a1[i], a1[l]
            a = ''.join(a1) 
            best_result = permute(a, l+1, r, C, param_str, best_result)
            a1 = list(a)
            a1[l], a1[i] = a1[i], a1[l]
            a = ''.join(a1) 
    return best_result

def permutate_all_substrings(text, C, param_str, best_result):
    permutations = []
    # All possible substring lengths
    for length in range(1, len(text)+1):
        # All permutations of a given length
        for permutation in itertools.permutations(text, length):
            # itertools.permutations returns a tuple, so join it back into a string
            a = "".join(permutation)
            param_str1=f'{param_str}_{a}'
            a1 = list(a)
            print(a1)
            A = C
            A = np.abs(A)
            while a1 != []:
                if a1[0] == 'S':
                    A = 0.5 * (np.abs(A)+np.abs(A.T))
                if a1[0] == 'N':
                    A = normalize(A, 'l2')
                if a1[0] == 'D':
                    pass #add code to proj on set of Doubly Stochastic matrices
                if a1[0] == 'L':
                    # A = normalize(A, 'l2')
                    A = processC(A, 0.90)
                    # A = 0.5 * (np.abs(A) + np.abs(A.T))
                if a1[0] == 'K':
                    A = kneighbors_graph(A, affinity_kNeighbors, mode='connectivity', include_self=False)
                    # A = 0.5 * (np.abs(A)+np.abs(A.T))
                a1.pop(0)
            print('start analysis')
            results = spectral_clustering_metrics(A, nclass, label,
                                n_init=args.n_init, solver_type=args.solver_type,
                                    extra_dim=args.extra_dim, tol=args.eigs_tol)
            acc_lst, nmi_lst, fd_error, nnz = results
            acc_val, nmi_val = np.mean(acc_lst), np.mean(nmi_lst)
            if acc_val > best_result[0]:
                best_result = acc_val, nmi_val, fd_error, nnz, param_str1
            print(f'Writing results to {filename}')
            with open(filename+'.csv', 'a') as f:
                f.write(f'{args.method},{param_str1},{acc_val:.4f},{nmi_val:.4f},{fd_error:.4f},{nnz:.1f}\n')
    return best_result


for i in range(len(solvers)):
    print('--------------------------------------------')
    start_time = time.time()
    param_str = solvers[i]()
    C = np.load(f'{filename}_{param_str}_C.npy',allow_pickle=True)
    best_result = (-1, -1, -1, -1, -1)
    best_result = permutate_all_substrings(affinity_structure, C, param_str, best_result)
    # best_result = permute(affinity_structure,0,len(affinity_structure), C, param_str, best_result)
    with open(filename+'.csv', 'a') as f:
        f.write(f'{args.method},{best_result[4]},{best_result[0]:.4f},{best_result[1]:.4f},{best_result[2]:.4f},{best_result[3]:.1f},BEST\n')  
