from utils import rank_k_svd_with_work_time
from dataloader import *
from main_protocol import SettingResults, save_checkpoint
from caltech101_experiments import ALL_IMAGE_FILEPATHS, checkpoint_setting_results

def ran_svd_experiment(A, approx_rank, dataset_name, setting_num, include_UV=False, caltech101=False, image_num=-1): 
    # ran svd
    svd_approx, work = rank_k_svd_with_work_time(A, approx_rank)
    svd_error = np.sum(np.abs(svd_approx - A))

    # store results 
    UV_list = []
    if include_UV: 
        UV_list = [svd_approx] 
    setting_result = SettingResults(setting_num=setting_num, approx_rank=approx_rank, 
        error_list=[svd_error], UV_list=UV_list, work_list=[], span_list=[], 
        l1_regression_work_list=[], l1_regression_span_list=[], svd_work=work)

    # save results 
    if caltech101: 
        checkpoint_setting_results(image_num=image_num, approx_rank=approx_rank, 
            setting_num=-1, setting_data=setting_result)
    else:
        save_checkpoint(setting_result=setting_result, dataset_name=dataset_name, 
            test_rank=approx_rank, cauchy_size=-1, setting_num=setting_num)

if __name__ == "__main__":
    test_ranks = [10, 20, 30, 40, 50, 60]

    # forest cover 
    A, _ = load_forest_cover("forest_cover_500x3000.npy")
    for approx_rank in test_ranks: 
        print("rank: %d" % approx_rank)
        ran_svd_experiment(A=A, approx_rank=approx_rank, dataset_name="forest_cover", setting_num=-1, include_UV=False)

    # bcsstk13 
    A = load_bcsstk13("bcsstk13.mtx")
    for approx_rank in test_ranks: 
        print("rank: %d" % approx_rank)
        ran_svd_experiment(A=A, approx_rank=approx_rank, dataset_name="bcsstk13s", setting_num=-1, include_UV=False)

    # isolet transpose
    A, _ = load_isolet_transpose("isolet1+2+3+4.csv")
    for approx_rank in test_ranks: 
        print("rank: %d" % approx_rank)
        ran_svd_experiment(A=A, approx_rank=approx_rank, dataset_name="isolet_transpose", setting_num=-1, include_UV=False)

    # caltech101 
    for image_num, image_filepath in ALL_IMAGE_FILEPATHS:
        A = load_image(image_filepath)
        for approx_rank in test_ranks: 
            print("rank: %d" % approx_rank)
            ran_svd_experiment(A=A, approx_rank=approx_rank, dataset_name="caltech101", setting_num=-1, 
                include_UV=True, caltech101=True, image_num=image_num)







