import os
import jax
from t5x import gin_utils
from t5x import models
from t5x import utils
from t5x import partitioning


_DEFAULT_GIN_SEARCH_PATHS = [
    os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
]


def encode(
    *,
    model: models.BaseTransformerModel,
    restore_checkpoint_cfg: utils.RestoreCheckpointConfig,
    partitioner: partitioning.BasePartitioner
):

    input_shapes = {}
    input_types = {}

    train_state_initializer = utils.TrainStateInitializer(
        optimizer_def=None,
        init_fn=model.get_initial_variables,
        input_shapes=input_shapes,
        input_types=input_types,
        partitioner=partitioner
    )

    restore_checkpoint_cfg.strict = False
    train_state = train_state_initializer.from_checkpoint(
        [restore_checkpoint_cfg], init_rng=jax.random.PRNGKey(0)
    )

    import ipdb; ipdb.set_trace()

    print('Loaded')


if __name__ == '__main__':
    # pylint:disable=g-import-not-at-top
    from absl import app
    from absl import flags
    import gin
    # pylint:enable=g-import-not-at-top

    FLAGS = flags.FLAGS

    jax.config.parse_flags_with_absl()

    flags.DEFINE_multi_string(
        'gin_file',
        default=None,
        help='Path to gin configuration file. Multiple paths may be passed and '
        'will be imported in the given order, with later configurations  '
        'overriding earlier ones.')

    flags.DEFINE_multi_string(
        'gin_bindings', default=[], help='Individual gin bindings.')

    flags.DEFINE_list(
        'gin_search_paths',
        default=['.'],
        help='Comma-separated list of gin config path prefixes to be prepended '
        'to suffixes given via `--gin_file`. If a file appears in. Only the '
        'first prefix that produces a valid path for each suffix will be '
        'used.')

    flags.DEFINE_string(
        'tfds_data_dir', None,
        'If set, this directory will be used to store datasets prepared by '
        'TensorFlow Datasets that are not available in the public TFDS GCS '
        'bucket. Note that this flag overrides the `tfds_data_dir` attribute of '
        'all `Task`s.')


    def main(argv):
        """Wrapper for pdb post mortems."""
        _main(argv)

    def _main(argv):
        """True main function."""
        if len(argv) > 1:
            raise app.UsageError('Too many command-line arguments.')

        # Create gin-configurable version of `infer`.
        encode_using_gin = gin.configurable(encode)

        gin_utils.parse_gin_flags(
            # User-provided gin paths take precedence if relative paths conflict.
            FLAGS.gin_search_paths + _DEFAULT_GIN_SEARCH_PATHS,
            FLAGS.gin_file,
            FLAGS.gin_bindings)

        encode_using_gin()


    gin_utils.run(main)
