import os
import sys
dirof = os.path.dirname
sys.path.insert(0, dirof(dirof(os.path.abspath(__file__))))
if not len(sys.argv) == 2 and __name__ == '__main__': # For debug runs
    os.environ["CUDA_VISIBLE_DEVICES"] = '3'

import torch
from torch import utils

from lightning.pytorch import seed_everything
from utils_common.utils import jpath, read_yaml
from m2m.lightning_dataset import *
from m2m.lightning_model import get_lit_model, load_lit_model
import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping
from transformers import AutoTokenizer
import mlconfig
from lightning_dataset import get_dataloader
# from transformers.utils import logging
# logging.get_logger("transformers").setLevel(logging.ERROR)

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True


def main():
    '''
    Why the ppl and loss is different in left padding and right padding?
    Because the direct calling of model.forward require an additional position_ids to indicate the position of each token
    If the data is left-side padded, and a corresponding left-padded position_ids is not provided (which is my case)
    The position_ids will be generated by the model itself, starting from the first token (a pad token), which is not correct
    
    So, in finetune and ppl evaluation, use right-side padding
    - Take care: need to use a different padding token than eos, and mask them in the loss calculation
    In generation, use left-side padding
    '''
    seed_everything(42, workers=True)

    if not len(sys.argv) == 2: # Debug
        config_fp = '/home/[anonymous]/work/[anonymous]/[anonymous]/m2m/hparams/band_obj/remi_plus.yaml'
        config = mlconfig.load(config_fp)
        config['num_workers'] = 0
        config['fast_dev_run'] = 10
        config['bs_test'] = 2
    else:
        config_fp = sys.argv[1]
        config = mlconfig.load(config_fp)

    # Init the tokenizer
    tk_fp = config['tokenizer_fp']
    if 'probing' in config['out_dir']:
        tk = AutoTokenizer.from_pretrained(tk_fp)
    else:
        tk = AutoTokenizer.from_pretrained(tk_fp, padding_side='left')
    

    # Load a lightning model from checkpoint
    lit_model = load_lit_model(tk, config)

    # Setup data
    test_loader = get_dataloader(config, 'test')

    # Prepare trainer for testing
    trainer = L.Trainer(
        logger=False,
        fast_dev_run=config['fast_dev_run'] if 'fast_dev_run' in config else False,
        # fast_dev_run=True,
        precision='bf16',
        accelerator="gpu",
        devices=1,
    )
    trainer.test(
        model=lit_model,
        dataloaders=test_loader
    )



if __name__ == '__main__':
    main()