import os
import pandas
import requests
from scipy.cluster.hierarchy import fcluster
from sklearn.cluster import KMeans
from sklearn.preprocessing import scale
from sklearn_extra.cluster import KMedoids
from statistics import mean
import sys
import time

from functions import *
from ip_approx import approxIP
from hard_instances import *

save_str = {
        'Adult': 'AdultData', 
        'Drug': 'DrugData', 
        'IndianLiver': 'IndianLiver',
        'Car': 'CarData',
        'BreastCancer': 'BreastCancer',
        'Hard': 'HardInstance'
        }

urls ={
        'Adult': 'https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data',
        'Drug' : 'https://archive.ics.uci.edu/ml/machine-learning-databases/00373/drug_consumption.data',
        'IndianLiver': 'https://archive.ics.uci.edu/ml/machine-learning-databases/00225/'
        + 'Indian%20Liver%20Patient%20Dataset%20(ILPD).csv',
        'Car' : 'https://archive.ics.uci.edu/ml/machine-learning-databases/car/car.data',
        'BreastCancer' : 'https://archive.ics.uci.edu/ml/machine-learning-databases/breast-cancer-wisconsin/wdbc.data'
        }

def helper(tup):
    '''
    This is a helper function that takes an output from functions.measure_ip_stability() and summarizes it

    INPUT:
    tup: output from functions.measure_ip_stability (a tuple (is_stable,nr_unstable,violation,closest_cluster))

    OUTPUT:
    nr_unstable: the number of points that are not stable
    max_violation: the maximum violation of any given point
    mean_violation: the average violation among unstable points
    all these are defined at the beginning of Section 6 in the paper
    '''

    violated = list(x for x in tup[2] if x > 1)

    nr_unstable, max_violation, mean_violation = tup[1], np.amax(tup[2]), mean((x for x in tup[2]))
    return nr_unstable, max_violation, mean_violation

def data_loader(data_set):

    if data_set == 'Hard':
        data =generate_kmeans_instance (n_alpha = 2000, alpha = 50, r = 1)
        return data, data.shape[0]

    # reading the data; the dataset is automatically downloaded if it does't exist
    if (not os.path.exists('{}.data'.format(data_set))):
        print('Data set does not exist in current folder --- have to download it', file=sys.stderr)
        r = requests.get(urls[data_set],
                         allow_redirects=True)
        if r.status_code == requests.codes.ok:
            print('Download successful\n', file=sys.stderr)
        else:
            print('Could not download the data set --- please download it manually', file=sys.stderr)
            sys.exit()
        open('{}.data'.format(data_set), 'wb').write(r.content)
        

    if data_set == 'Adult':
        n = 1000
        df = pandas.read_csv('{}.data'.format(data_set), sep=',', header=None)
        df = df[:n]
        protected_attribute = np.array(pandas.get_dummies(df.iloc[:, 9], drop_first=True).values, dtype=int).flatten()
        cont_types = np.where(df.dtypes == 'int')[0]  # =[0,2,4,10,11,12]
        df = df.iloc[:, cont_types]
        data = np.array(df.values, dtype=float)

    if data_set == 'Drug':
        df = pandas.read_csv('{}.data'.format(data_set), sep=',', header=None)
        df.drop(columns=[0], inplace=True)
        cont_types = np.where(df.dtypes == 'float')[0]
        df = df.iloc[:, cont_types]
        data = np.array(df.values, dtype=float)
        n = data.shape[0]  # n=1885

    if data_set == 'IndianLiver':
        df = pandas.read_csv('{}.data'.format(data_set), sep=',', header=None)
        df.dropna(inplace=True)
        df = pandas.get_dummies(df)
        data = np.array(df.values, dtype=float)
        n = data.shape[0]  # n=579

    if data_set == 'Car':
        df = pandas.read_csv('{}.data'.format(data_set), sep=',', header=None)
        df = pandas.get_dummies(df)
        data = np.array(df.values, dtype=float)
        n = data.shape[0] # n = 1728
        
    if data_set == 'BreastCancer':
        df = pandas.read_csv('{}.data'.format(data_set), sep=',', header=None)
        df.drop(columns=[0,1], inplace=True)
        data = np.array(df.values, dtype=float)
        n = data.shape[0]  # n=569
    
    return data,n


def kmeans_runner(data, distance_matrix,k):
    kmeans = KMeans(n_clusters=k, n_init=1).fit(data)
    centers = kmeans.cluster_centers_
    clustering = kmeans.labels_
    obj = 0
    n = data.shape[0]

    for i in range(n):
        obj = max(obj, np.linalg.norm(data[i] - centers[clustering[i]]))
    return clustering, obj

def approx_runner(data, distance_matrix, k):
    centers,clustering = approxIP(distance_matrix, k)

    # k-center objectives
    n = data.shape[0]
    if len(distance_matrix.shape) == 1:
        sq_distance_matrix = squareform(distance_matrix)
    else:
        sq_distance_matrix = (distance_matrix)
    obj = 0
    for i in range(n):
        obj = max(obj, sq_distance_matrix[i][int(clustering[i])])
    return clustering,obj

def random_runner(data, distance_matrix, k):
    clustering = np.random.randint(k,size=data.shape[0])
    clusters = list()
    obj = 0
    if len(distance_matrix.shape) == 1:
        sq_distance_matrix = squareform(distance_matrix)
    else:
        sq_distance_matrix = (distance_matrix)
    for i in range(k):
        clusters.append(list())
    for i in range(data.shape[0]):
        clusters[clustering[i]].append(i)
    for i in range(k):
        cluster = clusters[i]
        kcenter_cost = list()
        for j in range(len(cluster)):
            kcenter_cost.append(0)
        for j in range(len(cluster)):
            for k in range(j+1,len(cluster)):
                kcenter_cost[j] = max( kcenter_cost[j], sq_distance_matrix[j][k])
                kcenter_cost[k] = max( kcenter_cost[k], sq_distance_matrix[j][k])
        obj = max(obj, min(kcenter_cost))

    return np.random.randint(k,size=data.shape[0]),obj

def run_clustering(data, distance_matrix, k, cluster_runner, nr_runs=1):

    nr = np.zeros(nr_runs)
    max_viol = np.zeros(nr_runs)
    mean_viol = np.zeros(nr_runs)
    cost = np.zeros(nr_runs)
    kmeans_cost = np.zeros(nr_runs)
    kcenters_cost = np.zeros(nr_runs)
    run_time = np.zeros(nr_runs)
    
    for run in np.arange(nr_runs):
        start_time = time.time()
        clustering,kcenters_cost[run] = cluster_runner(data, distance_matrix, k)
        end_time = time.time()
        nr[run], max_viol[run], mean_viol[run] = \
                helper(measure_ip_stability(distance_matrix, clustering))
        cost[run],kmeans_cost[run] = compute_clustering_cost(distance_matrix, clustering)
        run_time[run] = end_time - start_time

    return nr, max_viol, mean_viol, cost, kmeans_cost, kcenters_cost, run_time

def print_header():
    print('data_set,algorithm,dist_function,k,#run,#unf,max_viol,mean_viol,cost,kmeans_cost,kcenters_cost,time')

def log_result(
    data_set, label, dist_function, k_range, 
    nr_runs, nr_unf, maxviol, meanviol, 
    cost, kmeans_cost, kcenters_cost, time):
    mx_unf = np.max(nr_unf, axis=0)
    mx_maxviol = np.max(maxviol, axis=0)
    mx_meanviol = np.max(meanviol, axis=0)
    mx_cost = np.max(cost, axis=0)
    mx_kmeans_cost = np.max(kmeans_cost, axis=0)
    mx_kcenters_cost = np.max(kcenters_cost, axis=0)
    mx_time = np.max(time, axis=0)
    for ell, k in enumerate(k_range):
        print('{},{},{},{},{},{},{},{},{},{},{},{}'.format(
            data_set,
            label,
            dist_function,
            k,
            nr_runs,
            mx_unf[ell],
            mx_maxviol[ell],
            mx_meanviol[ell],
            mx_cost[ell],
            mx_kmeans_cost[ell],
            mx_kcenters_cost[ell],
            mx_time[ell]
        ))


def run_experiment(data_set, dist_function, k_range, nr_runs):
    '''
    This function runs the various clustering algorithms on a given dataset and for a given distance function and a
    given range for the number of clusters k. It saves the results as a npz-file.

    INPUT:
    data_set: either 'Adult', 'Drug' or 'IndianLiver'
    dist_function: one of the metrics listed on
                    https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.distance.pdist.html
    k_range: an array providing the range for the number of clusters k
    nr_runs: how often to run k-means++ and k-medoids
    '''
    results_folder = 'Results/' + data_set
    if (not os.path.exists(results_folder)):
        if (not os.path.exists('Results')):
            os.mkdir('Results')
        os.mkdir(results_folder)

    init_start = time.time()

    data, n = data_loader(data_set)

    # normalize data
    data = scale(data)

    # computing the distance matrix
    distance_matrix = pdist(data, metric=dist_function)

    init_end = time.time()

    print('init_time', '{:.06f}'.format(init_end - init_start), file=sys.stderr)

    global nr_unf_kmeans, nr_unf_const_approx, nr_unf_random
    global maxviol_kmeans, maxviol_const_approx, maxviol_random
    global meanviol_kmeans, meanviol_const_approx, meanviol_random
    global cost_kmeans, cost_const_approx, cost_random
    global kmeans_cost_kmeans, kmeans_cost_const_approx, kmeans_cost_random
    global kcenters_cost_kmeans, kcenters_cost_const_approx, kcenters_cost_random
    global time_kmeans, time_const_approx, time_random
    ### INITIATE THE MEASURES ###
    nr_unf_kmeans = np.zeros((nr_runs, k_range.size))
    nr_unf_const_approx = np.zeros((1,k_range.size))
    nr_unf_random = np.zeros((nr_runs, k_range.size))

    ###
    maxviol_kmeans = np.zeros((nr_runs, k_range.size))
    maxviol_const_approx = np.zeros((1,k_range.size))
    maxviol_random = np.zeros((nr_runs, k_range.size))

    ###
    meanviol_kmeans = np.zeros((nr_runs, k_range.size))
    meanviol_const_approx = np.zeros((1,k_range.size))
    meanviol_random = np.zeros((nr_runs, k_range.size))

    ###
    cost_kmeans = np.zeros((nr_runs, k_range.size))
    cost_const_approx = np.zeros((1,k_range.size))
    cost_random = np.zeros((nr_runs, k_range.size))

    ###
    kmeans_cost_kmeans = np.zeros((nr_runs, k_range.size))
    kmeans_cost_const_approx = np.zeros((1,k_range.size))
    kmeans_cost_random = np.zeros((nr_runs, k_range.size))

    ###
    kcenters_cost_kmeans = np.zeros((nr_runs, k_range.size))
    kcenters_cost_const_approx = np.zeros((1,k_range.size))
    kcenters_cost_random = np.zeros((nr_runs, k_range.size))

    ###
    time_const_approx = np.zeros((1,k_range.size))
    time_kmeans = np.zeros((nr_runs, k_range.size))
    time_random = np.zeros((nr_runs, k_range.size))

    np.random.seed(12345)


    # Repeat the experiment with different values of k
    for ell, k in enumerate(k_range):

        print('k=' + str(k), file=sys.stderr)

        nr_unf_const_approx[:,ell], maxviol_const_approx[:,ell], meanviol_const_approx[:,ell], \
                cost_const_approx[:,ell], kmeans_cost_const_approx[:,ell], kcenters_cost_const_approx[:,ell], time_const_approx[:,ell] = run_clustering(data, distance_matrix, k, approx_runner)

        nr_unf_kmeans[:,ell], maxviol_kmeans[:,ell], meanviol_kmeans[:,ell], \
                cost_kmeans[:,ell], kmeans_cost_kmeans[:,ell], kcenters_cost_kmeans[:,ell], time_kmeans[:,ell] = run_clustering(data, distance_matrix, k, kmeans_runner, nr_runs)

    
    log_result(data_set, 'approx', dist_function, k_range, 1, nr_unf_const_approx, maxviol_const_approx,
            meanviol_const_approx, cost_const_approx, kmeans_cost_const_approx, kcenters_cost_const_approx, time_const_approx)

    log_result(data_set, 'kmeans++', dist_function, k_range, nr_runs, nr_unf_kmeans, maxviol_kmeans,
            meanviol_kmeans, cost_kmeans, kmeans_cost_kmeans, kcenters_cost_kmeans, time_kmeans)

    # save data
    np.savez(results_folder + '/' + save_str[data_set] + '_NrReplicates=' + str(nr_runs) + '_distance=' + dist_function,
             data_set=data_set,
             nr_runs=nr_runs,
             dist_function=dist_function,
             k_range=k_range,
             n=n,
             
             nr_unf_kmeans=nr_unf_kmeans,
             nr_unf_const_approx=nr_unf_const_approx,
             nr_unf_random = nr_unf_random,
             
             maxviol_kmeans=maxviol_kmeans,
             maxviol_const_approx=maxviol_const_approx,
             maxviol_random = maxviol_random,
             
             meanviol_kmeans=meanviol_kmeans,
             meanviol_const_approx=meanviol_const_approx,
             meanviol_random = meanviol_random,
             
             cost_kmeans=cost_kmeans,
             cost_const_approx=cost_const_approx,
             cost_random=cost_random,

             kmeans_cost_const_approx=kmeans_cost_const_approx,
             kmeans_cost_kmeans=kmeans_cost_kmeans,
             kmeans_cost_random=kmeans_cost_random,

             kcenters_cost_const_approx=kcenters_cost_const_approx,
             kcenters_cost_kmeans=kcenters_cost_kmeans,
             kcenters_cost_random=kcenters_cost_random,

             time_kmeans=time_kmeans,
             time_const_approx=time_const_approx,
             time_random=time_random
             )

def load_and_log(data_set, dist_function, k_range, nr_runs):
    results_folder = 'Results/' + data_set
    data = np.load(results_folder + '/' + save_str[data_set] + '_NrReplicates=' + str(nr_runs) + '_distance=' + dist_function)
    log_result(data['data_set'], 'approx', dist_function,k_range, 1, data['nr_unf_const_approx'], data['maxviol_const_approx'],
            data['meanviol_const_approx'], data['cost_const_approx'], data['kmeans_cost_const_approx'], data['kcenters_cost_const_approx'], data['time_const_approx'])

    log_result(data['data_set'], 'kmeans++', dist_function,k_range, 1, data['nr_unf_kmeans'], data['maxviol_kmeans'],
            data['meanviol_kmeans'], data['cost_kmeans'], data['kmeans_cost_kmeans'], data['kcenters_cost_kmeans'], data['time_kmeans'])

def main():
    k_range = np.arange(2,26)
    print_header()
    for data_set in ['Adult', 'BreastCancer', 'Car', 'Drug', 'IndianLiver']:
        run_experiment(data_set, 'euclidean', k_range, 10)
    run_experiment('Hard', 'euclidean', np.array([100,200,300,400,500,600,700,800,900,1000,1100,1200,1300,1400,1500,1600,1700,1800,1900,2000,2100,2200,2300,2400,2500, 2600, 2700, 2800, 2900, 3000]), 10)


if __name__ == '__main__':
    main()
