import unittest

import torch
from torch import Tensor

from bbo.problems import Synthetic


class SyntheticTest(unittest.TestCase):
    def get_dim(self, name):
        if name == 'Branin':
            dim = 2
        elif name == 'Shekel':
            dim = 4
        else:
            dim = 6
        return dim

    def test_problems(self):
        noise_std = 0.01
        for name in Synthetic.options:
            dim = self.get_dim(name)
            problem = Synthetic(name, dim, noise_std)
            self.assertEqual(problem.lb.shape, (dim, ))
            self.assertEqual(problem.ub.shape, (dim, ))

    def test_run(self):
        bs = 32

        noise_std = 0.01
        for name in Synthetic.options:
            dim = self.get_dim(name)
            problem = Synthetic(name, dim, noise_std)
            lb, ub = problem.lb, problem.ub
            X = lb + (ub - lb) * torch.rand((bs, dim))
            Y = problem(X)
            self.assertTrue(torch.is_tensor(Y))
            self.assertEqual(Y.shape, (bs, 1))
            