import pandas as pd
import numpy as np
from sklearn import metrics
from os import listdir
from os.path import isfile, join, dirname
from datasets.datasets import Datasets
from torch.utils.data.dataloader import DataLoader
from metrics import cluster
from utils import utils
from analysis.plot import PlotConfusionMatrix
from collections import defaultdict
import re


def dataset_to_metrics(metrics_names, results_paths, dataset_name, root, output_path,
                       param_file=None, coil20_unprocessed=False, start_idx=0, stop_idx=0):
    utils.create_folders(output_path)

    metrics_names = metrics_names.split(":")
    dataset = Datasets(dataset=dataset_name, root_folder=root, flatten=True, coil20_unprocessed=coil20_unprocessed)

    test_loader = DataLoader(dataset.test_data, shuffle=False)

    true_labels = np.array([])
    for batch_idx, (sample, target) in enumerate(test_loader):
        true_labels = np.concatenate((true_labels, target.numpy()))

    repeats = len([f for f in listdir(results_paths)
                   if isfile(join(results_paths, f)) and not f.startswith('.') and f.endswith(".results")])
    start, stop = utils.adjust_start_stop_limits(repeats, start_idx, stop_idx)

    metrics_values = defaultdict(list)
    num_nodes = defaultdict(list)
    for i in range(start, stop):
        results_file = join(results_paths, "{0}_{1}.results".format(dataset_name.split(".")[0], i))
        print(results_file)

        results = open(results_file, 'r')
        results = results.readlines()
        found_clusters = int(results[0].split("\t")[0])
        metrics_values[dataset_name + ".n_nodes"].append(found_clusters)
        num_nodes[dataset_name].append(found_clusters)

        if found_clusters + 1 < len(results):
            results = pd.read_csv(results_file, sep="\t", skiprows=found_clusters + 1, header=None)
            evaluate(results, metrics_names, metrics_values, dataset_name,
                     index=i,
                     data_size=len(true_labels),
                     results_paths=results_paths,
                     output_path=output_path,
                     true_labels=true_labels)

    write_output(metrics_names, metrics_values, num_nodes, output_path, param_file, start_idx=start, stop_idx=stop)


def true_to_metrics(metric_name, results_paths, true_path, output_path, repeats, param_file=None):
    utils.create_folders(output_path)

    metrics_names = metric_name.split(":")

    files = [f for f in listdir(true_path) if (isfile(join(true_path, f)) and
                                               not f.startswith('.') and not f.endswith(".true"))]
    files = sorted(files)

    metrics_values = defaultdict(list)
    num_nodes = defaultdict(list)

    for file in files:
        data = utils.get_data_targets(true_path, file)
        print("{0}".format(join(results_paths, file)))

        dataset_name = file.split(".")[0]

        for i in range(repeats):
            results = open(join(results_paths, "{0}_{1}.results".format(file.split(".")[0], i)), 'r')
            results = results.readlines()
            found_clusters = int(results[0].split("\t")[0])

            metrics_values[dataset_name + ".n_nodes"].append(found_clusters)
            num_nodes[dataset_name].append(found_clusters)

            if found_clusters + 1 < len(results):
                results = pd.read_csv(join(results_paths, "{0}_{1}.results".format(file.split(".")[0], i)),
                                      sep="\t", skiprows=found_clusters + 1, header=None)

                # results = results.ix[results[len(results.columns) - 1] != 999]

                indexes = results.iloc[:, 0].to_list()
                true = list(map(int, data[indexes]))

                evaluate(results, metrics_names, metrics_values, dataset_name,
                         index=i,
                         data_size=len(data),
                         results_paths=results_paths,
                         output_path=output_path,
                         true_labels=true)

    write_output(metrics_names, metrics_values, num_nodes, output_path, param_file)


def evaluate(results, metrics_names, metrics_values, dataset_name, index, data_size,
             results_paths=None, output_path=None, true_labels=None):
    clusters = results.iloc[:, 1].to_list()

    for metric in metrics_names:

        if metric == "acc":
            predict = results.iloc[:, 2].to_list()
            corrects = metrics.accuracy_score(predict, true_labels, normalize=False)
            accuracy = float(corrects) / data_size
            metrics_values[dataset_name + "." + metric].append(accuracy)

        elif metric == 'ce':
            ce = cluster.predict_to_clustering_error(true_labels, clusters)
            metrics_values[dataset_name + "." + metric].append(ce)

        elif metric == 'nmi':
            nmi = cluster.nmi(true_labels, clusters)
            metrics_values[dataset_name + "." + metric].append(nmi)

        elif metric == 'pur':
            pur = cluster.purity(true_labels, clusters)
            metrics_values[dataset_name + "." + metric].append(pur)

        elif metric == 'cm':
            cm = cluster.predict_to_confusion(true_labels, clusters)
            cm = cluster.maximize_trace(cm)
            cm_path_to_save = join(dirname(output_path), "confusion_matrix")
            utils.create_folders(join(cm_path_to_save, "filename"))

            PlotConfusionMatrix().save_cm(join(cm_path_to_save, "{0}_{1}.png".format(dataset_name.split(".")[0], index)),
                                          cm, int(max(true_labels) + 1))
            PlotConfusionMatrix().save_txt(join(cm_path_to_save, "{0}_{1}.txt".format(dataset_name.split(".")[0], index)),
                                           cm)


def write_output(metrics_names, metrics_values, num_nodes, output_path, param_file, start_idx=0, stop_idx=0):
    if len(metrics_names) == 1 and metrics_names[0] == 'cm':
        return
    else:
        output_file = open(output_path + '.csv', 'w+')
        metrics_values = pd.DataFrame(metrics_values).fillna(0)
        num_nodes = pd.DataFrame(num_nodes).fillna(0)

        line = write_csv_header(metrics_values, num_nodes)
        write_csv_body(metrics_values, line, metrics_values.mean().values, output_file, param_file,
                       metrics_values.std().values, start_idx, stop_idx)


def write_csv_header(metrics_values, num_nodes):
    line = "max_value," + ",".join(map(str, metrics_values.max().values)) + "\n"

    n_nodes_string = []
    for metric in metrics_values.keys():
        for dataset in num_nodes.keys():
            if dataset in metric:
                n_nodes = num_nodes[dataset].values
                metrics_idxmax = np.argmax(metrics_values[metric].values)
                n_nodes_string.append(n_nodes[metrics_idxmax])
                break

    line += "num_nodes," + ",".join(map(str, n_nodes_string)) + "\n"
    line += "index_set," + ",".join(map(str, metrics_values.idxmax().values)) + "\n"

    return line


def write_csv_body(metrics_values, line, mean_value, output_file, param_file, std_value, start_idx=0, stop_idx=0):
    line += "mean_value," + ",".join(map(str, mean_value)) + "\n"
    line += "std_value," + ",".join(map(str, std_value)) + "\n\n"

    if param_file is not None:

        params = utils.read_params(param_file)
        params = utils.parameters_start_stop(params, start_idx, stop_idx)

        names = params.columns
        names = list(map(lambda x: x.strip(), names))

        line += "experiment," + ",".join(metrics_values.keys()) + "," + ",".join(names) + "\n"

        data_string = pd.concat([metrics_values, params], axis=1).to_string(header=False, index=True)

    else:
        line += "experiment," + ",".join(metrics_values.keys()) + "\n"

        data_string = metrics_values.to_string(header=False, index=True)

    line += re.sub('\s[^\S\r\n]+', ',', data_string)
    output_file.write(line)


if __name__ == '__main__':
    # ce:nmi:purmetrics_names, results_paths, dataset_name, root, output_path
    dataset_to_metrics(metrics_names="ce:nmi:pur",
                       results_paths="../results/invrelu_push_ae_mean_som_mean_a_lhs_b0_nn_sum_mean0_autoencoder30_r3_b256_0/",
                       dataset_name="mnist",
                       root="../raw-datasets/",
                       output_path="../results/invrelu_push_ae_mean_som_mean_a_lhs_b0_nn_sum_mean0_autoencoder30_r3_b256/invrelu_push_ae_mean_som_mean_a_lhs_b0_nn_sum_mean0_autoencoder30_r3_b256_0",
                       param_file="../arguments/autoencoder30_r3.lhs",
                       start_idx=0,
                       stop_idx=30)

    # true_to_metrics("acc:ce", "../results/uci-3T3F-50p-b64-nmax-short-l01", "../raw-datasets/Realdata_3T3F_Test2",
    #                 repeats=50, output_path="../results/sera_que_foi/uci-3T3F-50p-b64-nmax-short-s01",
    #                 param_file="../arguments/som_50p_nmax_short.lhs")

    # true_to_metrics("acc", "../results/uci-3T3F-50p-b64-nmax-short-s10", "../raw-datasets/Realdata_3T3F_Test2",
    #                 repeats=50, output_path="../results/uci-3T3F-50p-b64-nmax-short/uci-3T3F-50p-b64-nmax-short-s10",
    #                 param_file="../arguments/som_50p_nmax_short.lhs")

    # true_to_metrics("acc", "../results/uci-3T3F-50p-b128-nmax-short-l50", "../raw-datasets/Realdata_3T3F_Test2",
    #                repeats=50, output_path="../results/uci-3T3F-50p-b128-nmax-short/uci-3T3F-50p-b128-nmax-short-l50",
    #                param_file="../arguments/som_50p_nmax_short.lhs")

    # true_to_metrics("acc", "../results/uci-3T3F-50p-b64-nmax-short-l100", "../raw-datasets/Realdata_3T3F_Test2",
    #                 repeats=50, output_path="../results/uci-3T3F-50p-b64-nmax-short/uci-3T3F-50p-b64-nmax-short-l100",
    #                 param_file="../arguments/som_50p_nmax_short.lhs")
