
import gin
from absl import flags
from absl.testing import absltest

from base import idefics_transforms
from config.gin_template import parse_config
from igl import igl_transform
from utils import XTrainingArguments


class Config(absltest.TestCase):

    def tearDown(self):
        gin.clear_config()

    def test_kto(self):
        absltest.flags.FLAGS.gin_file = ['config/kto.gin']
        absltest.flags.FLAGS.gin_param = [
            'LocalConfig.checkpoint_path="none-empty"']
        parse_config(verbose=False)
        xargs = XTrainingArguments()
        self.assertTrue(xargs.learning_rate == 1e-6)
        self.assertTrue(xargs.per_device_train_batch_size == 2)

    def test_base(self):
        absltest.flags.FLAGS.gin_file = ['config/multitask.gin']
        parse_config(verbose=False)
        xargs = XTrainingArguments()
        self.assertTrue(xargs.learning_rate == 1e-4)
        self.assertTrue(xargs.per_device_train_batch_size == 2)


if __name__ == '__main__':
    absltest.main()
