from utils import *
from lewis_sampling import *
import cvxpy as cp
from greedy import *
from l1_solver import *
import multiprocessing
import time

class Coordinator(object):

    def __init__(self, servers, approx_rank, sketch_size, sparsity, lewis_weight_size):
        self.servers = servers
        self.A_I = None

        # Params for greedy
        self.approx_rank = approx_rank 

        # Params for CSS_l12
        self.sketch_size = sketch_size
        self.sparsity = sparsity
        self.lewis_weight_size = lewis_weight_size

    def send_cauchy_matrix(self, n_row, n_col):
        S = dense_cauchy_matrix(size=(n_row, n_col))
        for server in self.servers:
            server.S = S

    def run_CSS_l12(self, greedy=False):
        SAT = np.concatenate([server.get_SAiTi() for server in self.servers], axis=-1)
        # print("SAT shape: ", SAT.shape)

        coordinator_column_indices = None
        if greedy: 
            _, _, coordinator_column_indices = greedy_approx_l12(A=SAT, num_cols=self.approx_rank)
        else: 
            _, coordinator_column_indices = CSS_l12(A=SAT, sketch_size=self.sketch_size, sparsity=self.sparsity,
                                                lewis_sample_rows=self.lewis_weight_size)

        # Note that the first entry of coreset_sizes_scanned is 0,
        # and the remaining |# of servers| entries of coreset_sizes_scanned
        # is the prefix sum. See scan in utils.py
        def get_server_and_column_index(index_in_SAT):
            server_index = None
            column_index = None

            for i in range(len(coreset_sizes_scanned)):
                if index_in_SAT < coreset_sizes_scanned[i]:
                    server_index = i - 1
                    break
            
            column_index = index_in_SAT - coreset_sizes_scanned[server_index]
            return server_index, column_index

        coreset_sizes_scanned = scan([self.servers[i].get_coreset_size() for i in range(len(self.servers))])
        selected_coreset_indices = [[] for _ in range(len(self.servers))]
        for sel_idx in coordinator_column_indices:
            server_index, column_index = get_server_and_column_index(sel_idx)
            selected_coreset_indices[server_index].append(column_index)

        # 4. Coordinator collect all columns selected
        A_selected_columns = []
        for i, server in enumerate(self.servers):
            A_selected_columns.append(server.get_selected_columns(selected_coreset_indices=selected_coreset_indices[i]))
        self.A_I = np.concatenate(A_selected_columns, axis=-1)

        # print("left factor shape: ", self.A_I.shape)
    
    # Faster l1 regression. Returns the l1 regression error, the
    # reconstruction, the work done for l1 regression, and the span
    # of all of these regressions.
    def parallel_lra_error(self):
        # Gather A_i from server
        work_start = time.process_time()
        span = 0

        preprocessing_start_time = time.process_time()
        A = np.hstack([server.Ai for server in self.servers])
        _, d = A.shape
        # print('A_I size: {}, A size: {}'.format(self.A_I.shape, A.shape))

        # Replace A with its left singular vectors
        # (removing the ones corresponding to 0 singular values)
        u, _, _ = np.linalg.svd(self.A_I)
        rank = np.linalg.matrix_rank(self.A_I)
        u = u[:, :rank]

        reconstruction = np.zeros(shape = A.shape)
        preprocessing_end_time = time.process_time()
        span += preprocessing_end_time - preprocessing_start_time

        max_l1_regression_time = 0
        for c_idx in range(d):
            individual_regression_start_time = time.process_time()

            q = np.array(A[:, c_idx])
            scale = np.linalg.norm(q)
            if scale == 0.0:
                solution = np.zeros(rank)
            else:
                q = q/scale
                _, solution = solve_l1_regression_v2(u, q, c_idx)

            # Update column of reconstruction
            reconstruction[:, c_idx] = scale * (u @ solution)

            # Span updates
            individual_regression_end_time = time.process_time()
            individual_regression_time = individual_regression_end_time - individual_regression_start_time
            max_l1_regression_time = max(max_l1_regression_time, individual_regression_time)
        
        total_l1_error = np.sum(np.abs(A - reconstruction))
        
        work = time.process_time() - work_start
        span += max_l1_regression_time

        return total_l1_error, reconstruction, work, span
    
    """
    Obsolete:

    def huber_regression_l1_error(self): 
        # gather A_i from server
        A_overall = np.hstack([server.Ai for server in self.servers])
        # left factor 
        U = self.A_I
        U_num_cols = U.shape[-1]
        A_num_cols = A_overall.shape[-1]
        V = np.zeros((U_num_cols, A_num_cols))
        # Find each column of V with hubrer regression.
        for i in range(A_num_cols): 
            regression_column = A_overall[:, i]
            regression_column = np.ravel(regression_column)
            regression_coeffs = huber_regression_coeffs(A=U, y=regression_column)
            V[:, i] = regression_coeffs
        error = np.sum(np.abs(U @ V - A_overall)) 
        return error, U, V

    def check_lra_error(self):
        print('Checking lra error ...')
        # gather A_i from server
        A = np.hstack([server.Ai for server in self.servers])
        # solve for right factor
        # solve X = argmin |A_IX - A|_1
        X = cp.Variable(shape=(self.A_I.shape[1], A.shape[1]))
        loss = cp.pnorm(cp.matmul(self.A_I, X) - A, p=1)
        problem = cp.Problem(cp.Minimize(loss))
        print('Sovling problem ...')
        print('A_I size: {}, A size: {}'.format(self.A_I.shape, A.shape))
        problem.solve(solver=cp.CVXOPT)
        X_val = X.value
        # compute error 
        err = np.sum(np.abs(np.matmul(self.A_I, X_val) - A))
        return err


        ================================================================

        Used to be part of parallel_lra_error: 

        c_indices = np.arange(10)
        P = np.array(self.A_I) + np.random.randn(self.A_I.shape)
        for c_idx in c_indices:
            q = np.array(A[:, c_idx])
            res = solve_l1_regression(P, q, c_idx)
        # processes = []
        # for c_idx in range(10):
        #     q = np.array(A[:, c_idx])
        #     p = multiprocessing.Process(target=solve_l1_regression, args=(P, q, c_idx))
        #     processes.append(p)
        #     p.start()
        #
        # for process in processes:
        #     process.join()
    """