import torch
from src import IM_data, get_train_test_idx
from src import Light_Net
from src import Grapher
from src import Scheduler_manager
from src import verbose, get_args
from torch.utils.data import DataLoader
import importlib
import lightning as L


if __name__ == '__main__':
    # get args for current  run
    model_info = get_args()
    grapher = Grapher(base_pt='./result_graphs', model_info=model_info)
    # set models module path for auto imports
    nets_module = importlib.import_module('src.models')

    # prepare dataset
    data_train = IM_data(model_info['DATA_PATH'], train=True, normalize=model_info['NORMALIZE_INPUT'],
                         shape=model_info['SHAPE'], aug_dict=model_info['AUG'])
    data_val = IM_data(model_info['DATA_PATH'], train=False, normalize=model_info['NORMALIZE_INPUT'],
                       shape=model_info['SHAPE'], aug_dict=model_info['AUG'],
                       label_encoder=data_train.label_encoder)

    train_dataloader = DataLoader(data_train, batch_size=model_info['BATCH_SIZE'], shuffle=True, num_workers=12, persistent_workers=True)
    val_dataloader = DataLoader(data_val, batch_size=model_info['BATCH_SIZE'], shuffle=False, num_workers=12, persistent_workers=True)

    # initialize model
    torch.manual_seed(model_info['DATA_SEED'])
    net_class = getattr(nets_module, model_info['MODEL_NAME'])
    model = net_class(output_size=data_train.data_shape[1])

    # initialize loss fn and optimizer
    loss = torch.nn.CrossEntropyLoss()

    # get TB logger
    logger = L.pytorch.loggers.TensorBoardLogger(save_dir='./', version=grapher.time_name)

    if model_info['OPT'] == 'Adam':
        optimizer = torch.optim.AdamW(model.parameters(), lr=model_info['LR'],
                                      weight_decay=0.004, amsgrad=False)
    else:
        optimizer = torch.optim.SGD(model.parameters(), lr=model_info['LR'],
                                    weight_decay=model_info['LR']/model_info['EPOCHS'])

    # initialize scheduler
    scheduler_manager = Scheduler_manager(optimizer=optimizer, model_info=model_info)

    light_model = Light_Net(network=model, loss_fn=loss, optimizer=optimizer, scheduler=scheduler_manager.scheduler,
                            model_info=model_info, dataset=data_train,
                            wn_data_path=model_info['WN_PATH'],
                            glove_data_path=model_info['GLOVE_PATH'],
                            glove_embeddings_data_file=model_info['GLOVE_embeddings_PATH'],
                            grapher=grapher)

    lightning_trainer = L.Trainer(accelerator=model_info['DEVICE'], devices=1, strategy='ddp', max_epochs=model_info['EPOCHS'],
                                  limit_train_batches=500, limit_val_batches=800,
                                  check_val_every_n_epoch=1, log_every_n_steps=20,
                                  enable_progress_bar=True, logger=logger)

    # train
    lightning_trainer.fit(model=light_model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)

    grapher.save_data()
    print('[INFO] Done!')
