import pytest

import numpy as np
import torch

from simulator.worker import Worker
from simulator.fixed_time_utils import run_pipeline
from simulator.algorithms.synchronized_sgd import SynchronizedSGDServer
from simulator.algorithms.rennala_sgd import RennalaSGDServer, RennalaSGDWorker
from simulator.algorithms.ringmaster_sgd import RingmasterSGDServer
from simulator.functions import StochasticQuadraticFunction, generate_random_vector


def test_synchronized_centralized_sgd():
    num_workers = 100
    dim = 10
    func = StochasticQuadraticFunction.create_random(dim, seed=42, noise=0.5, reg=0.1)
    functions = [func for _ in range(num_workers)]
    gamma = 0.1
    point = torch.tensor(generate_random_vector(dim, seed=53))
    times_to_calculate = [1.] * num_workers
    times_to_calculate[-1] = 10. # the last worker is slow
    sim_time = 50000
    solution, stats = run_pipeline(SynchronizedSGDServer, Worker,
                                   functions, point, gamma, sim_time=sim_time,
                                   times_to_calculate=times_to_calculate,
                                   optimizer_cls=torch.optim.SGD)
    
    analytical_solution = np.linalg.solve(func._quadratic_function._A, func._quadratic_function._b)
    np.testing.assert_array_almost_equal(solution, analytical_solution, decimal=1)
    assert stats['steps'] >= 4990 and stats['steps'] <= 5010


def test_synchronized_centralized_sgd_with_communication():
    num_workers = 100
    dim = 10
    func = StochasticQuadraticFunction.create_random(dim, seed=42, noise=0.5, reg=0.1)
    functions = [func for _ in range(num_workers)]
    gamma = 0.1
    point = torch.tensor(generate_random_vector(dim, seed=53))
    times_to_calculate = [1.] * num_workers
    sim_time = 10
    
    # Test with communication
    times_to_communicate = [1.0] * num_workers
    times_with = []
    def calculate_metrics(env, iter, point):
        times_with.append(env.now)
    run_pipeline(SynchronizedSGDServer, Worker,
                 functions, point, gamma, sim_time=sim_time,
                 times_to_calculate=times_to_calculate,
                 times_to_communicate=times_to_communicate,
                 calculate_metrics=calculate_metrics,
                 optimizer_cls=torch.optim.SGD)
    
    # Test without communication
    times_without = []
    def calculate_metrics(env, iter, point):
        times_without.append(env.now)
    times_to_communicate = [0.] * num_workers
    run_pipeline(SynchronizedSGDServer, Worker,
                 functions, point, gamma, sim_time=sim_time,
                 times_to_calculate=times_to_calculate,
                 times_to_communicate=times_to_communicate,
                 calculate_metrics=calculate_metrics,
                 optimizer_cls=torch.optim.SGD)
    
    # Due to communication, the time to calculate is higher than without communication
    for i in range(len(times_with)):
        assert times_with[i] == 3 * times_without[i]


@pytest.mark.parametrize("alg", [RennalaSGDServer, RingmasterSGDServer])
def test_rennala_and_ringmaster_sgd_same_times_except_last(alg):
    num_workers = 10
    dim = 10
    func = StochasticQuadraticFunction.create_random(dim, seed=42, noise=0.5, reg=0.1)
    functions = [func for _ in range(num_workers)]
    point = torch.tensor(generate_random_vector(dim, seed=53))
    times_to_calculate = [1.] * num_workers
    times_to_calculate[-1] = 10. # the last worker is slow
    sim_time = 10000
    num_grads = 30
    if alg == RennalaSGDServer:
        gamma = 0.1 / num_grads
        solution, stats = run_pipeline(RennalaSGDServer, RennalaSGDWorker,
                              functions, point, gamma, sim_time=sim_time,
                              times_to_calculate=times_to_calculate,
                              server_params={'num_grads': num_grads},
                              optimizer_cls=torch.optim.SGD)
    elif alg == RingmasterSGDServer:
        gamma = 0.1 / num_grads
        solution, stats = run_pipeline(RingmasterSGDServer, Worker,
                              functions, point, gamma, sim_time=sim_time,
                              times_to_calculate=times_to_calculate,
                              server_params={'num_grads': num_grads},
                              optimizer_cls=torch.optim.SGD)
    else:
        assert False
    
    analytical_solution = np.linalg.solve(func._quadratic_function._A, func._quadratic_function._b)
    np.testing.assert_array_almost_equal(solution, analytical_solution, decimal=1)

    assert stats["worker_useful"][0] > 0
    assert stats["worker_useful"][-1] == 0 # the last is too slow


@pytest.mark.parametrize("alg", [RennalaSGDServer, RingmasterSGDServer])
def test_rennala_and_ringmaster_sgd_increasing_times(alg):
    num_workers = 10
    dim = 10
    func = StochasticQuadraticFunction.create_random(dim, seed=42, noise=1.0, reg=0.1)
    functions = [func for _ in range(num_workers)]
    point = torch.tensor(generate_random_vector(dim, seed=53))
    times_to_calculate = np.sqrt(np.array(range(num_workers)) + 1)
    sim_time = 50000
    num_grads = 30
    if alg == RennalaSGDServer:
        gamma = 0.1 / num_grads
        solution, stats = run_pipeline(RennalaSGDServer, RennalaSGDWorker,
                              functions, point, gamma, sim_time=sim_time,
                              times_to_calculate=times_to_calculate,
                              server_params={'num_grads': num_grads},
                              optimizer_cls=torch.optim.SGD)
    elif alg == RingmasterSGDServer:
        gamma = 0.1 / num_grads
        solution, stats = run_pipeline(RingmasterSGDServer, Worker,
                              functions, point, gamma, sim_time=sim_time,
                              times_to_calculate=times_to_calculate,
                              server_params={'num_grads': num_grads},
                              optimizer_cls=torch.optim.SGD)
    else:
        assert False
    assert stats["worker_useful"][0] > stats["worker_useful"][-1]
    analytical_solution = np.linalg.solve(func._quadratic_function._A, func._quadratic_function._b)
    np.testing.assert_array_almost_equal(solution, analytical_solution, decimal=1)
