from sklearn.datasets import load_svmlight_file
import numpy as np
import os
import pickle
NUM_ITER_SOLUTION = 500000 
MAX_TIME_SOLUTION = 3600
TOLERANCE_SOLUTION = 1e-10
EPS = 1e-6
from .experiment_utils import run_centralized_experiment_real_data, solve_with_extragradient_real_data, compute_robust_linear_normed_L, create_robust_linear_oracle
from oracles.saddle import create_robust_linear_oracle, OracleLinearComb, ArrayPair, BaseSmoothSaddleOracle


def load_and_process_a9a(training_path):
    X, y = load_svmlight_file(training_path)
    X = X.toarray()
    y = y.reshape(-1, 1)
    y = np.array(y).squeeze()
    return X, y


if __name__ == "__main__":
    training_path = './a9a'
    A, b = load_and_process_a9a(training_path)
    num_nodes = 500
    regcoef_x = 2.
    regcoef_y = 2.
    r_x = 1.
    r_y = 1.
    oracles = []
    part_sizes = np.empty(num_nodes, dtype=np.int32)
    part_sizes[:] = A.shape[0] // num_nodes
    part_sizes[:A.shape[0] - part_sizes.sum()] += 1
    start = 0
    for part_size in part_sizes:
        A_small = A[start: start + part_size]
        b_small = b[start: start + part_size]
        oracles.append(create_robust_linear_oracle(
            A_small, b_small, regcoef_x, regcoef_y, normed=True))
        start += part_size
    oracle_mean = OracleLinearComb(oracles, [1 / num_nodes] * num_nodes)
    
    z_0 = ArrayPair.zeros(A.shape[1])
    z_true = solve_with_extragradient_real_data(
        A, b, regcoef_x, regcoef_y, r_x, r_y,
        num_iter=NUM_ITER_SOLUTION, max_time=MAX_TIME_SOLUTION, tolerance=TOLERANCE_SOLUTION)
    print('Running centralized experiment on a9a dataset with {} nodes'.format(num_nodes))
    extragrad, sliding, exsliding, sliding_vr, sliding_vrmb = run_centralized_experiment_real_data(
        A=A,
        b=b,
        num_nodes=num_nodes,
        regcoef_x=regcoef_x, 
        regcoef_y=regcoef_y,
        r_x=r_x,
        r_y=r_y,
        eps=EPS,
        comm_budget_experiment=100,
        z_true=z_true
    )
    
    folder = "./logs/centralized/node={}_a9a".format(num_nodes)
    os.makedirs(folder, exist_ok=True)

    with open(os.path.join(folder, "extragrad_th.pkl"), "wb") as f:
        pickle.dump(extragrad.logger, f)
    with open(os.path.join(folder, "sliding_th.pkl"), "wb") as f:
        pickle.dump(sliding.logger, f)
    with open(os.path.join(folder, "exsliding_th.pkl"), "wb") as f:
        pickle.dump(exsliding.logger, f)
    with open(os.path.join(folder, "sliding_vr_th.pkl"), "wb") as f:
        pickle.dump(sliding_vr.logger, f)
    with open(os.path.join(folder, "sliding_vrmb_th.pkl"), "wb") as f:
        pickle.dump(sliding_vrmb.logger, f)
    with open(os.path.join(folder, "z_true"), "wb") as f:
        pickle.dump(z_true, f)
        
    print('Done')