from scipy.optimize import linear_sum_assignment
from sklearn import metrics, preprocessing

from arguments import get_args
import run_kauri, run_douglas, run_imm, run_exkmc, run_exshallow, run_ktree, run_randomthreshold
from data import get_data
import numpy as np
import pandas as pd


def load_data(args):
    X, y = get_data(args.dataset, args.path_to_data)

    if args.n_clusters == -1:
        args.n_clusters = len(np.unique(y))

    if args.subset_size != 1:
        selection = np.random.choice(len(X), size=int(len(X) * args.subset_size), replace=False)
        X, y = X[selection], y[selection]

    return X, y


def main():
    args = get_args()

    print(f"Loading the dataset {args.dataset}")
    X, y = load_data(args)
    if args.gap:
    	X = np.random.uniform(size=X.shape, low=X.min(0), high=X.max(0))
    print(f"Retrieved {X.shape[0]} samples for {X.shape[1]} features")
    

    if args.dataset != "car_evaluation" and args.dataset != "congressional_votes":
        print(f"Scaling dataset {args.dataset}")
        X = preprocessing.StandardScaler().fit_transform(X)

    print(f"We will try to find {args.n_clusters} clusters with {args.method}")
    if args.method == "kauri":
        y_pred = run_kauri.run(args, X)
    elif args.method == "douglas":
        y_pred = run_douglas.run(args, X)
    elif args.method == "imm":
        y_pred = run_imm.run(args, X)
    elif args.method == "exkmc":
        y_pred = run_exkmc.run(args, X)
    elif args.method == "exshallow":
        y_pred = run_exshallow.run(args, X)
    elif args.method == "rdm":
        y_pred = run_randomthreshold.run(args, X)
    else:
        y_pred = run_ktree.run(args, X)

    # Compute the metrics
    print(f"Finished running. Exporting predictions to {args.output_file}")
    if args.gap:
    	csv = pd.DataFrame(np.concatenate([X, np.expand_dims(y_pred,1)],axis=1), columns=[f"f_{i}" for i in range(X.shape[1])]+["y"])
    else:
    	csv = pd.DataFrame([y_pred], columns=[f"c_{i}" for i in range(len(y_pred))])
    csv.to_csv(args.output_file, index=False)


if __name__ == "__main__":
    main()
