from main_protocol import * 
from dataloader import save_image 
import os 
import math

ALL_IMAGE_FILEPATHS = [(0, "101_ObjectCategories/car_side/image_0002.jpg"), 
                    (1, "101_ObjectCategories/ceiling_fan/image_0034.jpg"), 
                    (2, "101_ObjectCategories/barrel/image_0038.jpg"),
                    (3, "101_ObjectCategories/elephant/image_0060.jpg"), 
                    (4, "101_ObjectCategories/helicopter/image_0088.jpg")] 

CHECKPOINT_DIR = "Caltech101_Checkpoints/"

def checkpoint_setting_results(image_num, approx_rank, setting_num, setting_data):
    if not os.path.isdir(CHECKPOINT_DIR): 
        os.mkdir(CHECKPOINT_DIR) 

    ckpt_filename = "image%d_rank%d_setting%d.pickle" % (image_num, approx_rank, setting_num) 
     # save resutls
    with open(CHECKPOINT_DIR + ckpt_filename, 'wb') as outfile:
        pickle.dump(setting_data, outfile)

# Setting 0: Greedy, 
# Setting 1: Sketch size = rank, sparsity <= 20
# Setting 2: Sketch size = rank/3, sparsity <= 2
# Setting 3: Sketch size = rank/5, sparsity <= 2
# Setting 4: Sketch size = rank/3, sparsity <= 5
# Setting 5: Sketch size = rank/5, sparsity <= 5
# Setting 6: SVD
# the cauchy size will be 2 * rank, and the coreset size will be rank. 
# 
# run experiments for one image, one rank, all setttings, all trials
# RETURNS: tuple(a list of SettingResutls of size 6 (number of settings), (svd_approx, svd_error)) 
def conduct_trials_one_image_one_rank(image_num, A, Ais, approx_rank, greedy_num_trial, sparse_embedd_num_trial): 
    # all_data = [] 

    cauchy_size = 2 * approx_rank
    coreset_size = approx_rank

    # =========================================================
    # Greedy
    # =========================================================
    # Setting 0
    setting_num = 0
    setting_data = greedy_multiple_trials(num_trials=greedy_num_trial, Ais=Ais, 
        cauchy_size=cauchy_size, coreset_size=coreset_size, approx_rank=approx_rank, setting_num=setting_num) 
    # all_data.append(setting_data)
    checkpoint_setting_results(image_num=image_num, approx_rank=approx_rank, 
    setting_num=setting_num, setting_data=setting_data)

    # =========================================================
    # Sparse Embedding 
    # =========================================================
    # Setting 1
    setting_num = 1
    sketch_size = approx_rank
    setting1 = {'num_trials': sparse_embedd_num_trial,
                'Ais': Ais,
                'cauchy_size': cauchy_size,
                'coreset_size': coreset_size,
                'sketch_size': sketch_size, 
                'lewis_weight_size': approx_rank,
                'sparsity': min(sketch_size, 20), 
                'setting_num': setting_num}

    # Setting 2
    setting_num = 2
    sketch_size = math.ceil(approx_rank/3)
    setting2 = {'num_trials': sparse_embedd_num_trial,
                'Ais': Ais,
                'cauchy_size': cauchy_size,
                'coreset_size': coreset_size,
                'sketch_size': sketch_size, 
                'lewis_weight_size': approx_rank,
                'sparsity': min(sketch_size, 2), 
                'setting_num': setting_num}

    # Setting 3
    setting_num = 3
    sketch_size = math.ceil(approx_rank/5)
    setting3 = {'num_trials': sparse_embedd_num_trial, 
                'Ais': Ais,
                'cauchy_size': cauchy_size,
                'coreset_size': coreset_size,
                'sketch_size': sketch_size, 
                'lewis_weight_size': approx_rank,
                'sparsity': min(sketch_size, 2), 
                'setting_num': setting_num}

    # Setting 4
    setting_num = 4
    sketch_size = math.ceil(approx_rank/3)
    setting4 = {'num_trials': sparse_embedd_num_trial,
                'Ais': Ais,
                'cauchy_size': cauchy_size,
                'coreset_size': coreset_size,
                'sketch_size': sketch_size, 
                'lewis_weight_size': approx_rank,
                'sparsity': min(sketch_size, 5), 
                'setting_num': setting_num}


    # Setting 5
    setting_num = 5
    sketch_size = math.ceil(approx_rank/5)
    setting5 = {'num_trials': sparse_embedd_num_trial,
                'Ais': Ais,
                'cauchy_size': cauchy_size,
                'coreset_size': coreset_size,
                'sketch_size': sketch_size, 
                'lewis_weight_size': approx_rank,
                'sparsity': min(sketch_size, 5), 
                'setting_num': setting_num}

    sparse_embed_settings = [setting1, setting2, setting3, setting4, setting5]
    for setting in sparse_embed_settings: 
        setting_data = run_sparse_embedding_setting(setting=setting) 
        # all_data.append(setting_data)

        # checkpointing
        setting_num = setting['setting_num']
        checkpoint_setting_results(image_num=image_num, approx_rank=approx_rank, 
            setting_num=setting_num, setting_data=setting_data)

    # =========================================================
    # SVD
    # =========================================================
    # Setting 6
    setting_num = 6 
    svd_approx = rank_k_svd(A, approx_rank)
    svd_error = np.sum(np.abs(svd_approx - A))
    checkpoint_setting_results(image_num=image_num, approx_rank=approx_rank,
        setting_num=setting_num, setting_data=(svd_approx, svd_error, setting_num))

    # return (all_data, (svd_approx, svd_error, setting_num))

# run experiments for image_filepaths, all ranks, all settings, all trials 
# len(image_filepath) * 6 * 6 * 15 = len(image_filepaths) * 540 number of SettingsResutls 
def run_caltech101_experiments(results_filepath, image_filepaths, test_ranks, num_trials): 
    # experiment parameters
    greedy_num_trial = num_trials
    sparse_embedd_num_trial = num_trials

    # load images 
    images = [] 
    for image_num, image_filepath in image_filepaths:
        A = load_image(image_filepath)

        print(image_filepath)
        print("Image size: ", A.shape) 
        print("dtype: ", A.dtype)

        A_rows, A_cols = A.shape
        assert(A_cols == 300) 
        cutoff = A_cols // 3 
        Ais = [A[:, 0:cutoff], A[:, cutoff:2*cutoff], A[:, 2*cutoff:]]
        images.append((image_num, A, Ais))

    # run experiments 
    # results = dict()
    for image_num, A, Ais in images: 
        for approx_rank in test_ranks: 
            print("Image: %d, approx_rank: %d" % (image_num, approx_rank))
            all_settings_data = conduct_trials_one_image_one_rank(image_num=image_num, A=A, Ais=Ais, approx_rank=approx_rank, 
                greedy_num_trial=greedy_num_trial, sparse_embedd_num_trial=sparse_embedd_num_trial)
            # results[(image_num, approx_rank)] = all_settings_data

    # save resutls
    # with open(results_filepath, 'wb') as outfile:
    #     pickle.dump(results, outfile)

    print("Caltech101 Experiments Completed")

def save_image_from_results(save_dir, results_filepath, image_filepaths, test_ranks): 
    # load results 
    results = pickle.load(open(results_filepath, 'rb'))

    # save images 
    if not os.path.isdir(save_dir): 
        os.mkdir(save_dir)
    for image_num, image_filepath in image_filepaths: 
        for approx_rank in test_ranks: 
            all_settings_data = results[(image_num, approx_rank)]
            advanced_methods_data, svd_data = all_settings_data 

            # greedy + sparse embedding 
            for setting_num in range(len(advanced_methods_data)):
                setting_result = advanced_methods_data[setting_num]
                assert(setting_result.setting_num == setting_num) 

                # find trial that has the least error 
                best_trial = None
                best_error = float('inf')
                for trial_num, error in enumerate(setting_result.error_list): 
                    if error < best_error:
                        best_error = error
                        best_trial = trial_num
                best_image_data = setting_result.UV_list[best_trial]
                # print(best_image_data.dtype)
                # print(best_image_data.shape)
                # print(best_image_data)

                # save best image for this setting
                out_filepath = save_dir + "image%d_rank%d_setting%d" % (image_num, approx_rank, setting_num)
                save_image(data=best_image_data, filepath=out_filepath)


            # svd
            svd_approx, svd_error, setting_num = svd_data 
            out_filepath = save_dir + "image%d_rank%d_setting%d" % (image_num, approx_rank, setting_num)
            save_image(data=svd_approx, filepath=out_filepath)

if __name__ == '__main__':
    # Run all images 
    image_filepaths = ALL_IMAGE_FILEPATHS
    test_ranks = [10, 20, 30, 40, 50, 60]
    num_trials = 15
    results_filepath = "Caltech101_All_Results.pickle"
    run_caltech101_experiments(results_filepath=results_filepath, image_filepaths=image_filepaths, 
        test_ranks=test_ranks, num_trials=num_trials)
    save_image_from_results(save_dir="caltech101_exp_images/", results_filepath=results_filepath, 
        image_filepaths=image_filepaths, test_ranks=test_ranks)

    # # TESTING 
    # image_filepaths = ALL_IMAGE_FILEPATHS[:1]
    # test_ranks = [10]
    # num_trials = 1
    # results_filepath = "caltech101_all_results.pickle"
    # run_caltech101_experiments(results_filepath=results_filepath, image_filepaths=image_filepaths, 
    #     test_ranks=test_ranks, num_trials=num_trials)
    # save_image_from_results(save_dir="caltech101_exp_images/", results_filepath=results_filepath, 
    #     image_filepaths=image_filepaths, test_ranks=test_ranks)


