import hydra
import hydra.experimental
import json
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import random

from torchfly.flyconfig import GlobalFlyConfig
from torchfly.training import TrainerLoop
from torchfly.common import set_random_seed, launch_distributed
import logging

from model.recurrent.recurrent_trainer_loop import RecurrentTrainerLoop
from model.rnn_flymodel import RNNFlymodel
from dataloaders.memformer_dataloader import TextDataLoaderHelper

logger = logging.getLogger(__name__)

config_path = None


def main():
    # torch.autograd.set_detect_anomaly(True)
    config_manager = GlobalFlyConfig(config_path=config_path)
    config = config_manager.user_config
    config.training.batch_size = config.flydata.training.dataloader.batch_size

    set_random_seed(config.training.random_seed)
    dataloader_handler = TextDataLoaderHelper(config)

    model = RNNFlymodel(config)
    trainer = RecurrentTrainerLoop(
        config,
        model,
        train_dataloader_fn=dataloader_handler.train_dataloader_fn,
        valid_dataloader_fn=dataloader_handler.valid_dataloader_fn
    )
    trainer.train()


if __name__ == "__main__":
    launch_distributed(config_path=config_path, worker_fn=main)
