from Trainer import Trainer

import schedulers
import callbacks
import losses
import tasks
import torch
import utils
import click
import os

@click.group(invoke_without_command=True)
@utils.options.modadd
def modadd(dir, seed, epochs, train_batch_size, eval_batch_size, src_symbols, tgt_symbols, etv, compile, data_device, weight_decay, learning_rate, device):
    utils.seed_everything(seed)
    torch.set_float32_matmul_precision("high")
    torch.set_printoptions(sci_mode=False, linewidth=1000)
    compiler = torch.compile if compile else lambda x:x

    data = tasks.modadd.contrastive.datasets.Data(
        train_dataset_frac = .8          ,
        valid_dataset_frac = .2          ,
        test_dataset_frac  = .0          ,
        src_symbols        = src_symbols ,
        tgt_symbols        = tgt_symbols ,
        device             = data_device
    )
    model = compiler(tasks.modadd.contrastive.models.Transformer(
        vocab_size             = src_symbols+tgt_symbols+1,
        layers                 = 3        ,
        symbol_embedding_size  = 8        ,
        context_embedding_size = 128      ,
        feedforward_size       = 1024     ,
        heads                  = 4        ,
        activation             = "gelu"   ,
        dropout                = 0.0      ,
        device                 = device
    ))

    Trainer() \
        .set_epochs(epochs) \
        .set_etv(etv) \
        .set_train_dataloader(
            torch.utils.data.DataLoader(
                dataset := tasks.modadd.contrastive.datasets.DatasetSuperMisAligned(data, "train", device) ,
                collate_fn = dataset.collate                                     ,
                batch_size = train_batch_size                                    ,
                drop_last  = True                                                ,
                shuffle    = True                                                ,
            )
        ).set_valid_dataloader(
            torch.utils.data.DataLoader(
                dataset:=tasks.modadd.contrastive.datasets.Dataset(data, "valid", device) ,
                collate_fn = dataset.collate                                   ,
                batch_size = eval_batch_size                                   ,
                drop_last  = False                                             ,
                shuffle    = True                                              ,
            )
        ).set_test_dataloaer(
            []
        ).set_model(
            model
        ).set_optimizer(
            torch.optim.Adam(
                model.parameters(),
                lr           = learning_rate,
                weight_decay = weight_decay,
            )
        ).set_scheduler(
            schedulers.ConstantScheduler()
        ).set_loss_fn(
            losses.NoiseContrastiveLoss()
        ).set_train_callback(
            utils.chain_callbacks(
                tasks.modadd.contrastive.callbacks.LogCdists(model, src_symbols=src_symbols, tgt_symbols=tgt_symbols, logger=utils.get_file_logger(os.path.join(dir, "cdist.log")), etc=100),
                tasks.modadd.contrastive.callbacks.TrainLogger(utils.get_file_logger(os.path.join(dir, "train.log"))),
                callbacks.Checkpoint(model, etc=1, path=os.path.join(dir,"model.pkl")),
                tasks.modadd.contrastive.callbacks.ProgressBar(length=epochs),
            )
        ).set_valid_callback(
            tasks.modadd.contrastive.callbacks.EvalLogger(src_symbols=src_symbols, tgt_symbols=tgt_symbols, logger=utils.get_file_logger(os.path.join(dir, "valid.log")))
        ).set_test_callback(
            tasks.modadd.contrastive.callbacks.EvalLogger(src_symbols=src_symbols, tgt_symbols=tgt_symbols, logger=utils.get_file_logger(os.path.join(dir, "test.log")))
        ).run()


if __name__ == "__main__":
    modadd()
    
