"""
isolet dataset (6238 x 617) for ranks 10, 20, 30, 40, 50, 60.
We again split the columns into 2 servers. The number of columns 
in the coreset (per server) is (ceiling of 2.5 * rank). The number 
of rows in the Cauchy matrix is either 5 * rank, 8 * rank or 10 * rank.
"""

import numpy as np
from main_protocol import *
from dataloader import *
from utils import *
import pickle
import math
import argparse

# Conducts trials for settings 1-9. Returns a list of SettingResults objects,
# along with the svd error for this rank.
def conduct_trials_isolet_row_large(A, Ais, approx_rank, num_trials, cauchy_size, coreset_size):
    dataset_name = "isolet_row_large"

    # Setting 0
    setting_result = greedy_multiple_trials(num_trials=num_trials, Ais=Ais, cauchy_size=cauchy_size, coreset_size=coreset_size, approx_rank=approx_rank, setting_num=0)
    save_checkpoint(setting_result=setting_result, dataset_name=dataset_name, test_rank=approx_rank, cauchy_size=cauchy_size, setting_num=0)

    # Setting 1
    sketch_size = 5 * approx_rank
    setting = {'num_trials': num_trials, 
                'Ais': Ais,
                'cauchy_size': cauchy_size,
                'coreset_size': coreset_size,
                'sketch_size': sketch_size,
                'lewis_weight_size': approx_rank,
                'sparsity': min(sketch_size, 20), 
                'setting_num': 1}
    setting_result = run_sparse_embedding_setting(setting=setting)
    save_checkpoint(setting_result=setting_result, dataset_name=dataset_name, test_rank=approx_rank, cauchy_size=cauchy_size, setting_num=1)

    # Setting 2
    sketch_size = 8 * approx_rank
    setting = {'num_trials': num_trials, 
                'Ais': Ais,
                'cauchy_size': cauchy_size,
                'coreset_size': coreset_size,
                'sketch_size': sketch_size,
                'lewis_weight_size': approx_rank,
                'sparsity': min(sketch_size, 20), 
                'setting_num': 2}
    setting_result = run_sparse_embedding_setting(setting=setting)
    save_checkpoint(setting_result=setting_result, dataset_name=dataset_name, test_rank=approx_rank, cauchy_size=cauchy_size, setting_num=2)

    # Setting 3
    sketch_size = 5 * approx_rank
    setting = {'num_trials': num_trials, 
                'Ais': Ais,
                'cauchy_size': cauchy_size,
                'coreset_size': coreset_size,
                'sketch_size': sketch_size,
                'lewis_weight_size': approx_rank,
                'sparsity': min(sketch_size, 40), 
                'setting_num': 3}
    setting_result = run_sparse_embedding_setting(setting=setting)
    save_checkpoint(setting_result=setting_result, dataset_name=dataset_name, test_rank=approx_rank, cauchy_size=cauchy_size, setting_num=3)


    # Setting 4
    sketch_size = 8 * approx_rank
    setting = {'num_trials': num_trials, 
                'Ais': Ais,
                'cauchy_size': cauchy_size,
                'coreset_size': coreset_size,
                'sketch_size': sketch_size,
                'lewis_weight_size': approx_rank,
                'sparsity': min(sketch_size, 40), 
                'setting_num': 4}
    setting_result = run_sparse_embedding_setting(setting=setting)
    save_checkpoint(setting_result=setting_result, dataset_name=dataset_name, test_rank=approx_rank, cauchy_size=cauchy_size, setting_num=4)

    # Setting 5
    sketch_size = math.ceil(approx_rank/3)
    setting = {'num_trials': num_trials, 
                'Ais': Ais,
                'cauchy_size': cauchy_size,
                'coreset_size': coreset_size,
                'sketch_size': sketch_size,
                'lewis_weight_size': approx_rank,
                'sparsity': min(sketch_size, 2), 
                'setting_num': 5}
    setting_result = run_sparse_embedding_setting(setting=setting)
    save_checkpoint(setting_result=setting_result, dataset_name=dataset_name, test_rank=approx_rank, cauchy_size=cauchy_size, setting_num=5)


    # Setting 6
    sketch_size = math.ceil(approx_rank/5)
    setting = {'num_trials': num_trials, 
                'Ais': Ais,
                'cauchy_size': cauchy_size,
                'coreset_size': coreset_size,
                'sketch_size': sketch_size,
                'lewis_weight_size': approx_rank,
                'sparsity': min(sketch_size, 2), 
                'setting_num': 6}
    setting_result = run_sparse_embedding_setting(setting=setting)
    save_checkpoint(setting_result=setting_result, dataset_name=dataset_name, test_rank=approx_rank, cauchy_size=cauchy_size, setting_num=6)


    # Setting 7
    sketch_size = math.ceil(approx_rank/3)
    setting = {'num_trials': num_trials, 
                'Ais': Ais,
                'cauchy_size': cauchy_size,
                'coreset_size': coreset_size,
                'sketch_size': sketch_size,
                'lewis_weight_size': approx_rank,
                'sparsity': min(sketch_size, 5), 
                'setting_num': 7}
    setting_result = run_sparse_embedding_setting(setting=setting)
    save_checkpoint(setting_result=setting_result, dataset_name=dataset_name, test_rank=approx_rank, cauchy_size=cauchy_size, setting_num=7)

    # Setting 8
    sketch_size = math.ceil(approx_rank/5)
    setting = {'num_trials': num_trials, 
                'Ais': Ais,
                'cauchy_size': cauchy_size,
                'coreset_size': coreset_size,
                'sketch_size': sketch_size,
                'lewis_weight_size': approx_rank,
                'sparsity': min(sketch_size, 5), 
                'setting_num': 8}
    setting_result = run_sparse_embedding_setting(setting=setting)
    save_checkpoint(setting_result=setting_result, dataset_name=dataset_name, test_rank=approx_rank, cauchy_size=cauchy_size, setting_num=8)

if __name__ == "__main__":
    # parsing commandline argumentt 
    parser = argparse.ArgumentParser()
    parser.add_argument("--test_rank")
    parser.add_argument("--num_trials")
    args = parser.parse_args()
    test_rank_arg = int(args.test_rank)
    num_trials_arg = int(args.num_trials)

    # Ensure that the folder Results exists.
    folder_name = "Results/"
    if not os.path.isdir(folder_name):
        os.mkdir(folder_name)

    # Load the dataset.
    A = load_isolet("isolet1+2+3+4.csv")
    A_rows, A_cols = A.shape
    cutoff = A_cols // 2
    Ais = [A[:, 0:cutoff], A[:, cutoff:]]

    # Fix parameters.
    test_rank = test_rank_arg
    num_trials = num_trials_arg
    coreset_size = math.ceil(2.5 * test_rank)
    cauchy_sizes = [5 * test_rank, 8 * test_rank, 10 * test_rank]
    for cauchy_size in cauchy_sizes:
        # Perform the experiment with these parameters and pickle the results.
        test_rank_result_list, test_rank_svd_error = conduct_trials_isolet_row_large(A=A, Ais=Ais, approx_rank=test_rank, num_trials=num_trials, cauchy_size=cauchy_size, coreset_size=coreset_size)
        results_object = (test_rank_result_list, test_rank_svd_error)
        filename = "Results/isolet_row_large_rank_%d_results_cauchy_size_%d" % (test_rank, cauchy_size)
        with open(filename, "wb") as outfile:
            pickle.dump(results_object, outfile)