import torch
import numpy as np

from simulator.functions import QuadraticFunction, StochasticQuadraticFunction


def test_smoke_random_quadratic_function():
    dim = 10
    f = QuadraticFunction.create_random(dim=dim, seed=42)
    point = torch.tensor(np.random.rand(dim).astype(np.float32))
    v = f.value(point)
    g = f.gradient(point)
    assert len(g.shape) == 1 and g.shape[0] == dim


def test_smoke_random_stochastic_quadratic_function():
    dim = 10
    f = StochasticQuadraticFunction.create_random(dim=dim, seed=42)
    point = torch.tensor(np.random.rand(dim).astype(np.float32))
    g = f.stochastic_gradient(point)
    assert len(g.shape) == 1 and g.shape[0] == dim
    g_new = f.stochastic_gradient(point)
    assert not np.array_equal(g, g_new)
    assert np.array_equal(f.gradient(point), f.gradient(point))
