import os
import sys

sys.path.append('.')
sys.path.append(os.path.abspath('..'))
if not len(sys.argv) == 2 and __name__ == '__main__': # For debug runs
    os.environ["CUDA_VISIBLE_DEVICES"] = '3'

import torch
from torch import utils

from lightning.pytorch import seed_everything
from utils_common.utils import jpath, read_yaml
from m2m.lightning_dataset import *
from m2m.lightning_model import get_lit_model, load_lit_model
import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping
from transformers import AutoTokenizer
import mlconfig
# from transformers.utils import logging
# logging.get_logger("transformers").setLevel(logging.ERROR)

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True


def main():
    '''
    Why the ppl and loss is different in left padding and right padding?
    Because the direct calling of model.forward require an additional position_ids to indicate the position of each token
    If the data is left-side padded, and a corresponding left-padded position_ids is not provided (which is my case)
    The position_ids will be generated by the model itself, starting from the first token (a pad token), which is not correct
    
    So, in finetune and ppl evaluation, use right-side padding
    - Take care: need to use a different padding token than eos, and mask them in the loss calculation
    In generation, use left-side padding
    '''
    seed_everything(42, workers=True)

    if not len(sys.argv) == 2: # Debug
        config_fp = '/home/[anonymous]/work/[anonymous]/m2m/hparams/drum_arrange/direct_opd.yaml'
        config = mlconfig.load(config_fp)
        config['num_workers'] = 0
        config['fast_dev_run'] = 5
        config['bs_test'] = 2
    else:
        config_fp = sys.argv[1]
        config = mlconfig.load(config_fp)

    # Init the tokenizer
    tk_fp = config['tokenizer_fp']
    if 'probing' in config['out_dir']:
        tk = AutoTokenizer.from_pretrained(tk_fp)
    else:
        tk = AutoTokenizer.from_pretrained(tk_fp, padding_side='left')
    

    # Load a lightning model from checkpoint
    lit_model = load_lit_model(tk, config)

    # Setup data
    # test_loader = get_dataloader(config, 'valid')
    test_loader = get_dataloader(config, 'test')

    # Prepare trainer for testing
    trainer = L.Trainer(
        logger=False,
        fast_dev_run=config['fast_dev_run'] if 'fast_dev_run' in config else False,
        precision='bf16',
        accelerator="gpu",
        devices=1,
    )
    trainer.test(
        model=lit_model,
        dataloaders=test_loader,
    )


def get_dataloader(config, split):
    if split != 'test':
        bs = config['bs']
    else:
        bs = config['bs_test']
    data_root = config['data_root']
    data_fn = '{}.txt'.format(split)
    data_fp = jpath(data_root, data_fn)

    dataset_class_name = config['dataset_class']
    dataset_class = eval(dataset_class_name)

    dataset = dataset_class(data_fp=data_fp, split=split, config=config)
    dataloader = utils.data.DataLoader(
        dataset=dataset, 
        batch_size=bs,
        num_workers=config['num_workers'] if 'num_workers' in config else 4,
        collate_fn=lambda x: x,
    )
    return dataloader


if __name__ == '__main__':
    main()