from sklearn import metrics, preprocessing

from arguments import get_args
import run_kauri, run_douglas, run_imm, run_exkmc, run_exshallow, run_ktree, run_randomthreshold
from data import get_data
import numpy as np
import pandas as pd
import evaluate


def load_data(args):
    X, y = get_data(args.dataset, args.path_to_data)

    if args.n_clusters == -1:
        args.n_clusters = len(np.unique(y))

    if args.subset_size != 1:
        selection = np.random.choice(len(X), size=int(len(X) * args.subset_size), replace=False)
        X, y = X[selection], y[selection]

    return X, y


def main():
    args = get_args()

    print(f"Loading the dataset {args.dataset}")
    X, y = load_data(args)
    print(f"Retrieved {X.shape[0]} samples for {X.shape[1]} features")

    if args.dataset != "car_evaluation" and args.dataset != "congressional_votes":
        print(f"Scaling dataset {args.dataset}")
        X = preprocessing.StandardScaler().fit_transform(X)

    print(f"We will try to find {args.n_clusters} clusters with {args.method}")
    if args.method == "kauri":
        y_pred, tree = run_kauri.run(args, X)
        wad = evaluate.compute_kauri_wad(tree, X)
        waes = evaluate.compute_kauri_waes(tree, X)
        n_leaf = evaluate.get_tree_n_leaves(tree)
    elif args.method == "douglas":
        y_pred = run_douglas.run(args, X)
        wad, waes, n_leaf = -1, -1, -1
    elif args.method == "imm":
        y_pred, tree = run_imm.run(args, X)
        wad = evaluate.compute_exkmc_wad(tree)
        waes = evaluate.compute_exkmc_waes(tree)
        n_leaf = evaluate.get_exkmc_n_leaves(tree)
    elif args.method == "exkmc":
        y_pred, tree = run_exkmc.run(args, X)
        wad = evaluate.compute_exkmc_wad(tree)
        waes = evaluate.compute_exkmc_waes(tree)
        n_leaf = evaluate.get_exkmc_n_leaves(tree)
    elif args.method == "exshallow":
        y_pred, tree = run_exshallow.run(args, X)
        wad = evaluate.compute_exkmc_wad(tree)
        waes = evaluate.compute_exkmc_waes(tree)
        n_leaf = evaluate.get_exkmc_n_leaves(tree)
    elif args.method == "rdm":
        y_pred, tree = run_randomthreshold.run(args, X)
        wad = evaluate.compute_exkmc_wad(tree)
        waes = evaluate.compute_exkmc_waes(tree)
        n_leaf = evaluate.get_exkmc_n_leaves(tree)
    else:
        y_pred, tree = run_ktree.run(args, X)
        wad = evaluate.compute_dt_wad(tree, X)
        waes = evaluate.compute_dt_waes(tree, X)
        n_leaf = evaluate.get_tree_n_leaves(tree)

    # Compute the metrics
    print("Finished running. Computing metrics")
    found_clusters = len(np.unique(y_pred))
    ari = metrics.adjusted_rand_score(y, y_pred)
    acc = evaluate.unsupervised_accuracy(y, y_pred)
    kmeans_score = evaluate.compute_kmeans_score(X, y_pred)

    print(f"Exporting csv to {args.output_file}")
    csv = pd.DataFrame([{"ARI": ari, "FCl": found_clusters, "Method": args.method, "K": args.n_clusters,
                         "n_leaf": n_leaf, "ACC": acc, "dataset": args.dataset, "WAD": wad, "WAES": waes,
                         "KScore": kmeans_score}])
    csv.to_csv(args.output_file, index=False)


if __name__ == "__main__":
    main()
