import hydra
import hydra.experimental
import json
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import random

from torchfly.flyconfig import GlobalFlyConfig
from torchfly.training import TrainerLoop
from torchfly.common import set_random_seed, launch_distributed
import logging

from transformerxl.recurrent.recurrent_trainer_loop import RecurrentTrainerLoop
from transformerxl.transformer_xl_flymodel import TransformerXLFlyModel
from dataloaders.wiki103_nosent_bound import Wiki103NoSentBoundDataLoader

logger = logging.getLogger(__name__)

config_path = None


def main():
    # torch.autograd.set_detect_anomaly(True)
    config_manager = GlobalFlyConfig(config_path=config_path)
    config = config_manager.user_config
    config.training.batch_size = config.flydata.training.dataloader.batch_size

    set_random_seed(config.training.random_seed)
    dataloader_handler = Wiki103NoSentBoundDataLoader(config)

    model = TransformerXLFlyModel(config)
    trainer = RecurrentTrainerLoop(
        config,
        model,
        train_dataloader_fn=dataloader_handler.train_dataloader_fn,
        valid_dataloader_fn=dataloader_handler.valid_dataloader_fn
    )
    trainer.train()


if __name__ == "__main__":
    launch_distributed(config_path=config_path, worker_fn=main)
