import unittest
from sbsep.config import (
    ArchConfig,
    load_yaml,
    TrainConfig,
    TableConfig,
    FooConfig,
    VarScalingConfig,
    BNNConfig,
    BNNConfigFactory,
)


class TestConfigs(unittest.TestCase):
    c = load_yaml("./conf/logistic_gen.yaml")

    def test_archconfig(self):
        aconfig = {
            "name": "a",
            "dims": [1, 32, 128, 32, 1],
            "transform": "ExpTransform",
            "domain": [-5, 5],
            "normalize": True,
        }
        ac = ArchConfig(**aconfig)

        self.assertEqual(ac.normalize, True)
        self.assertEqual(ac.transform, "ExpTransform")

        aconfig = {
            "name": "a",
            "dims": [1, 32, 128, 32, 1],
            "domain": [-5, 5],
        }

        ac = ArchConfig(**aconfig)

        self.assertEqual(ac.normalize, False)
        self.assertEqual(ac.transform, "IdTransform")

        aconfig = {
            "name": "a",
        }
        try:
            # ac = from_dict_to_dataclass(ArchConfig, aconfig)
            ac = ArchConfig(**aconfig)
        except:
            self.assertTrue(True)

    def test_trainconfig(self):
        tc = TrainConfig()
        self.assertEqual(tc.nepochs, 1000)
        self.assertEqual(tc.pretrain, True)
        self.assertEqual(tc.sigma_obs, 0.002)

    def test_table(self):
        tconfig = {"xs": [1, 2.0, 3.0], "ys": [2.0, 2.0, 5.0]}

        tc = TableConfig(**tconfig)
        self.assertEqual(len(tc.xs), 3)
        self.assertEqual(len(tc.ys), 3)

    def test_table_from_function(self):

        aconfig = {"name": "a", "dims": [1, 32, 128, 32, 1], "domain": [-5, 5]}

        ac = ArchConfig(**aconfig)

        foo_config = {
            "function_name": "scaled_logistic",
            "module_name": "sbsep.util",
            "kwargs": {"a": 1000, "b": 10, "x0": 0.2, "scale": 0.2},
        }

        fooc = FooConfig(**foo_config)

        tc = TableConfig(ngrid=101)
        tc.from_function(ac, fooc)
        self.assertEqual(len(tc.xs), 101)

    def test_bnconfig(self):
        aconfig = {"name": "a", "dims": [1, 32, 128, 32, 1], "domain": [-5, 5]}

        ac = ArchConfig(**aconfig)
        tc = TrainConfig(sigma_obs=0.1)

        foo_config = {
            "function_name": "scaled_logistic",
            "module_name": "sbsep.util",
            "kwargs": {"a": 1000, "b": 10, "x0": 0.2, "scale": 0.2},
        }

        fooc = FooConfig(**foo_config)

        bc = BNNConfig(architecture=ac, train=tc, table=TableConfig(), foo=fooc)
        self.assertEqual(len(bc.architecture.dims), 5)
        bc_json = bc.to_json()
        self.assertEqual(bc_json["train"]["sigma_obs"], 0.1)
        self.assertEqual(bc_json["architecture"]["name"], "a")

    def test_bnconfig_foo(self):
        aconfig = {"name": "a", "dims": [1, 32, 128, 32, 1], "domain": [-5, 5]}

        ac = ArchConfig(**aconfig)
        tc = TrainConfig(pretrain=True)
        foo_config = {
            "function_name": "scaled_logistic",
            "module_name": "sbsep.util",
            "kwargs": {"a": 1000, "b": 10, "x0": 0.2, "scale": 0.2},
        }

        fooc = FooConfig(**foo_config)

        bc = BNNConfig(architecture=ac, train=tc, table=TableConfig(),
                       foo=fooc, varscaling=VarScalingConfig())
        self.assertEqual(len(bc.table.xs), 101)

    def test_bnn_factory(self):
        bc = BNNConfigFactory.get_bnn_config(self.c)
        self.assertEqual(len(bc.table.xs), 101)


if __name__ == "__main__":
    unittest.main()
