from Trainer import Trainer

import schedulers
import callbacks
import losses
import tasks
import torch
import utils
import click
import os

@click.group(invoke_without_command=True)
@utils.options.modadd
@click.option("--embeddings-path", "emb_path", type=click.Path(), default=None, help="path to model from which embedding are loaded")
def modadd(dir, seed, epochs, train_batch_size, eval_batch_size, src_symbols, tgt_symbols, etv, compile, data_device, emb_path, learning_rate, weight_decay, device):
    utils.seed_everything(seed)
    torch.set_float32_matmul_precision("high")
    torch.set_printoptions(sci_mode=False, linewidth=1000)
    compiler = torch.compile if compile else lambda x:x

    data = tasks.modadd.classification.datasets.Data(
        train_dataset_frac = .8          ,
        valid_dataset_frac = .2          ,
        test_dataset_frac  = .0          ,
        src_symbols        = src_symbols ,
        tgt_symbols        = tgt_symbols ,
        device             = data_device
    )

    model = tasks.modadd.classification.models.Transformer(
        vocab_size             = src_symbols+tgt_symbols+1,
        layers                 = 3        ,
        embedding_size         = 128      ,
        feedforward_size       = 1024     ,
        heads                  = 4        ,
        activation             = "gelu"   ,
        dropout                = 0.0      ,
        device                 = device
    )

    if emb_path: 
        with torch.no_grad():
            model.embeddings.weight.data = torch.load(emb_path,weights_only=True)["modelsd"]["symbol_embeddings.weight"]

    model = compiler(model)


    Trainer() \
        .set_epochs(epochs) \
        .set_etv(etv) \
        .set_train_dataloader(
            torch.utils.data.DataLoader(
                dataset := tasks.modadd.classification.datasets.Dataset(data, "train", device) ,
                collate_fn = dataset.collate                                     ,
                batch_size = train_batch_size                                    ,
                drop_last  = True                                                ,
                shuffle    = True                                                ,
            )
        ).set_valid_dataloader(
            torch.utils.data.DataLoader(
                dataset:=tasks.modadd.classification.datasets.Dataset(data, "valid", device) ,
                collate_fn = dataset.collate                                   ,
                batch_size = eval_batch_size                                   ,
                drop_last  = False                                             ,
                shuffle    = True                                              ,
            )
        ).set_test_dataloaer(
            []
        ).set_model(
           model
        ).set_optimizer(
            torch.optim.Adam(
                model.parameters(),
                lr           = learning_rate,
                weight_decay = weight_decay
            )
        ).set_scheduler(
            schedulers.ConstantScheduler()
        ).set_loss_fn(
            losses.ContrastiveLoss()
        ).set_train_callback(
            utils.chain_callbacks(
                tasks.modadd.classification.callbacks.LogCdists(model, src_symbols=src_symbols, tgt_symbols=tgt_symbols, logger=utils.get_file_logger(os.path.join(dir, "cdist.log")), etc=100),
                tasks.modadd.classification.callbacks.TrainLogger(utils.get_file_logger(os.path.join(dir, "train.log"))),
                callbacks.Checkpoint(model, etc=1, path=os.path.join(dir,"model.pkl")),
                tasks.modadd.classification.callbacks.ProgressBar(length=epochs),
            )
        ).set_valid_callback(
            tasks.modadd.classification.callbacks.EvalLogger(src_symbols=src_symbols, tgt_symbols=tgt_symbols, logger=utils.get_file_logger(os.path.join(dir, "valid.log")))
        ).set_test_callback(
            tasks.modadd.classification.callbacks.EvalLogger(src_symbols=src_symbols, tgt_symbols=tgt_symbols, logger=utils.get_file_logger(os.path.join(dir, "test.log")))
        ).run()


if __name__ == "__main__":
    modadd()
    
