import json
import numpy as np 
import os 
from caltech101_experiments import ALL_IMAGE_FILEPATHS
from dataloader import load_image

EXPERIMENT_STATS_DIR = "Experiment_Stats/"

def get_svd_error(filepath): 
    data = json.load(open(filepath, 'rb'))
    error_list = data['error_list'] 
    assert(len(error_list) == 1) 
    return error_list[0]

def get_greedy_or_sparse_embedding_work_span_stats(filepath): 
    data = json.load(open(filepath, 'rb'))
    work_list = data['work_list']
    span_list = data['span_list']
    assert(len(work_list) >= 1 and len(span_list) >= 1) 

    return np.mean(work_list), np.mean(span_list)

def get_greedy_or_sparse_embedding_error_stats(filepath): 
    data = json.load(open(filepath, 'rb'))
    error_list = data['error_list']
    assert(len(error_list) >= 1) 
    return np.min(error_list), np.mean(error_list), np.std(error_list)

def get_greedy_or_sparse_embedding_error_work_span_stats(filepath): 
    error_min, error_mean, error_std = get_greedy_or_sparse_embedding_error_stats(filepath=filepath)
    work_mean, span_mean = get_greedy_or_sparse_embedding_work_span_stats(filepath=filepath) 

    return error_min, error_mean, error_std, work_mean, span_mean

def get_sparse_embedding_last_setting_num(filepath_expr): 
    setting_num = 1 

    while os.path.exists(filepath_expr % setting_num): 
        setting_num += 1

    return setting_num - 1

# RETURN: sorted
def get_sparse_embedding_error_stats_top2(filepath_expr): 
    setting_num = 1

    all_error_stats_tuples = []
    while os.path.exists(filepath_expr % setting_num): 
        error_min, error_mean, error_std = get_greedy_or_sparse_embedding_error_stats(filepath_expr % setting_num)
        all_error_stats_tuples.append((setting_num, error_min, error_mean, error_std)) 

        setting_num += 1

    print("Last setting_num for sparse embedding found: %d" % (setting_num-1))

    # find indices of two settings with least error 
    all_min_errors = [min_error for (_, min_error, _, _) in all_error_stats_tuples]
    indices = np.argpartition(np.array(all_min_errors), kth=2)

    error_stats_top2 = [all_error_stats_tuples[indices[0]], all_error_stats_tuples[indices[1]]] 
    if error_stats_top2[0][1] > error_stats_top2[1][1]: 
        temp = error_stats_top2[1]
        error_stats_top2[1] = error_stats_top2[0] 
        error_stats_top2[0] = temp 

    return error_stats_top2

# RETURNS: (error_mins, error_means, error_stds)
#   error_mins: a list of length = len(test_ranks), where each entry is a tuple as follows
#           (svd_error, greedy_error_min, sparse1_error_min, sparse2_error_min)
#   error_means: a list of length = len(test_ranks), where each entry is a tuple as follows
#           (svd_error, greedy_error_mean, sparse1_error_mean, sparse2_error_mean)
#   error_stds: a list of length = len(test_ranks), where each entry is a tuple as follows
#           (0, greedy_error_std, sparse1_error_std, sparse2_error_std)
def get_all_errors(dataset_name, test_ranks, cauchy_factor): 
    error_mins = [] 
    error_means = [] 
    error_stds = []
    top2_sparse_embed_setting_nums = [] 

    # find the top2 performing sparse embedding settings for the highest rank 
    approx_rank = test_ranks[-1]
    cauchy_size = cauchy_factor * approx_rank 
    directory_name = "Checkpoints_%s_rank_%d_json/" % (dataset_name, approx_rank)
    sparse_embedding_filepath_expr = directory_name + dataset_name + "_rank_" + str(approx_rank) + "_results_cauchy_size_" + str(cauchy_size) + "_setting_num_%d.json" 
    sparse_error_stats_top2 = get_sparse_embedding_error_stats_top2(filepath_expr=sparse_embedding_filepath_expr)
    sparse1_setting_num = sparse_error_stats_top2[0][0] 
    sparse2_setting_num = sparse_error_stats_top2[1][0] 
    top2_sparse_embed_setting_nums = [(sparse1_setting_num, sparse2_setting_num)]

    # get error stats for all ranks 
    for approx_rank in test_ranks: 
        cauchy_size = cauchy_factor * approx_rank 
        directory_name = "Checkpoints_%s_rank_%d_json/" % (dataset_name, approx_rank)

        # svd error
        svd_filename = "%s_rank_%d_results_cauchy_size_-1_setting_num_-1.json" % (dataset_name, approx_rank) 
        svd_error = get_svd_error(filepath=directory_name+svd_filename)

        # greedy error stats 
        greedy_filename = "%s_rank_%d_results_cauchy_size_%d_setting_num_0.json" % (dataset_name, approx_rank, cauchy_size)
        greedy_error_min, greedy_error_mean, greedy_error_std = get_greedy_or_sparse_embedding_error_stats(filepath=directory_name+greedy_filename)

        # sparse embedding error stats 
        sparse_embedding_filepath_expr = directory_name + dataset_name + "_rank_" + str(approx_rank) + "_results_cauchy_size_" + str(cauchy_size) + "_setting_num_%d.json" 
        sparse1_error_min, sparse1_error_mean, sparse1_error_std = get_greedy_or_sparse_embedding_error_stats(filepath=sparse_embedding_filepath_expr%sparse1_setting_num)
        sparse2_error_min, sparse2_error_mean, sparse2_error_std = get_greedy_or_sparse_embedding_error_stats(filepath=sparse_embedding_filepath_expr%sparse2_setting_num)

        # store stats 
        error_mins.append((svd_error, greedy_error_min, sparse1_error_min, sparse2_error_min))
        error_means.append((svd_error, greedy_error_mean, sparse1_error_mean, sparse2_error_mean))
        error_stds.append((0, greedy_error_std, sparse1_error_std, sparse2_error_std)) 

    return error_mins, error_means, error_stds, top2_sparse_embed_setting_nums


def get_error_work_span_stats_all_settings(dataset_name, test_ranks, cauchy_factor):
    error_mins = [] 
    error_means = [] 
    error_stds = []
    work_means = [] 
    span_means = []

    # get error stats for all ranks 
    for approx_rank in test_ranks: 
        cauchy_size = cauchy_factor * approx_rank 
        directory_name = "Checkpoints_%s_rank_%d_json/" % (dataset_name, approx_rank)

        # svd error
        svd_filename = "%s_rank_%d_results_cauchy_size_-1_setting_num_-1.json" % (dataset_name, approx_rank) 
        svd_error = get_svd_error(filepath=directory_name+svd_filename)

        # greedy error stats 
        greedy_filename = "%s_rank_%d_results_cauchy_size_%d_setting_num_0.json" % (dataset_name, approx_rank, cauchy_size)
        greedy_error_min, greedy_error_mean, greedy_error_std, greedy_work_mean, greedy_span_mean = get_greedy_or_sparse_embedding_error_work_span_stats(filepath=directory_name+greedy_filename)

        # find last sparse embedding num 
        sparse_embedding_filepath_expr = directory_name + dataset_name + "_rank_" + str(approx_rank) + "_results_cauchy_size_" + str(cauchy_size) + "_setting_num_%d.json"
        sparse_last_setting_num = get_sparse_embedding_last_setting_num(filepath_expr=sparse_embedding_filepath_expr)

        approx_rank_error_mins = [svd_error, greedy_error_min]
        approx_rank_error_means = [svd_error, greedy_error_mean] 
        approx_rank_error_stds = [0, greedy_error_std]
        approx_rank_work_means = [greedy_work_mean] # DOES NOT INCLUDE svd work 
        approx_rank_span_means = [greedy_span_mean] # DOES NOT INCLUDE svd span 

        # sparse embedding error stats 
        for setting_num in range(1, sparse_last_setting_num + 1, 1):
            sparse_error_min, sparse_error_mean, sparse_error_std, sparse_work_mean, sparse_span_mean = get_greedy_or_sparse_embedding_error_work_span_stats(filepath=sparse_embedding_filepath_expr%setting_num) 
            approx_rank_error_mins.append(sparse_error_min) 
            approx_rank_error_means.append(sparse_error_mean) 
            approx_rank_error_stds.append(sparse_error_std)
            approx_rank_work_means.append(sparse_work_mean)
            approx_rank_span_means.append(sparse_span_mean)
        
        error_mins.append(approx_rank_error_mins) 
        error_means.append(approx_rank_error_means) 
        error_stds.append(approx_rank_error_stds)
        work_means.append(approx_rank_work_means)
        span_means.append(approx_rank_span_means)

    return error_mins, error_means, error_stds, work_means, span_means 


def get_svd_error_and_UV(filepath): 
    data = json.load(open(filepath, 'rb'))
    error_list = data['error_list'] 
    UV_list = data['UV_list']
    assert(len(error_list) == 1 and len(UV_list) == 1) 
    return error_list[0], UV_list[0]

def get_greedy_or_sparse_embedding_error_stats_and_UV(filepath): 
    data = json.load(open(filepath, 'rb'))
    error_list = data['error_list']
    UV_list = data['UV_list']
    assert(len(error_list) >= 1 and len(UV_list) == 1)
    return np.min(error_list), np.mean(error_list), np.std(error_list), UV_list[0] 

# RETURN: sorted
def get_sparse_embedding_error_stats_and_UV_top2(filepath_expr): 
    setting_num = 1

    all_error_stats_tuples = []
    while os.path.exists(filepath_expr % setting_num): 
        error_min, error_mean, error_std, best_UV = get_greedy_or_sparse_embedding_error_stats_and_UV(filepath_expr % setting_num)
        all_error_stats_tuples.append((setting_num, error_min, error_mean, error_std, best_UV)) 

        setting_num += 1

    print("Last setting_num for sparse embedding found: %d" % (setting_num-1))

    # find indices of two settings with least error 
    all_min_errors = [min_error for (_, min_error, _, _, _) in all_error_stats_tuples]
    indices = np.argpartition(np.array(all_min_errors), kth=2)

    error_stats_top2 = [all_error_stats_tuples[indices[0]], all_error_stats_tuples[indices[1]]] 
    if error_stats_top2[0][1] > error_stats_top2[1][1]: 
        temp = error_stats_top2[1]
        error_stats_top2[1] = error_stats_top2[0] 
        error_stats_top2[0] = temp 

    return error_stats_top2

def get_all_errors_and_best_UVs_caltech101(image_num, test_ranks, directory_name): 
    error_mins = [] 
    error_means = [] 
    error_stds = []
    best_UVs = []
    top2_sparse_embed_setting_nums = [] 

    # find the top2 performing sparse embedding settings for the highest rank 
    approx_rank = test_ranks[-1]
    filepath_expr = "image" + str(image_num) + "_rank" + str(approx_rank) + "_setting%d.json" 
    sparse_error_stats_top2 = get_sparse_embedding_error_stats_and_UV_top2(filepath_expr=directory_name+filepath_expr)
    sparse1_setting_num = sparse_error_stats_top2[0][0] 
    sparse2_setting_num = sparse_error_stats_top2[1][0] 
    top2_sparse_embed_setting_nums = [(sparse1_setting_num, sparse2_setting_num)]

    for approx_rank in test_ranks: 
        filepath_expr = "image" + str(image_num) + "_rank" + str(approx_rank) + "_setting%d.json" 

        # svd error
        svd_filename = filepath_expr % -1
        svd_error, svd_UV = get_svd_error_and_UV(filepath=directory_name+svd_filename)

        # greedy error stats 
        greedy_filename = filepath_expr % 0
        greedy_error_min, greedy_error_mean, greedy_error_std, greedy_best_UV = get_greedy_or_sparse_embedding_error_stats_and_UV(filepath=directory_name+greedy_filename)

        # sparse embedding error stats 
        sparse_embedding_filepath_expr = directory_name + filepath_expr
        sparse1_error_min, sparse1_error_mean, sparse1_error_std, sparse1_best_UV = get_greedy_or_sparse_embedding_error_stats_and_UV(filepath=sparse_embedding_filepath_expr%sparse1_setting_num)
        sparse2_error_min, sparse2_error_mean, sparse2_error_std, sparse2_best_UV = get_greedy_or_sparse_embedding_error_stats_and_UV(filepath=sparse_embedding_filepath_expr%sparse2_setting_num)

        # store stats 
        error_mins.append((svd_error, greedy_error_min, sparse1_error_min, sparse2_error_min))
        error_means.append((svd_error, greedy_error_mean, sparse1_error_mean, sparse2_error_mean))
        error_stds.append((0, greedy_error_std, sparse1_error_std, sparse2_error_std)) 
        best_UVs.append((svd_UV, greedy_best_UV, sparse1_best_UV, sparse2_best_UV)) 

    return error_mins, error_means, error_stds, best_UVs, top2_sparse_embed_setting_nums

def get_caltech101_average_top2(directory_name, test_ranks): 
    image_nums = [x for (x, y) in ALL_IMAGE_FILEPATHS]
    filepath_expr = directory_name + "image" + str(image_nums[0]) + "_rank" + str(test_ranks[0]) + "_setting%d.json"

    # store orignial image entrywise l1 value 
    original_image_l1s = dict()
    for image_num, image_filepath in ALL_IMAGE_FILEPATHS:
        A = load_image(image_filepath)  
        original_image_l1s[image_num] = np.sum(np.abs(A)) 

    # find the highest setting num 
    setting_num = 1
    while os.path.exists(filepath_expr % setting_num): 
        setting_num += 1
    last_setting_num = setting_num - 1

    # store all sparse embedding settings errrors for all imagese, for all sparse embed setting_nums, for the highest rank
    approx_rank = test_ranks[-1]
    data = dict() 
    for setting_num in range(1, last_setting_num + 1, 1):  
        data[setting_num] = [] # list used to store sparse_error_min for all images, for highest rank 

    for image_num in image_nums: 
        original_image_l1 = original_image_l1s[image_num]
        filepath_expr = "image" + str(image_num) + "_rank" + str(approx_rank) + "_setting%d.json" 

        # all sparse embedding settings
        for setting_num in range(1, last_setting_num + 1, 1): 
            sparse_error_min, _, _, _ = get_greedy_or_sparse_embedding_error_stats_and_UV(filepath=directory_name+filepath_expr%setting_num)
            data[setting_num].append(sparse_error_min/float(original_image_l1)) 

    # convert dict to np.array (Note: index and setting_num are off by 1, namely index = setting_num-1)
    # and find the top2 performing setting nums for 
    L = [] 
    for setting_num in range(1, last_setting_num + 1, 1):  
        L.append(data[setting_num])
    total_error_by_setting_num = np.sum(np.array(L), axis=-1)
    indices = np.argpartition(total_error_by_setting_num, kth=2)

    # print(np.array(L))
    # print(np.sum(np.array(L), axis=-1))

    # get top2 indices in sorted order and then get top2 setting_nums 
    top2_indices = indices[:2]
    if total_error_by_setting_num[indices[0]] > total_error_by_setting_num[indices[1]]: 
        top2_indices = [indices[1], indices[0]]
    sparse1_setting_num = top2_indices[0] + 1
    sparse2_setting_num = top2_indices[1] + 1 

    # print(indices)
    print(top2_indices)
    print(sparse1_setting_num, sparse2_setting_num)

    # get stats 
    # all_images_error_mins_percent = np.zeros((len(image_nums), len(test_ranks), 4)) # svd + greedy + 2 sparse embeds
    # all_images_error_means_percent = np.zeros((len(image_nums), len(test_ranks), 4))
    # all_images_error_stds_percent = np.zeros((len(image_nums), len(test_ranks), 4))
    all_images_error_mins_percent = [] 
    all_images_error_means_percent = [] 
    all_images_error_stds_percent = []

    for i, image_num in enumerate(image_nums): 
        error_mins = [] 
        error_means = [] 
        error_stds = []

        for j, approx_rank in enumerate(test_ranks): 
            filepath_expr = directory_name + "image" + str(image_num) + "_rank" + str(approx_rank) + "_setting%d.json" 

            # svd error
            svd_filepath = filepath_expr % -1
            svd_error, _ = get_svd_error_and_UV(filepath=svd_filepath)

            # greedy error stats 
            greedy_filepath = filepath_expr % 0
            greedy_error_min, greedy_error_mean, greedy_error_std, _ = get_greedy_or_sparse_embedding_error_stats_and_UV(filepath=greedy_filepath)

            # sparse embedding error stats 
            sparse1_error_min, sparse1_error_mean, sparse1_error_std, _ = get_greedy_or_sparse_embedding_error_stats_and_UV(filepath=filepath_expr%sparse1_setting_num)
            sparse2_error_min, sparse2_error_mean, sparse2_error_std, _ = get_greedy_or_sparse_embedding_error_stats_and_UV(filepath=filepath_expr%sparse2_setting_num)

            # store stats 
            error_mins.append((svd_error, greedy_error_min, sparse1_error_min, sparse2_error_min))
            error_means.append((svd_error, greedy_error_mean, sparse1_error_mean, sparse2_error_mean))
            error_stds.append((0, greedy_error_std, sparse1_error_std, sparse2_error_std)) 

        original_image_l1 = original_image_l1s[image_num]
        # print(original_image_l1)
        # print(type(original_image_l1))
        # print(error_mins)
        # print(type(error_mins)) 
        # print(np.array(error_mins)/np.float64(2))

        error_mins_percent = (np.array(error_mins)/original_image_l1).tolist()
        error_means_percent = (np.array(error_means)/original_image_l1).tolist()
        error_stds_percent = (np.array(error_stds)/original_image_l1).tolist()

        all_images_error_mins_percent.append(error_mins_percent)
        all_images_error_means_percent.append(error_means_percent)
        all_images_error_stds_percent.append(error_stds_percent)

    all_images_error_mins_percent = np.array(all_images_error_mins_percent)
    all_images_error_means_percent = np.array(all_images_error_means_percent)
    all_images_error_stds_percent = np.array(all_images_error_stds_percent)

    print(all_images_error_mins_percent.shape)

    num_images = len(image_nums)
    image_set_stats = dict()
    image_set_stats['test_ranks'] = test_ranks
    image_set_stats['image_set_mean_error_mins_percent'] = (np.sum(all_images_error_mins_percent, axis=0)/float(num_images)).tolist()
    image_set_stats['image_set_mean_error_means_percent'] = (np.sum(all_images_error_means_percent, axis=0)/float(num_images)).tolist()
    image_set_stats['image_set_mean_error_stds_percent'] = (np.sum(all_images_error_stds_percent, axis=0)/float(num_images)).tolist()

    with open(save_dir + "caltech101_image_set_experiment_stats.json", "w") as json_file:
            json.dump(image_set_stats, json_file)

def get_caltech101_error_work_span_stats_all_settings(directory_name, test_ranks): 
    image_nums = [x for (x, y) in ALL_IMAGE_FILEPATHS]
    filepath_expr = directory_name + "image" + str(image_nums[0]) + "_rank" + str(test_ranks[0]) + "_setting%d.json"

    # store orignial image entrywise l1 value 
    original_image_l1s = dict()
    for image_num, image_filepath in ALL_IMAGE_FILEPATHS:
        A = load_image(image_filepath)  
        original_image_l1s[image_num] = np.sum(np.abs(A)) 

    # find the highest setting num 
    setting_num = 1
    while os.path.exists(filepath_expr % setting_num): 
        setting_num += 1
    last_setting_num = setting_num - 1

    # get stats 
    all_images_error_mins_percent = [] 
    all_images_error_means_percent = [] 
    all_images_error_stds_percent = []
    all_images_work_means = [] 
    all_images_span_means = [] 

    for i, image_num in enumerate(image_nums): 
        error_mins = [] 
        error_means = [] 
        error_stds = []
        work_means = []
        span_means = [] 

        for j, approx_rank in enumerate(test_ranks): 
            filepath_expr = directory_name + "image" + str(image_num) + "_rank" + str(approx_rank) + "_setting%d.json" 

            # svd error
            svd_filepath = filepath_expr % -1
            svd_error, _ = get_svd_error_and_UV(filepath=svd_filepath)

            # greedy error stats 
            greedy_filepath = filepath_expr % 0
            greedy_error_min, greedy_error_mean, greedy_error_std, greedy_work_mean, greedy_span_mean = get_greedy_or_sparse_embedding_error_work_span_stats(filepath=greedy_filepath)

            approx_rank_error_mins = [svd_error, greedy_error_min]
            approx_rank_error_means = [svd_error, greedy_error_mean]
            approx_rank_error_stds = [0, greedy_error_std]
            approx_rank_work_means = [greedy_work_mean] # DOES NOT INCLUDE svd work 
            approx_rank_span_means = [greedy_span_mean] # DOES NOT INCLUDE svd span 

            # sparse embedding error stats 
            for setting_num in range(1, last_setting_num + 1, 1):
                sparse_error_min, sparse_error_mean, sparse_error_std, sparse_work_mean, sparse_span_mean = get_greedy_or_sparse_embedding_error_work_span_stats(filepath=filepath_expr%setting_num) 
                approx_rank_error_mins.append(sparse_error_min) 
                approx_rank_error_means.append(sparse_error_mean)
                approx_rank_error_stds.append(sparse_error_std)
                approx_rank_work_means.append(sparse_work_mean) 
                approx_rank_span_means.append(sparse_span_mean)
            
            # store stats 
            error_mins.append(approx_rank_error_mins) 
            error_means.append(approx_rank_error_means) 
            error_stds.append(approx_rank_error_stds) 
            work_means.append(approx_rank_work_means) 
            span_means.append(approx_rank_span_means)

        original_image_l1 = original_image_l1s[image_num]

        error_mins_percent = (np.array(error_mins)/original_image_l1).tolist()
        error_means_percent = (np.array(error_means)/original_image_l1).tolist()
        error_stds_percent = (np.array(error_stds)/original_image_l1).tolist()

        all_images_error_mins_percent.append(error_mins_percent)
        all_images_error_means_percent.append(error_means_percent)
        all_images_error_stds_percent.append(error_stds_percent)
        all_images_work_means.append(work_means) 
        all_images_span_means.append(span_means) 

    all_images_error_mins_percent = np.array(all_images_error_mins_percent)
    all_images_error_means_percent = np.array(all_images_error_means_percent)
    all_images_error_stds_percent = np.array(all_images_error_stds_percent)
    all_images_work_means = np.array(all_images_work_means) 
    all_images_span_means = np.array(all_images_span_means)

    print(all_images_error_mins_percent.shape)

    num_images = len(image_nums)
    image_set_stats = dict()
    image_set_stats['test_ranks'] = test_ranks
    image_set_stats['image_set_mean_error_mins_percent'] = (np.sum(all_images_error_mins_percent, axis=0)/float(num_images)).tolist()
    image_set_stats['image_set_mean_error_means_percent'] = (np.sum(all_images_error_means_percent, axis=0)/float(num_images)).tolist()
    image_set_stats['image_set_mean_error_stds_percent'] = (np.sum(all_images_error_stds_percent, axis=0)/float(num_images)).tolist()
    image_set_stats['image_set_mean_work_means'] = (np.sum(all_images_work_means, axis=0)/float(num_images)).tolist()
    image_set_stats['image_set_mean_span_means'] = (np.sum(all_images_span_means, axis=0)/float(num_images)).tolist()

    with open(save_dir + "caltech101_image_set_experiment_stats_all_settings.json", "w") as json_file:
            json.dump(image_set_stats, json_file)

if __name__ == "__main__":
    save_dir = EXPERIMENT_STATS_DIR
    if not os.path.isdir(save_dir): 
        os.mkdir(save_dir)
    test_ranks = [10, 20, 30, 40, 50, 60]

    cauchy_factors = {"bcsstk13s": 8, "isolet_transpose": 4, "forest_cover": 4}
    dataset_names = ["bcsstk13s", "isolet_transpose", "forest_cover"]

    ###########################################################################
    # stats for: bcsstk13s, isolet_transpose, forest_cover (top2 sparse embedding settings)
    ###########################################################################
    for dataset_name in dataset_names: 
        error_mins, error_means, error_stds, top2_sparse_embed_setting_nums = get_all_errors(dataset_name=dataset_name, test_ranks=test_ranks, cauchy_factor=cauchy_factors[dataset_name])

        data = dict() 
        data['test_ranks'] = test_ranks
        data['error_mins'] = error_mins
        data['error_means'] = error_means
        data['error_stds'] = error_stds 
        data['top2_sparse_embed_setting_nums'] = top2_sparse_embed_setting_nums

        print(dataset_name) 
        print(top2_sparse_embed_setting_nums)

        with open(save_dir + "%s_experiment_stats.json" % dataset_name, "w") as json_file:
            json.dump(data, json_file)

    ###########################################################################
    # stats for: caltech101 individual images (top2 sparse embedding settings)
    ###########################################################################
    caltech101_all_images_stats = dict() 
    for image_num, image_filepath in ALL_IMAGE_FILEPATHS:
        print("image_num: %d" % image_num)

        # original image 
        A = load_image(image_filepath)
        original_image_l1 = np.sum(np.abs(A)) 

        error_mins, error_means, error_stds, best_UVs, top2_sparse_embed_setting_nums = get_all_errors_and_best_UVs_caltech101(image_num=image_num, 
            test_ranks=test_ranks, directory_name="Checkpoints_caltech101_json/") 
        error_mins_percent = (error_mins/original_image_l1).tolist()
        error_means_percent = (error_means/original_image_l1).tolist()
        error_stds_percent = (error_stds/original_image_l1).tolist()

        print(top2_sparse_embed_setting_nums)

        data = dict() 
        data['test_ranks'] = test_ranks
        data['error_mins'] = error_mins
        data['error_means'] = error_means
        data['error_stds'] = error_stds 
        data['best_UVs'] = best_UVs
        data['top2_sparse_embed_setting_nums'] = top2_sparse_embed_setting_nums
        data['error_mins_percent'] = error_mins_percent
        data['error_means_percent'] = error_means_percent
        data['error_stds_percent'] = error_stds_percent
        data['original_image_l1'] = original_image_l1

        caltech101_all_images_stats[image_num] = data 

        with open(save_dir + "caltech101_image%d_experiment_stats.json" % image_num, "w") as json_file:
            json.dump(data, json_file)

    ###########################################################################
    # stats for: caltech101 image set (top2 sparse embedding settings)
    ###########################################################################
    get_caltech101_average_top2(directory_name="Checkpoints_caltech101_json/", test_ranks=test_ranks)

    #########################
    # stats for: bcsstk13s, isolet_transpose, forest_cover (all settings)
    #########################
    for dataset_name in dataset_names: 
        error_mins, error_means, error_stds, work_means, span_means = get_error_work_span_stats_all_settings(dataset_name=dataset_name, test_ranks=test_ranks, cauchy_factor=cauchy_factors[dataset_name])

        data = dict() 
        data['test_ranks'] = test_ranks
        data['error_mins'] = error_mins
        data['error_means'] = error_means 
        data['error_stds'] = error_stds
        data['work_means'] = work_means 
        data['span_means'] = span_means

        with open(save_dir + "%s_experiment_stats_all_settings.json" % dataset_name, "w") as json_file:
            json.dump(data, json_file)

    ###########################################################################
    # stats for: caltech101 image set (all settings)
    ###########################################################################
    get_caltech101_error_work_span_stats_all_settings(directory_name="Checkpoints_caltech101_json/", test_ranks=test_ranks)

        