import numpy as np
from coordinator import Coordinator
from server import Server
from dataloader import * 
from utils import *
import pickle
import time
import os

class SettingResults(object):

    def __init__(self, setting_num, approx_rank, error_list, UV_list, work_list, span_list, l1_regression_work_list, l1_regression_span_list, svd_work=None):
        self.setting_num = setting_num
        self.approx_rank = approx_rank
        self.error_list = error_list
        self.UV_list = UV_list
        self.work_list = work_list
        self.span_list = span_list
        self.l1_regression_work_list = l1_regression_work_list
        self.l1_regression_span_list = l1_regression_span_list
        self.svd_work = svd_work

def distributed_protocol(Ais, cauchy_size, coreset_size, approx_rank, greedy, sketch_size, sparsity, lewis_weight_size):
    work_start = time.process_time()
    span = 0

    num_rows = Ais[0].shape[0]
    n_server = len(Ais)

    servers = [Server(Ais[i]) for i in range(n_server)]
    coordinator = Coordinator(servers=servers, approx_rank=approx_rank, sketch_size=sketch_size, sparsity=sparsity, lewis_weight_size=lewis_weight_size)

    # Coordinator generates Cauchy matrix and sends to servers. The
    # span is incremented by the time taken for the server to do this.
    coordinator_cauchy_time = time.process_time()
    coordinator.send_cauchy_matrix(n_row=cauchy_size, n_col=num_rows)
    coordinator_cauchy_time = time.process_time() - coordinator_cauchy_time
    span += coordinator_cauchy_time

    # Each server generates a coreset. Since these are all done in parallel,
    # the span is incremented by the maximum of the times taken by the servers
    # to generate the coresets.
    max_server_coreset_time = 0
    for server in servers:
        server_coreset_start = time.process_time()
        server.generate_coreset(coreset_size)
        server_coreset_time = time.process_time() - server_coreset_start
        max_server_coreset_time = max(max_server_coreset_time, server_coreset_time)
    span += max_server_coreset_time

    # Coordinator runs CSS in the l_{1, 2}-norm. The servers all have to
    # wait for this, so we directly add the time taken here to the span.
    css_l12_start_time = time.process_time()
    coordinator.run_CSS_l12(greedy=greedy)
    css_l12_runtime = time.process_time() - css_l12_start_time
    span += css_l12_runtime

    # Final work - we are not including the regression error computation.
    work = time.process_time() - work_start

    # Compute the errors using l1 regression.
    error, reconstruction, l1_regression_work, l1_regression_span = coordinator.parallel_lra_error()
    print('l1 error : {}'.format(error))

    return error, reconstruction, work, span, l1_regression_work, l1_regression_span

def distributed_via_greedy(Ais, cauchy_size, coreset_size, approx_rank): 
    start = time.time() 
    results = distributed_protocol(Ais=Ais, cauchy_size=cauchy_size, coreset_size=coreset_size, approx_rank=approx_rank, greedy=True, sketch_size=-1, sparsity=-1, lewis_weight_size=-1)
    end = time.time()
    print("Greedy Time: ", end - start)
    return results 

def distributed_via_sparse_embedding(Ais, cauchy_size, coreset_size, sketch_size, sparsity, lewis_weight_size): 
    start = time.time()
    results = distributed_protocol(Ais=Ais, cauchy_size=cauchy_size, coreset_size=coreset_size, approx_rank=-1, greedy=False, sketch_size=sketch_size, sparsity=sparsity, lewis_weight_size=lewis_weight_size)
    end = time.time()
    print("Sparse Embedding Time: ", end - start)
    return results 

# Returns the results from all of the trials
def sparse_embedding_multiple_trials(num_trials, Ais, cauchy_size, coreset_size, sketch_size, sparsity, lewis_weight_size, setting_num):
    errors = []
    UVs = []
    works = []
    spans = []
    l1_regression_works = []
    l1_regression_spans = []
    for k in range(num_trials):
        print("Rank: %d, Setting: %d, Trial Number: %d" % (lewis_weight_size, setting_num, k))
        error, UV, work, span, l1_regression_work, l1_regression_span = distributed_via_sparse_embedding(Ais=Ais, cauchy_size=cauchy_size, coreset_size=coreset_size, sketch_size=sketch_size, sparsity=sparsity, lewis_weight_size=lewis_weight_size)
        errors.append(error)
        UVs.append(UV)
        works.append(work)
        spans.append(span)
        l1_regression_works.append(l1_regression_work)
        l1_regression_spans.append(l1_regression_span)
    
    setting_data = SettingResults(setting_num=setting_num, approx_rank=lewis_weight_size, error_list=errors, UV_list=UVs, work_list=works, span_list=spans, l1_regression_work_list=l1_regression_works, l1_regression_span_list=l1_regression_spans)
    return setting_data

def greedy_multiple_trials(num_trials, Ais, cauchy_size, coreset_size, approx_rank, setting_num=0):
    errors = []
    UVs = []
    works = []
    spans = []
    l1_regression_works = []
    l1_regression_spans = []
    for k in range(num_trials):
        print("Rank: %d, Setting: %d, Trial Number: %d" % (approx_rank, setting_num, k))
        error, UV, work, span, l1_regression_work, l1_regression_span = distributed_via_greedy(Ais=Ais, cauchy_size=cauchy_size, coreset_size=coreset_size, approx_rank=approx_rank)
        errors.append(error)
        UVs.append(UV)
        works.append(work)
        spans.append(span)
        l1_regression_works.append(l1_regression_work)
        l1_regression_spans.append(l1_regression_span)

    setting_data = SettingResults(setting_num=setting_num, approx_rank=approx_rank, error_list=errors, UV_list=UVs, work_list=works, span_list=spans, l1_regression_work_list=l1_regression_works, l1_regression_span_list=l1_regression_spans)
    return setting_data

def run_sparse_embedding_setting(setting): 
    num_trials = setting['num_trials']
    Ais = setting['Ais']
    cauchy_size = setting['cauchy_size']
    coreset_size = setting['coreset_size']
    sketch_size = setting['sketch_size']
    sparsity = setting['sparsity']
    lewis_weight_size = setting['lewis_weight_size']
    setting_num = setting['setting_num']

    return sparse_embedding_multiple_trials(num_trials=num_trials, Ais=Ais, cauchy_size=cauchy_size, coreset_size=coreset_size, 
        sketch_size=sketch_size, sparsity=sparsity, lewis_weight_size=lewis_weight_size, setting_num=setting_num) 

# Pickles the result from a particular setting and places it in a checkpoint file.
def save_checkpoint(setting_result, dataset_name, test_rank, cauchy_size, setting_num):
    checkpoint_dir = "Checkpoints/" 
    if not os.path.isdir(checkpoint_dir): 
        os.mkdir(checkpoint_dir)

    filename = checkpoint_dir + "%s_rank_%d_results_cauchy_size_%d_setting_num_%d.pickle" % (dataset_name, test_rank, cauchy_size, setting_num)
    with open(filename, "wb") as outfile:
        pickle.dump(setting_result, outfile)

"""

def run_caltech101_experiments(): 
    A = load_image_as_np_matrix("101_ObjectCategories/car_side/image_0002.jpg")
    A = np.asarray(A)
    print("Image size: ", A.shape) 
    A_rows, A_cols = A.shape
    cutoff = A_cols // 2
    Ais = [A[:, 0:cutoff], A[:, cutoff:]]

    num_trials = 1
    approx_rank = 20
    cauchy_size = 8 * approx_rank 
    coreset_size = 10 * approx_rank 

    # Setting 1
    greedy_error = float('inf')
    greedy_work = float('inf') 
    greedy_span = float('inf') 
    greedy_UV = None

    for k in range(num_trials): 
        print("greedy trial number: %d" % k)
        error, UV, work, span = distributed_via_greedy(Ais=Ais, cauchy_size=cauchy_size, coreset_size=coreset_size, approx_rank=approx_rank)
        if error < greedy_error: 
            greedy_error = error
            greedy_work = work
            greedy_span = span 
            greedy_UV = UV
    print(greedy_UV.shape)
    approx_A = greedy_UV
    save_image(approx_A, "rank_20.png")

def main():
    A = load_bcsstk13("bcsstk13.mtx")
    # A = load_isolet("isolet1+2+3+4.csv")
    print(type(A))
    # with open('3s_bbc.pkl', 'rb') as f:
    #     A = pickle.load(f)
    A_rows, A_cols = A.shape 
    cutoff = A_cols // 2
    Ais = [A[:, 0:cutoff], A[:, cutoff:]]

    test_ranks = [10, 20]
    
    errors = []
    works = []
    spans = []
    
    for approx_rank in test_ranks:
        new_error, _, _, new_work, new_span = conduct_trials(Ais, approx_rank, num_trials=2)
        print(new_error)
        errors.append(new_error)
        works.append(new_work)
        spans.append(new_span)
    print(errors)

    # # # hyperparams
    # # n_server = 2
    # # n = 10
    # # mis = [30, 50]
    # # # hyperparams related to rank
    # rank = 10
    # sketch_size = 5 * rank
    # sparsity = min(sketch_size, 20)
    # cauchy_size = 8 * rank
    # lewis_weight_size = rank
    # coreset_size = 10 * rank

    # distributed_protocol(Ais=Ais, cauchy_size=cauchy_size, coreset_size=coreset_size, approx_rank=rank,
    #                      greedy=False, sketch_size=sketch_size, sparsity=sparsity,
    #                      lewis_weight_size=lewis_weight_size)


# # gen data matrix
    # Ais = [np.random.randn(n, mis[i]) * 10. for i in range(n_server)]
    # servers = [Server(Ais[i]) for i in range(n_server)]
    # coordinator = Coordinator(servers, rank, sketch_size, sparsity, lewis_weight_size)

    # # proceed to interaction
    # coordinator.send_cauchy_matrix(cauchy_size, n)
    # for server in servers:
    #     server.generate_coreset(coreset_size)

    # # greedy
    # coordinator.run_CSS_l12(greedy=True)
    # err = coordinator.check_lra_error()
    # print('greedy error : {}'.format(err))

    # # CSS 
    # coordinator.run_CSS_l12(greedy=False)
    # err = coordinator.check_lra_error()
    # print('CSS error : {}'.format(err))


if __name__ == '__main__':
    main()
    # run_caltech101_experiments()

"""