
import numpy as np

from distributed_optimization_library.experiments.local_optimization_pytorch.optimize_model \
    import MomentumVarianceReduction, SGD
from distributed_optimization_library.function import StochasticQuadraticFunction



def test_mvr_with_stochastic_quadratic_function():
    dim = 100
    number_of_iter = 100000
    func = StochasticQuadraticFunction.create_random(dim=dim, seed=42, reg=0.1, noise=0.5)
    analytical_solution = np.linalg.solve(func._quadratic_function._A, func._quadratic_function._b)
    point = np.zeros((dim,), dtype=np.float32)
    # optimizer = MomentumVarianceReduction(func, point, lr=0.01, momentum=1e-4)
    optimizer = SGD(func, point, lr=0.01, momentum=1e-2)
    for _ in range(number_of_iter):
        optimizer.step()
        # gradient_norm = func.gradient(optimizer.get_point())
    point = optimizer.get_point()
    np.testing.assert_array_almost_equal(point, analytical_solution, decimal=1)
