import hydra
import hydra.experimental
import json
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import random

from torchfly.flyconfig import GlobalFlyConfig
from torchfly.training import TrainerLoop
from torchfly.common import set_random_seed, launch_distributed
import logging

from compressive_transformer.recurrent.recurrent_trainer_loop import RecurrentTrainerLoop
from compressive_transformer.compressive_transformer_flymodel import CTransFlyModel
from dataloaders.wiki103_nosent_bound import Wiki103NoSentBoundDataLoader

logger = logging.getLogger(__name__)

config_path = None


def main():
    # torch.autograd.set_detect_anomaly(True)
    config_manager = GlobalFlyConfig(config_path=config_path)
    config = config_manager.user_config
    config.training.batch_size = config.flydata.training.dataloader.batch_size

    set_random_seed(config.training.random_seed)
    dataloader_handler = Wiki103NoSentBoundDataLoader(config)

    model = CTransFlyModel(config)
    trainer = RecurrentTrainerLoop(
        config,
        model,
        train_dataloader_fn=dataloader_handler.train_dataloader_fn,
        valid_dataloader_fn=dataloader_handler.valid_dataloader_fn
    )
    trainer.update_log_keys(["attn_loss", "word_loss"])

    trainer.train()


if __name__ == "__main__":
    launch_distributed(config_path=config_path, worker_fn=main)
