"""
(d = n) case, with 2 servers
bcsstk2013 dataset (2003 x 2003) for ranks 10, 20, 30, 40, 50, 60.
For these settings, we split the columns into 2 servers. 
The number of columns in the coreset (per server) 
is 5 * rank, and the number of rows in the Cauchy matrix is 
either 5 * rank or 8 * rank.
"""

import numpy as np
from main_protocol import *
from dataloader import *
from utils import *
import pickle
import math
import argparse
import os

# Conducts trials for settings 1-11. Returns a list of SettingResults objects,
# along with the svd error for this rank.
def conduct_trials_bcsstk13(A, Ais, approx_rank, num_trials, cauchy_size, coreset_size):
    dataset_name = "bcsstk13s"

    # Setting 0
    setting_result = greedy_multiple_trials(num_trials=num_trials, Ais=Ais, cauchy_size=cauchy_size, coreset_size=coreset_size, approx_rank=approx_rank, setting_num=0)
    save_checkpoint(setting_result=setting_result, dataset_name=dataset_name, test_rank=approx_rank, cauchy_size=cauchy_size, setting_num=0)

    # Setting 1
    sketch_size = 3 * approx_rank
    setting = {'num_trials': num_trials, 
                'Ais': Ais,
                'cauchy_size': cauchy_size,
                'coreset_size': coreset_size,
                'sketch_size': sketch_size,
                'lewis_weight_size': approx_rank,
                'sparsity': min(sketch_size, 20), 
                'setting_num': 1}
    setting_result = run_sparse_embedding_setting(setting=setting)
    save_checkpoint(setting_result=setting_result, dataset_name=dataset_name, test_rank=approx_rank, cauchy_size=cauchy_size, setting_num=1)

    # Setting 2
    sketch_size = 5 * approx_rank
    setting = {'num_trials': num_trials, 
                'Ais': Ais,
                'cauchy_size': cauchy_size,
                'coreset_size': coreset_size,
                'sketch_size': sketch_size,
                'lewis_weight_size': approx_rank,
                'sparsity': min(sketch_size, 20), 
                'setting_num': 2}
    setting_result = run_sparse_embedding_setting(setting=setting)
    save_checkpoint(setting_result=setting_result, dataset_name=dataset_name, test_rank=approx_rank, cauchy_size=cauchy_size, setting_num=2)
    
    # Setting 3
    sketch_size = 3 * approx_rank
    setting = {'num_trials': num_trials, 
                'Ais': Ais,
                'cauchy_size': cauchy_size,
                'coreset_size': coreset_size,
                'sketch_size': sketch_size,
                'lewis_weight_size': approx_rank,
                'sparsity': min(sketch_size, 40), 
                'setting_num': 3}
    setting_result = run_sparse_embedding_setting(setting=setting)
    save_checkpoint(setting_result=setting_result, dataset_name=dataset_name, test_rank=approx_rank, cauchy_size=cauchy_size, setting_num=3)
    
    # Setting 4
    sketch_size = 5 * approx_rank
    setting = {'num_trials': num_trials, 
                'Ais': Ais,
                'cauchy_size': cauchy_size,
                'coreset_size': coreset_size,
                'sketch_size': sketch_size,
                'lewis_weight_size': approx_rank,
                'sparsity': min(sketch_size, 40), 
                'setting_num': 4}
    setting_result = run_sparse_embedding_setting(setting=setting)
    save_checkpoint(setting_result=setting_result, dataset_name=dataset_name, test_rank=approx_rank, cauchy_size=cauchy_size, setting_num=4)
    
    # Setting 5
    sketch_size = math.ceil(approx_rank/2)
    setting = {'num_trials': num_trials, 
                'Ais': Ais,
                'cauchy_size': cauchy_size,
                'coreset_size': coreset_size,
                'sketch_size': sketch_size,
                'lewis_weight_size': approx_rank,
                'sparsity': min(sketch_size, 2), 
                'setting_num': 5}
    setting_result = run_sparse_embedding_setting(setting=setting)
    save_checkpoint(setting_result=setting_result, dataset_name=dataset_name, test_rank=approx_rank, cauchy_size=cauchy_size, setting_num=5)
    
    # Setting 6
    sketch_size = math.ceil(approx_rank/3)
    setting = {'num_trials': num_trials, 
                'Ais': Ais,
                'cauchy_size': cauchy_size,
                'coreset_size': coreset_size,
                'sketch_size': sketch_size,
                'lewis_weight_size': approx_rank,
                'sparsity': min(sketch_size, 2), 
                'setting_num': 6}
    setting_result = run_sparse_embedding_setting(setting=setting)
    save_checkpoint(setting_result=setting_result, dataset_name=dataset_name, test_rank=approx_rank, cauchy_size=cauchy_size, setting_num=6)

    
    # Setting 7
    sketch_size = math.ceil(approx_rank/5)
    setting = {'num_trials': num_trials, 
                'Ais': Ais,
                'cauchy_size': cauchy_size,
                'coreset_size': coreset_size,
                'sketch_size': sketch_size,
                'lewis_weight_size': approx_rank,
                'sparsity': min(sketch_size, 2), 
                'setting_num': 7}
    setting_result = run_sparse_embedding_setting(setting=setting)
    save_checkpoint(setting_result=setting_result, dataset_name=dataset_name, test_rank=approx_rank, cauchy_size=cauchy_size, setting_num=7)


    # Setting 8
    sketch_size = math.ceil(approx_rank/2)
    setting = {'num_trials': num_trials, 
                'Ais': Ais,
                'cauchy_size': cauchy_size,
                'coreset_size': coreset_size,
                'sketch_size': sketch_size,
                'lewis_weight_size': approx_rank,
                'sparsity': min(sketch_size, 5), 
                'setting_num': 8}
    setting_result = run_sparse_embedding_setting(setting=setting)
    save_checkpoint(setting_result=setting_result, dataset_name=dataset_name, test_rank=approx_rank, cauchy_size=cauchy_size, setting_num=8)

    
    # Setting 9
    sketch_size = math.ceil(approx_rank/3)
    setting = {'num_trials': num_trials, 
                'Ais': Ais,
                'cauchy_size': cauchy_size,
                'coreset_size': coreset_size,
                'sketch_size': sketch_size,
                'lewis_weight_size': approx_rank,
                'sparsity': min(sketch_size, 5), 
                'setting_num': 9}
    setting_result = run_sparse_embedding_setting(setting=setting)
    save_checkpoint(setting_result=setting_result, dataset_name=dataset_name, test_rank=approx_rank, cauchy_size=cauchy_size, setting_num=9)

    
    # Setting 10
    sketch_size = math.ceil(approx_rank/5)
    setting = {'num_trials': num_trials, 
                'Ais': Ais,
                'cauchy_size': cauchy_size,
                'coreset_size': coreset_size,
                'sketch_size': sketch_size,
                'lewis_weight_size': approx_rank,
                'sparsity': min(sketch_size, 5), 
                'setting_num': 10}
    setting_result = run_sparse_embedding_setting(setting=setting)
    save_checkpoint(setting_result=setting_result, dataset_name=dataset_name, test_rank=approx_rank, cauchy_size=cauchy_size, setting_num=10)

if __name__ == "__main__":
    # parsing commandline argumentt 
    parser = argparse.ArgumentParser()
    parser.add_argument("--test_rank")
    parser.add_argument("--num_trials")
    args = parser.parse_args()
    test_rank_arg = int(args.test_rank)
    num_trials_arg = int(args.num_trials)

    # Ensure that the folder Results exists.
    folder_name = "Results/"
    if not os.path.isdir(folder_name):
        os.mkdir(folder_name)

    # Load the dataset.
    A = load_bcsstk13("bcsstk13.mtx")
    A_rows, A_cols = A.shape
    cutoff = A_cols // 2
    Ais = [A[:, 0:cutoff], A[:, cutoff:]]

    # Conduct the trial for cauchy size = 5 * test_rank.
    test_rank = test_rank_arg
    num_trials = num_trials_arg
    cauchy_size = 5 * test_rank
    coreset_size = 5 * test_rank
    conduct_trials_bcsstk13(A=A, Ais=Ais, approx_rank=test_rank, num_trials=num_trials, cauchy_size=cauchy_size, coreset_size=coreset_size)

    # Conduct the trial for cauchy size = 8 * test_rank.
    cauchy_size = 8 * test_rank
    conduct_trials_bcsstk13(A=A, Ais=Ais, approx_rank=test_rank, num_trials=num_trials, cauchy_size=cauchy_size, coreset_size=coreset_size)