import importlib
import os
import tensorflow as tf

from absl import app, flags
from ml_collections import ConfigDict, config_flags
from torch.cuda import device_count


tf.config.set_visible_devices([], 'GPU')
CONFIG = config_flags.DEFINE_config_file("config")
FLAGS = flags.FLAGS
flags.DEFINE_string(
    "trainer", default="transformer_decoder_torchrun", help="The trainer type."
)
flags.DEFINE_boolean('testmode', False, 'Enable feature Y.')

def main(_):
    cfg = ConfigDict(CONFIG.value)
    trainer = importlib.__import__(f"train_{FLAGS.trainer}")
    if FLAGS.testmode:
        print("Enabling testmode!")
        cfg.testmode = True

    if int(os.environ["LOCAL_RANK"]) == 0:
        print("##########################################")
        print(f"Training with train_{FLAGS.trainer}.")
        print("##########################################")

    # Setup training
    world_size = device_count()
    if cfg.trainer_config.distributed and world_size != cfg.trainer_config.world_size:
        raise ValueError(
            "Requested world size is not the same as number of visible GPUs."
        )
    if cfg.trainer_config.distributed:
        # if world_size < 2:
        #    raise ValueError(
        #        "Distributed training cannot be run on machine "
        #        f"with {world_size} device(s)."
        #    )
        if cfg.trainer_config.batch_size % world_size != 0:
            raise ValueError(
                f"Batch size {cfg.trainer_config.batch_size} is not evenly "
                f"divisble by # GPUs = {world_size}."
            )
        cfg.trainer_config.batch_size = cfg.trainer_config.batch_size // world_size
        trainer.train_distributed(cfg)
    else:
        trainer.train(cfg)


if __name__ == "__main__":
    app.run(main)
            
