import os
import sys
dirof = os.path.dirname
sys.path.insert(0, dirof(dirof(dirof(__file__))))

sys.path.append('.')
sys.path.append(os.path.abspath('..'))
if not len(sys.argv) == 2 and __name__ == '__main__': # For debug runs
    os.environ["CUDA_VISIBLE_DEVICES"] = '3'

import torch
from torch import utils

from piano_dataset import get_dataloader
from lightning.pytorch import seed_everything
from utils_common.utils import jpath, read_yaml
# from m2m.lightning_dataset import *
# from m2m.lightning_model import get_lit_model, load_lit_model
from lightning_model_moyu import load_lit_model
import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping
from transformers import AutoTokenizer
import mlconfig
# from transformers.utils import logging
# logging.get_logger("transformers").setLevel(logging.ERROR)

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True


def main():
    seed_everything(42, workers=True)

    if not len(sys.argv) == 2: # Debug
        config_fp = '/home/[anonymous]/work/[anonymous]/[anonymous]/baselines/moyu_piano/hparams/baseline.yaml'
        config = mlconfig.load(config_fp)
        config['num_workers'] = 0
        config['fast_dev_run'] = 5
        config['bs_test'] = 2
    else:
        config_fp = sys.argv[1]
        config = mlconfig.load(config_fp)

    # Load a lightning model from checkpoint
    lit_model = load_lit_model(config)

    # Setup data
    test_loader = get_dataloader(config, 'test')

    # Prepare trainer for testing
    trainer = L.Trainer(
        logger=False,
        fast_dev_run=config['fast_dev_run'] if 'fast_dev_run' in config else False,
        precision='bf16',
        accelerator="gpu",
        devices=1,
    )
    trainer.test(
        model=lit_model,
        dataloaders=test_loader,
    )

if __name__ == '__main__':
    main()