import numpy as np
import pandas as pd
import gzip
from scipy.cluster.hierarchy import linkage
from scipy.spatial.distance import euclidean, cityblock, chebyshev
from scipy.cluster.hierarchy import fcluster
import time
import multiprocessing
import metrics

def load_data(path):
    if path.endswith('.gz'):
        with gzip.open(path, 'rt') as file:
            matrix = np.loadtxt(file)

    if path.endswith('.csv'):
        df = pd.read_csv(path)
        matrix = df.to_numpy()

    try:
        name = path.split('\\')[-1].split('_')[0]
    except:
        name = path.split('_')[0]
    return matrix, name

def run_tests(k, clusters, n, distance, data_set_name, method_name, lists_size):
    start_time = time.time()
    df = pd.DataFrame()
    df['data_set'] = [data_set_name]
    df['method'] = [method_name]
    df['k_size'] = lists_size
    df['k'] = [k]
    df['sep_min'] = [metrics.sep_min(n, clusters, k, distance)]
    df['max_avg'] = [metrics.max_avg(n, clusters, k, distance)]
    df['max_diam'] = [metrics.max_diameter(n, clusters, k, distance)]
    df['sep_avg'] = [metrics.sep_avg(n, clusters, k, distance)]
    df['cs_ratio_AV'] = df['max_avg'] / df['sep_min']
    df['cs_ratio_DM'] = df['max_diam'] / df['sep_min']
    df['time'] = [time.time() - start_time]
    return df

def run_tests_parallel(k_list_complete, k_to_clusters, n, dist_matrix, data_set_name, method_name, k_lists):
    all_data = pd.DataFrame()
    core_count = multiprocessing.cpu_count()-2
    pool = multiprocessing.Pool(processes=core_count)
    results = pool.starmap(process_k, [(k, k_to_clusters, n, dist_matrix, data_set_name, method_name, get_size(k, k_lists)) for k in k_list_complete]) 
    pool.close()
    pool.join()
    for df in results:
        all_data = pd.concat([all_data, df], ignore_index=True)
    all_data = all_data.sort_values(by=['k'])
    return all_data

def get_size(k, k_lists):
    if k in k_lists['small']:
        return 'small'
    if k in k_lists['medium']:
        return 'medium'
    if k in k_lists['large']:
        return 'large'
    return 'error'

def process_k(k, k_to_clusters, n, dist_matrix, data_set_name, method_name, list_size):

    df = run_tests(k, k_to_clusters[k], n, dist_matrix, data_set_name, method_name, list_size)
    
    print('Done k= ' , k)
    return df

def try_method(matrix, dist_matrix, data_set_name, method_name):
    n = matrix.shape[0]
    sqrt_n = int(np.sqrt(n))
    k_lists = {
        'small': [2,3,4,5,6,7,8,9,10],
        'medium': [sqrt_n-4, sqrt_n-3, sqrt_n-2, sqrt_n-1, sqrt_n, sqrt_n+1, sqrt_n+2, sqrt_n+3, sqrt_n+4],
        'large': [n//10,n//9,n//8,n//7,n//6,n//5,n//4,n//3,n//2]       
    }
    k_list_complete = k_lists['small'] + k_lists['medium'] + k_lists['large']

    linkage_matrix = linkage(matrix, method=method_name)

    k_to_clusters = {}

    for k in k_list_complete:
        distance_threshold = linkage_matrix[-(k - 1), 2] - 1e-10
        clusters = fcluster(linkage_matrix, distance_threshold, criterion='distance')
        k_to_clusters[k] = clusters

    all_data = run_tests_parallel(k_list_complete, k_to_clusters, n, dist_matrix, data_set_name, method_name, k_lists)

    return all_data   

def try_base(matrix, dist_matrix, data_set_name):
    print('\n\n Starting data set: ', data_set_name, '\n')
    methods = ['single', 'complete', 'average']#, 'ward'] 
    all_data = pd.DataFrame()
    for method in methods:
        print('Method: ', method)
        data = try_method(matrix, dist_matrix, data_set_name, method)
        all_data = pd.concat([all_data, data], ignore_index=True)

    return all_data

def get_average(all_results, path):
    results = all_results.copy()
    results.drop(columns=['time', 'k'], inplace=True)
    average_df = results.groupby(['data_set', 'method', 'k_size']).mean().reset_index()
    average_df = average_df[['data_set', 'method','k_size', 'sep_min', 'sep_avg', 'max_diam', 'max_avg', 'cs_ratio_DM', 'cs_ratio_AV']]
    average_df.to_csv(path, sep='|', index=False)
    return

def main():
    #file_paths_list = sys.argv[1:]
    file_paths_list = [r'datasets\airfoil_clusters.csv', r'datasets\banknote_clusters.csv',
                       r'datasets\collins_30_clusters.gz', r'datasets\concrete_clusters.csv',
                       r'datasets\digits_10_clusters.gz', r'datasets\geographicalmusic_clusters.csv', 
                       r'datasets\mice_8_clusters.gz', r'datasets\qsarfish_clusters.csv',
                       r'datasets\tripadvisor_clusters.csv', r'datasets\vowel_11_clusters.gz'
                       ]
    
    all_data = pd.DataFrame()

    for file_path in file_paths_list:
        try:
            start_time = time.time()
            print('Loading file: ', file_path)
            matrix, name = load_data(file_path)
            print('Loaded file: ', name)
            dist_matrix = metrics.dist_matrix_creator(matrix, euclidean)

            data = try_base(matrix, dist_matrix, name)

            data.to_csv('results\\' + name+'_results.csv', index=False, sep='|')
            all_data = pd.concat([all_data, data], ignore_index=True)
            print('Finished ', name, ' in ', time.time() - start_time)

        except Exception as e: 
            print('Error loading file: ', file_path)
            print('Error: ', e)
    
if __name__ == '__main__':

    main()
