import os
import sys

sys.path.append('.')
sys.path.append(os.path.abspath('..'))
if not len(sys.argv) == 2 and __name__ == '__main__': # For debug runs
    os.environ["CUDA_VISIBLE_DEVICES"] = '3'

import torch
from torch import utils

from piano_dataset import get_dataloader
from lightning_model_moyu import get_lit_model

from lightning.pytorch import seed_everything
from utils_common.utils import jpath, read_yaml
# from m2m.lightning_dataset import *
# from m2m.lightning_model import get_lit_model
import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping
from transformers import AutoTokenizer
import mlconfig

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True


def main():
    seed_everything(42, workers=True)

    if not len(sys.argv) == 2: # Debug
        config_fp = '/home/[anonymous]/work/[anonymous]/[anonymous]/baselines/moyu_piano/hparams/baseline_with_octave_shift.yaml'
        config = mlconfig.load(config_fp)
        config['num_workers'] = 0
        fast_dev_run = 4
        # config['train_with'] = 'valid'
    else:
        config_fp = sys.argv[1]
        config = mlconfig.load(config_fp)
        fast_dev_run = False

    # Init the model
    lit_model = get_lit_model(config)

    # Setup data
    train_loader = get_dataloader(config, config['train_with'])
    valid_loader = get_dataloader(config, config['valid_with'])

    # Train the model
    out_dir = jpath(config['result_root'], config['out_dir'])
    checkpoint_callback = ModelCheckpoint(
        monitor=config['monitor'],
        mode=config['mode'],
        filename='{epoch:02d}-{valid_loss:.2f}',
        save_top_k=1,
    )
    earlystop_callback = EarlyStopping(
        monitor=config['monitor'],
        mode=config['mode'],
        patience=config['early_stop_patience'],
    )
    trainer = L.Trainer(
        max_epochs=config['n_epoch'],
        default_root_dir=out_dir, # output and log dir
        callbacks=[checkpoint_callback, earlystop_callback],
        fast_dev_run=fast_dev_run,
        val_check_interval=config['val_check_interval'],
        check_val_every_n_epoch=config['check_val_every_n_epoch'],
        precision='bf16',
        accelerator="gpu",
    )
    trainer.fit(
        model=lit_model,
        train_dataloaders=train_loader, 
        val_dataloaders=valid_loader,
    )





if __name__ == '__main__':
    main()