from l1_solver import *
from dataloader import *
from main_protocol import *
import math
import time
import os
import json
import matplotlib.pyplot as plt
import argparse

"""
k is the number of new columns being selected.
The columns of B, together with the subset S,
are used to fit the columns of A. The columns
of B are greedily selected. This is done in
the l_2 norm. When this algorithm is being used
as a subroutine in the distributed setting, a
PCP-sketch of A will be given in the place of A.

If desired, we could also use the lazier-than-lazy
greedy algorithm (also due to [Altschuler, et al.])
where instead of checking all columns of B at every
iteration, we instead check a subset of size
(d/k * log(1/delta)), where d is the number of columns
of B. Decreasing delta improves the approximation
guarantee as discussed in Theorem 5 of that work.

S is None on the first round of dist_greedy, and
this also works in the case where we just end up
doing one round of dist_greedy.
"""
def greedy_one_round_single_server(A, B, S, k, lazier_than_lazy=False, delta=0.005):
    n, _ = A.shape
    _, num_cols = B.shape

    result_subset = np.zeros((n, k))
    if S == None:
        offset = 0
    else:
        _, offset = S.shape

    # Add k columns of B to result_subset that fit A best.
    for iteration in range(k):
        if S == None:
            evaluation_subset = result_subset
        else:
            evaluation_subset = np.hstack((S, result_subset))
        
        current_right_factor = np.linalg.pinv(evaluation_subset) @ A
        best_l2_error = np.linalg.norm(evaluation_subset @ current_right_factor - A)
        best_column_index = -1

        # Indices to check - if lazier_than_lazy is True,
        # then we only check a random subset of indices, and
        # otherwise we check all indices.
        check_indices = np.arange(start=0, stop=num_cols, step=1)
        if lazier_than_lazy:
            num_samples = math.ceil(num_cols/k * np.log(1/delta))
            check_indices = np.random.choice(check_indices, size=num_samples)
        
        for column in check_indices:
            # Test if this column is better than the running min
            evaluation_subset[:, offset + iteration] = B[:, column]
            new_right_factor = np.linalg.pinv(evaluation_subset) @ A
            new_l2_error = np.linalg.norm(evaluation_subset @ new_right_factor - A)
            if new_l2_error < best_l2_error:
                best_l2_error = new_l2_error
                best_column_index = column
        
        # Update result subset
        result_subset[:, iteration] = B[:, best_column_index]

    #print(result_subset.shape)

    return result_subset


"""
Creates a Projection-cost-preserving sketch of A
to send to all of the servers. This is the PCPS
found by [Cohen, et al. 2015] and used for distributed
greedy CSS by [Altschuler, et al., 2016]
"""
def create_pcp_sketch(A, num_cols_in_sketch=300):
    _, A_cols = A.shape
    sketch_shape = (A_cols, num_cols_in_sketch)
    sketch = np.random.randint(low=0, high=2, size=sketch_shape) # entries 0/1
    sketch = 2 * sketch - 1
    sketch = 1/np.sqrt(num_cols_in_sketch) * sketch
    return (A @ sketch)

"""
Performs one round of the DistGreedy algorithm
by [Altschuler, et al., 2016]. Here, S is the
subset of columns of the A_i (equivalently A)
which has been aggregated over previous rounds
(or None if this is the first round).

If random=True, then every iteration, the columns 
are partitioned randomly between servers (columns 
of A are shuffled). Otherwise, A_i is specified, 
and the columns are not shuffled.

Returns the selected subset (together with S if
S is not None) along with the work and span of this
round of dist_greedy.
"""
def dist_greedy_one_round(A, num_servers, S, k, A_i=None, random=True, num_cols_in_pcps=300, lazier_than_lazy=False):

    # Shuffle columns and afterwards, partition by
    # splitting the columns evenly between the servers.
    # For the purpose of work/span measurement, this is
    # not included.
    if random:
        np.random.shuffle(A.T)
        A_i = []
        _, num_cols = A.shape
        cols_starting_point = 0
        for i in range(num_servers):
            new_cols_starting_point = ((i + 1) * num_cols)//num_servers
            column_group = A[:, cols_starting_point:new_cols_starting_point]
            A_i.append(column_group)
            cols_starting_point = new_cols_starting_point
    else:
        assert(A_i != None)
    
    work_start = time.process_time()
    span = 0
    
    # Compute PCPs and send to each server.
    pcp_start = time.process_time()
    A_pcps = create_pcp_sketch(A, num_cols_in_sketch=num_cols_in_pcps)
    pcp_time = time.process_time() - pcp_start
    span += pcp_time

    # Run greedy on each server. If S is not equal to None, then
    # these will return the subsets which work best *with S*. 
    server_subsets = []
    max_greedy_time = 0
    for i in range(num_servers):
        greedy_start = time.process_time()
        #print("server %s" % i)
        subset_i = greedy_one_round_single_server(A=A_pcps, B=A_i[i], S=S, k=k, lazier_than_lazy=lazier_than_lazy)
        server_subsets.append(subset_i)
        greedy_time = time.process_time() - greedy_start
        max_greedy_time = max(max_greedy_time, greedy_time)
    span += max_greedy_time

    # Run greedy CSS on the coordinator, using the
    # union of those subsets, to fit A.
    #print("coordinator")
    coordinator_greedy_start = time.process_time()
    union_of_server_subsets = np.hstack(tuple(server_subsets))
    coordinator_subset = greedy_one_round_single_server(A=A_pcps, B=union_of_server_subsets, S=S, k=k, lazier_than_lazy=lazier_than_lazy)
    coordinator_greedy_time = time.process_time() - coordinator_greedy_start
    span += coordinator_greedy_time

    for computed_subset in server_subsets:
        assert(type(computed_subset) == np.ndarray)
    assert(type(coordinator_subset) == np.ndarray)

    # Return the subset which gives best error, out
    # of coordinator_subset and server_subsets (account
    # for S if it is not None.)
    choosing_subset_time = time.process_time()
    subset_list = server_subsets
    subset_list.append(coordinator_subset)
    best_subset = None
    best_subset_error = np.linalg.norm(A)
    for computed_subset in subset_list:
        if S == None:
            test_subset = computed_subset
        else:
            test_subset = np.hstack((S, test_subset))

        test_right_factor = np.linalg.pinv(test_subset) @ A
        test_error = np.linalg.norm(test_subset @ test_right_factor - A)
        if test_error < best_subset_error:
            best_subset_error = test_error
            best_subset = computed_subset
    choosing_subset_time = time.process_time() - choosing_subset_time
    span += choosing_subset_time

    work = time.process_time() - work_start

    if S == None:
        return best_subset, work, span
    else:
        return np.hstack((S, best_subset)), work, span

"""
Compute the l1 regression error when trying to fit
the columns of A using the columns of U. This is
similar to the parallel_lra_error method in the 
Coordinator class.
"""
def compute_l1_regression_error(U, A):
    work_start = time.process_time()
    span = 0

    preprocessing_start_time = time.process_time()
    _, d = A.shape

    # Pre-conditioning: replace A with its left singular
    # vectors, removing the ones corresponding to 0
    # singular values.
    conditioned_U, _, _ = np.linalg.svd(U)
    rank = np.linalg.matrix_rank(U)
    conditioned_U = conditioned_U[:, :rank]

    # Other preprocessing steps
    reconstruction = np.zeros(shape=A.shape)
    preprocessing_end_time = time.process_time()
    span += preprocessing_end_time - preprocessing_start_time

    max_l1_regression_time = 0
    for c_idx in range(d):
        individual_regression_start_time = time.process_time()

        # Conditioning for the column of A
        q = np.array(A[:, c_idx])
        scale = np.linalg.norm(q)
        if scale == 0.0:
            solution = np.zeros(rank)
        else:
            q = q/scale
            _, solution = solve_l1_regression_v2(conditioned_U, q, c_idx)

        # Update column of reconstruction
        reconstruction[:, c_idx] = scale * (conditioned_U @ solution)

        # Span updates
        individual_regression_end_time = time.process_time()
        individual_regression_time = individual_regression_end_time - individual_regression_start_time
        max_l1_regression_time = max(max_l1_regression_time, individual_regression_time)
    
    total_l1_error = np.sum(np.abs(A - reconstruction))
    work = time.process_time() - work_start
    span += max_l1_regression_time
    return total_l1_error, reconstruction, work, span

####################################################

# Returns a SettingResults object
def run_trials_frobenius_greedy(A, A_i, approx_rank, num_trials, setting_num):
    errors = []
    works = []
    spans = []
    l1_regression_works = []
    l1_regression_spans = []
    for i in range(num_trials):
        print("Distributed greedy for Frobenius norm, Trial Number: %d" % i)
        start_time = time.time()
        subset_i, work_i, span_i = dist_greedy_one_round(A=A, num_servers=len(A_i), S=None, k=approx_rank, A_i=A_i, random=False, lazier_than_lazy=True)

        # Do l1 regression on the subset
        l1_error_i, _, regression_work_i, regression_span_i = compute_l1_regression_error(subset_i, A)

        # Update lists
        errors.append(l1_error_i)
        works.append(work_i)
        spans.append(span_i)
        l1_regression_works.append(regression_work_i)
        l1_regression_spans.append(regression_span_i)

        end_time = time.time()
        print('l1 error : {}'.format(l1_error_i))
        print("Baseline Time: ", end_time - start_time)
    
    return SettingResults(setting_num=setting_num, approx_rank=approx_rank, error_list=errors, UV_list=None, work_list=works, span_list=spans, l1_regression_work_list=l1_regression_works, l1_regression_span_list=l1_regression_spans)

"""
Runs num_trials many trials of our l1 protocol using
greedy, our l1 protocol using sketching/sampling for
{1, 2}-CSS. Runs SVD also. Uses a single set of
hyperparameters for {1, 2}-CSS based on sketching
and sampling.
"""
def run_experiments_on_dataset(A, A_i, approx_rank, num_trials, dataset_name):

    # Set hyperparameters for our l1 protocol
    # based on settings which have worked well
    # for real datasets
    cauchy_size = 8 * approx_rank
    coreset_size = 10 * approx_rank
    sketch_size = approx_rank // 3
    sparsity = min(5, sketch_size)
    lewis_weight_size = approx_rank

    # Our l1 protocol using Greedy {1, 2}-CSS
    l1_protocol_greedy_results = greedy_multiple_trials(num_trials=num_trials, Ais=A_i, cauchy_size=cauchy_size, coreset_size=coreset_size, approx_rank=approx_rank, setting_num=0)
    with open("Synthetic_Checkpoints/%s_l1_greedy.json" % dataset_name, "w") as outfile:
        l1_protocol_greedy_results.UV_list = None
        json.dump(l1_protocol_greedy_results.__dict__, outfile)

    # Our l1 protocol using Sketch-and-sample {1, 2}-CSS
    l1_protocol_sketching_results = sparse_embedding_multiple_trials(num_trials=num_trials, Ais=A_i, cauchy_size=cauchy_size, coreset_size=coreset_size, sketch_size=sketch_size, sparsity=sparsity, lewis_weight_size=lewis_weight_size, setting_num=1)
    with open("Synthetic_Checkpoints/%s_l1_sketch.json" % dataset_name, "w") as outfile:
        l1_protocol_sketching_results.UV_list = None
        json.dump(l1_protocol_sketching_results.__dict__, outfile)

    # Distributed greedy for Frobenius norm
    frobenius_results = run_trials_frobenius_greedy(A=A, A_i=A_i, approx_rank=approx_rank, num_trials=num_trials, setting_num=2)
    with open("Synthetic_Checkpoints/%s_frobenius_greedy.json" % dataset_name, "w") as outfile:
        frobenius_results.UV_list = None
        json.dump(frobenius_results.__dict__, outfile)

    # SVD - not a distributed algorithm but just for comparison purposes.
    svd_approx = rank_k_svd(A, approx_rank)
    svd_error = np.sum(np.abs(A - svd_approx))
    with open("Synthetic_Checkpoints/%s_svd_error.txt" % dataset_name, "w") as outfile:
        outfile.write(str(svd_error))

    print("SVD error: ", svd_error)

    # Return
    return l1_protocol_greedy_results, l1_protocol_sketching_results, frobenius_results, svd_error

def min_setting_error(results, dict=True):
    min_error = float("inf")
    if dict:
        error_list = results["error_list"]
    else:
        error_list = results.error_list
    for error in error_list:
        min_error = min(min_error, error)
    return min_error

def plot_results(dataset_name, results_tuple, approx_rank):
    names = ["l1 Greedy", "l1 SE", "DistGreedy", "SVD"]

    # Collect errors
    l1_greedy_error = min_setting_error(results_tuple[0])
    l1_sketching_error = min_setting_error(results_tuple[1])
    frobenius_error = min_setting_error(results_tuple[2])
    svd_error = results_tuple[3]
    errors = [l1_greedy_error, l1_sketching_error, frobenius_error, svd_error]
    
    # Make plot
    x_locations = np.arange(start=0, stop=len(names), step=1)
    plt.bar(x_locations, errors)
    plt.xticks(x_locations, names)
    plt.ylabel("l1 Error")
    plt.title("Error on %s for rank %d" % (dataset_name, approx_rank))
    plt.savefig("Synthetic_Checkpoints/%s_plot.png" % dataset_name)
    plt.cla()
    plt.clf()
    plt.close()

"""
Run all synthetic data experiments
- For Gaussian noise added to the counterexample
- Synthetic example in both cases is (2000 + 60) x (2000 + 60)
- All experiments run with 2 servers
"""
def run_synthetic_data_experiments(n=2000, k=10, num_trials=8):

    # Create dataset with Gaussian noise
    if not os.path.isfile("gaussian_synthetic_%d.npy" % k):
        gaussian_noise_dataset = synthetic_matrix(n=n, k=k, noise="Gaussian")
        np.save("gaussian_synthetic_%d.npy" % k, gaussian_noise_dataset)
    
    # Load both matrices
    gaussian_noise_dataset = np.load("gaussian_synthetic_%d.npy" % k)

    # Run Gaussian noise experiments - save results to file
    # Save as a tuple.
    _, num_cols = gaussian_noise_dataset.shape
    cutoff = num_cols//2
    gaussian_1 = gaussian_noise_dataset[:, :cutoff]
    gaussian_2 = gaussian_noise_dataset[:, cutoff:]
    server_list = [gaussian_1, gaussian_2]
    gaussian_results = run_experiments_on_dataset(A=gaussian_noise_dataset, A_i=server_list, approx_rank=k, num_trials=num_trials, dataset_name="Counterexample with Gaussian Noise")
    
    # Generate bar graph of Gaussian results
    plot_results("Counterexample with Gaussian Noise", gaussian_results, k)

def plot_3_for_main_paper(dataset_name, approx_rank):
    names = ["l1 Greedy", "l1 SE", "SVD"]
    dir_name = "Synthetic_Checkpoints_rank_%d" % approx_rank

    # Collect Errors
    with open("%s/%s_l1_greedy.json" % (dir_name, dataset_name)) as json_file:
        l1_greedy_error = min_setting_error(json.load(json_file))
    
    with open("%s/%s_l1_sketch.json" % (dir_name, dataset_name)) as json_file:
        l1_sketching_error = min_setting_error(json.load(json_file))

    with open("%s/%s_svd_error.txt" % (dir_name, dataset_name)) as text_file:
        line_list = text_file.read().splitlines()
        svd_error = float(line_list[0])
    
    errors = [l1_greedy_error, l1_sketching_error, svd_error]
    
    # Make plot
    x_locations = np.arange(start=0, stop=len(names), step=1)
    plt.bar(x_locations, errors)
    plt.xticks(x_locations, names)
    plt.ylabel("l1 Error")
    plt.title("Error on %s for rank %d" % (dataset_name, approx_rank))
    plt.savefig("%s/%s_plot_new.png" % (dir_name, dataset_name))
    plt.cla()
    plt.clf()
    plt.close()


def get_errors_for_main_paper(dataset_name, approx_rank):
    dir_name = "Synthetic_Checkpoints_rank_%d" % approx_rank

    # Collect Errors
    with open("%s/%s_l1_greedy.json" % (dir_name, dataset_name)) as json_file:
        l1_greedy_error = min_setting_error(json.load(json_file))
    
    with open("%s/%s_l1_sketch.json" % (dir_name, dataset_name)) as json_file:
        l1_sketching_error = min_setting_error(json.load(json_file))
    
    with open("%s/%s_svd_error.txt" % (dir_name, dataset_name)) as text_file:
        line_list = text_file.read().splitlines()
        svd_error = float(line_list[0])
    
    return l1_greedy_error, l1_sketching_error, svd_error

# Make a line graph with the 3 different algorithms
def plot_errors_for_main_paper(dataset_name, approx_ranks):

    l1_greedy_errors = []
    l1_sketching_errors = []
    svd_errors = []
    for rank in approx_ranks:
        l1_greedy_error, l1_sketching_error, svd_error = get_errors_for_main_paper(dataset_name, rank)
        l1_greedy_errors.append(l1_greedy_error)
        l1_sketching_errors.append(l1_sketching_error)
        svd_errors.append(svd_error)
    
    # Setup x labels
    x_tick_labels = []
    for rank in approx_ranks:
        x_tick_labels.append(str(rank))
    
    # Plot
    plt.plot(approx_ranks, l1_greedy_errors, label="Greedy")
    plt.plot(approx_ranks, l1_sketching_errors, label="Regular")
    plt.plot(approx_ranks, svd_errors, label="SVD")
    plt.xlabel("k, and Rank of Output")
    plt.xticks(ticks=approx_ranks, labels=x_tick_labels)
    plt.ylabel("l1 Error")
    plt.legend(loc="upper right")
    plt.title("Synthetic-Counter")
    plt.savefig("Synthetic-Counter-Errors.png")
    plt.show()
    plt.clf()
    plt.cla()
    plt.close()

# Two servers
def frobenius_greedy_l1_protocol_exps(k, num_trials, dataset_name):

    # Preparation
    A = load_additional_dataset(dataset_name)
    _, num_cols = A.shape
    cutoff = num_cols//2

    # Get min l1 error for distributed greedy for
    # Frobenius norm, and our protocol for l1 norm.
    min_frobenius_error = float("inf")
    frobenius_setting_results = SettingResults(setting_num=0, approx_rank=k, error_list=[], UV_list=None, work_list=[], span_list=[], l1_regression_work_list=[], l1_regression_span_list=[])
    min_l1_protocol_error = float("inf")
    l1_protocol_setting_results = SettingResults(setting_num=1, approx_rank=k, error_list=[], UV_list=None, work_list=[], span_list=[], l1_regression_work_list=[], l1_regression_span_list=[])
    cauchy_size = 2 * k
    coreset_size = 5 * k
    if dataset_name == "secom":
        pcps_size = 7 * k
    elif dataset_name == "gastro_lesions":
        pcps_size = 8 * k
    else:
        raise NotImplementedError

    for i in range(num_trials):
        print(i)

        # Shuffle columns at the beginning of each trial for
        # fairness to the distributed greedy algorithm.
        np.random.shuffle(A.T)
        A1 = A[:, :cutoff]
        A2 = A[:, cutoff:]

        # Get error from distributed greedy for Frobenius norm
        trial_time = time.time()
        subset, work, span = dist_greedy_one_round(A, 2, random=False, A_i=[A1, A2], S=None, k=k, lazier_than_lazy=True, num_cols_in_pcps=pcps_size)
        error, _, l1_regression_work, l1_regression_span = compute_l1_regression_error(subset, A)
        print("Frobenius greedy error: ", error)
        min_frobenius_error = min(min_frobenius_error, error)
        trial_time = time.time() - trial_time
        print("Time for Frobenius distributed greedy: ", trial_time)

        # Update frobenius_setting_results
        frobenius_setting_results.error_list.append(error)
        frobenius_setting_results.work_list.append(work)
        frobenius_setting_results.span_list.append(span)
        frobenius_setting_results.l1_regression_work_list.append(l1_regression_work)
        frobenius_setting_results.l1_regression_span_list.append(l1_regression_span)

        # Get error from l1 distributed protocol
        trial_time = time.time()
        error, _, work, span, l1_regression_work, l1_regression_span = distributed_protocol(Ais=[A1, A2], cauchy_size=cauchy_size, coreset_size=coreset_size, approx_rank=k, greedy=True, sketch_size=-1, sparsity=-1, lewis_weight_size=-1)
        print("l1 Protocol Error: ", error)
        min_l1_protocol_error = min(min_l1_protocol_error, error)
        trial_time = time.time() - trial_time
        print("Time for l1 distributed protocol with greedy: ", trial_time)

        # Update l1_protocol_setting_results
        l1_protocol_setting_results.error_list.append(error)
        l1_protocol_setting_results.work_list.append(work)
        l1_protocol_setting_results.span_list.append(span)
        l1_protocol_setting_results.l1_regression_work_list.append(l1_regression_work)
        l1_protocol_setting_results.l1_regression_span_list.append(l1_regression_span)

    print("Best Frobenius greedy error: ", min_frobenius_error)
    print("Best l1 protocol error: ", min_l1_protocol_error)

    # Save to directory
    SAVE_DIR_NAME = "checkpoints_%s_distributed_greedy_l1_protocol_comparison/" % dataset_name
    SAVE_FILE_NAME = "%s_rank_%d" % (dataset_name, k)
    
    frobenius_extension = "_frobenius.json"
    l1_protocol_extension = "_l1_protocol.json"

    if not os.path.isdir(SAVE_DIR_NAME):
        os.mkdir(SAVE_DIR_NAME)

    with open(SAVE_DIR_NAME + SAVE_FILE_NAME + frobenius_extension, "w+") as json_file:
        json.dump(frobenius_setting_results.__dict__, json_file)
    
    with open(SAVE_DIR_NAME + SAVE_FILE_NAME + l1_protocol_extension, "w+") as json_file:
        json.dump(l1_protocol_setting_results.__dict__, json_file)
    
###########################################################################
# Plots For l1 Protocol and Frobenius distributed greedy comparison

def plot_min_errors_supplementary(dataset_name, frobenius_stats, l1_stats, approx_ranks):
    frobenius_min_errors = []
    l1_min_errors = []

    for i in range(len(approx_ranks)):
        frobenius_min_errors.append(min_setting_error(frobenius_stats[i]))
        l1_min_errors.append(min_setting_error(l1_stats[i]))

    # Setup x labels
    x_tick_labels = []
    for rank in approx_ranks:
        x_tick_labels.append(str(rank))
    
    # Plot
    plt.plot(approx_ranks, frobenius_min_errors, label="Frobenius Protocol")
    plt.plot(approx_ranks, l1_min_errors, label="Our Protocol")
    plt.xlabel("Rank")
    plt.xticks(ticks=approx_ranks, labels=x_tick_labels)
    plt.ylabel("l1 Error")
    plt.legend(loc="upper right")
    plt.title("Min l1 Error on %s" % dataset_name)
    plt.savefig("%s-Frobenius-l1-comparison-Min-Errors.png" % dataset_name)
    plt.show()
    plt.clf()
    plt.cla()
    plt.close()

def mean_std_dev_setting_error(results, dict=True):
    if dict:
        error_list = results["error_list"]
    else:
        error_list = results.error_list
    mean_error = np.mean(error_list)
    std_dev = np.std(error_list)
    return mean_error, std_dev

def plot_mean_std_errors_supplementary(dataset_name, frobenius_stats, l1_stats, approx_ranks):
    frobenius_mean_errors = []
    frobenius_std_devs = []
    l1_mean_errors = []
    l1_std_devs = []
    
    for i in range(len(approx_ranks)):
        frob_mean, frob_std = mean_std_dev_setting_error(frobenius_stats[i])
        frobenius_mean_errors.append(frob_mean)
        frobenius_std_devs.append(frob_std)

        l1_mean, l1_std = mean_std_dev_setting_error(l1_stats[i])
        l1_mean_errors.append(l1_mean)
        l1_std_devs.append(l1_std)

    # Setup x labels
    x_tick_labels = []
    for rank in approx_ranks:
        x_tick_labels.append(str(rank))

    # Plot
    plt.errorbar(x=approx_ranks, y=frobenius_mean_errors, yerr=frobenius_std_devs, label="Frobenius Protocol")
    plt.errorbar(x=approx_ranks, y=l1_mean_errors, yerr=l1_std_devs, label="Our Protocol")
    plt.xlabel("Rank")
    plt.xticks(ticks=approx_ranks, labels=x_tick_labels)
    plt.ylabel("l1 Error")
    plt.legend(loc="upper right")
    plt.title("Mean l1 Error on %s" % dataset_name)
    plt.savefig("%s-Frobenius-l1-comparison-mean-error-bar-plot.png" % dataset_name)
    plt.show()
    plt.clf()
    plt.cla()
    plt.close()

# Option is either "work" or "span"
# It is assumed that the results structure
# is an object dictionary.
def mean_runtime(results, option):
    field = option + "_list"
    time_list = results[field]
    mean = np.mean(time_list)
    std_dev = np.std(time_list)
    return mean, std_dev

# Option is either "work" or "span"
def plot_mean_std_work_span_supplementary(dataset_name, frobenius_stats, l1_stats, approx_ranks, option):
    frobenius_mean_times = []
    frobenius_std_devs = []
    l1_mean_times = []
    l1_std_devs = []

    for i in range(len(approx_ranks)):
        frob_mean, frob_std = mean_runtime(frobenius_stats[i], option)
        frobenius_mean_times.append(frob_mean)
        frobenius_std_devs.append(frob_std)

        l1_mean, l1_std = mean_runtime(l1_stats[i], option)
        l1_mean_times.append(l1_mean)
        l1_std_devs.append(l1_std)
    
    # Setup x labels
    x_tick_labels = []
    for rank in approx_ranks:
        x_tick_labels.append(str(rank))

    # Plot
    plt.errorbar(x=approx_ranks, y=frobenius_mean_times, yerr=frobenius_std_devs, label="Frobenius Protocol")
    plt.errorbar(x=approx_ranks, y=l1_mean_times, yerr=l1_std_devs, label="Our Protocol")
    plt.xlabel("Rank")
    plt.xticks(ticks=approx_ranks, labels=x_tick_labels)
    plt.ylabel("Relative Runtime")
    plt.legend(loc="lower right")
    plt.title("Average %s on %s Using time.time()" % (option, dataset_name))
    plt.savefig("%s-Frobenius-l1-comparison-average-%s-error-bar-plot.png" % (dataset_name, option))
    plt.show()
    plt.clf()
    plt.cla()
    plt.close()

# This plots the min error, mean error with std. dev as error bars,
# and average work/span for each rank across all the trials, for
# our comparison between our l1 protocol and the distributed greedy
# protocol for CSS in the Frobenius norm. The plots will be line graphs.
def plot_results_l1_protocol_frobenius_greedy_comparison(dataset_name, approx_ranks):
    SAVE_DIR_NAME = "checkpoints_%s_distributed_greedy_l1_protocol_comparison/" % dataset_name
    frobenius_extension = "_frobenius.json"
    l1_protocol_extension = "_l1_protocol.json"

    # Get result files
    frobenius_stats = []
    l1_protocol_stats = []
    for rank in approx_ranks:
        SAVE_FILE_NAME = "%s_rank_%d" % (dataset_name, rank)

        # Frobenius stats
        frobenius_path = SAVE_DIR_NAME + SAVE_FILE_NAME + frobenius_extension
        with open(frobenius_path) as json_file:
            frobenius_stats.append(json.load(json_file))

        # l1 protocol stats
        l1_path = SAVE_DIR_NAME + SAVE_FILE_NAME + l1_protocol_extension
        with open(l1_path) as json_file:
            l1_protocol_stats.append(json.load(json_file))

    # Plot min errors
    plot_min_errors_supplementary(dataset_name=dataset_name, frobenius_stats=frobenius_stats, l1_stats=l1_protocol_stats, approx_ranks=approx_ranks)

    # Plot mean errors with std. dev as error bars
    plot_mean_std_errors_supplementary(dataset_name=dataset_name, frobenius_stats=frobenius_stats, l1_stats=l1_protocol_stats, approx_ranks=approx_ranks)

    # Plot mean work span with error bars
    plot_mean_std_work_span_supplementary(dataset_name=dataset_name, frobenius_stats=frobenius_stats, l1_stats=l1_protocol_stats, approx_ranks=approx_ranks, option="work")
    plot_mean_std_work_span_supplementary(dataset_name=dataset_name, frobenius_stats=frobenius_stats, l1_stats=l1_protocol_stats, approx_ranks=approx_ranks, option="span")

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--main_paper")
    parser.add_argument("--dataset_name")
    parser.add_argument("--test_rank")
    parser.add_argument("--num_trials")
    args = parser.parse_args()
    main_paper_arg = int(args.main_paper)
    if main_paper_arg == 0:
        # In this case, run the comparison between our
        # protocol and distributed greedy for Frobenius.
        dataset_name_arg = str(args.dataset_name)
        test_rank_arg = int(args.test_rank)
        num_trials_arg = int(args.num_trials)
        frobenius_greedy_l1_protocol_exps(k=test_rank_arg, num_trials=num_trials_arg, dataset_name=dataset_name_arg)
    elif main_paper_arg == 1:
        # main_paper_arg is equal to 1 - in this case,
        # run the synthetic data experiments for a given
        # choice of k, for a given number of trials.
        test_rank_arg = int(args.test_rank)
        num_trials_arg = int(args.num_trials)
        run_synthetic_data_experiments(k=test_rank_arg, num_trials=num_trials_arg)
    elif main_paper_arg == 2:
        # In this case, plot the results from the desired ranks.
        # For the main paper, the ranks are 10, 20, 30
        plot_errors_for_main_paper(dataset_name="Counterexample with Gaussian Noise", approx_ranks=[10, 20, 30])
    elif main_paper_arg == 3:
        # In this case, plot the results on the desired 
        # ranks for our comparison between our l1 protocol 
        # and distributed greedy for Frobenius norm.
        plot_results_l1_protocol_frobenius_greedy_comparison(dataset_name="gastro_lesions", approx_ranks=[10, 20, 30])
        plot_results_l1_protocol_frobenius_greedy_comparison(dataset_name="secom", approx_ranks=[30, 60, 90, 120])
    else:
        raise NotImplementedError