import os
import re

import accelerate.logging
import transformers
from transformers.trainer_utils import get_last_checkpoint

# torch._dynamo.config.optimize_ddp = False
# torch._dynamo.config.capture_scalar_outputs = True

from preamble import get_args, get_run_name, get_tokenizer, get_all_datasets, get_model, prepare_train_args, get_trainer
import datasets

# datasets.logging.set_verbosity_info()
transformers.logging.set_verbosity_info()

args, train_args = get_args()

train_args = prepare_train_args(args, train_args)

# Check if checkpoint exists
if train_args.resume_from_checkpoint == True:
    try:
        ckpt = get_last_checkpoint(train_args.output_dir)
    except:
        ckpt = None
        train_args.resume_from_checkpoint = None
    if ckpt is not None:
        # Extract checkpoint number using regex
        checkpoint_number = re.search(r"checkpoint-(\d+)", os.path.basename(ckpt))
        if checkpoint_number:
            checkpoint_step = int(checkpoint_number.group(1))
            if checkpoint_step >= train_args.max_steps:
                print(f"Skipping training because checkpoint {checkpoint_step} already exceeds max_steps {train_args.max_steps}")
                exit(0)
    else:
        train_args.resume_from_checkpoint = None

tokenizer = get_tokenizer(args)

train_dataset, eval_datasets = get_all_datasets(args, train_args, tokenizer)

model = get_model(args, train_args, tokenizer)

trainer = get_trainer(args, model, tokenizer, train_args, train_dataset, eval_datasets)

# check local rank
if "LOCAL_RANK" not in os.environ or os.environ["LOCAL_RANK"] == "0":
    import wandb

    wandb.init(entity="<WANDB_ENTITY>", name=get_run_name(args, train_args), reinit=True)

    # Workaround for incrorrect global metrics
    # define our custom x axis metric
    wandb.define_metric("train/global_step")
    # set all other train/ metrics to use this step
    wandb.define_metric("*", step_metric="train/global_step")

if "LOCAL_RANK" not in os.environ or os.environ["LOCAL_RANK"] == "0":
    wandb.config.update(args.__dict__)
    wandb.config.update(train_args.__dict__)

if train_args.do_train:
    if train_args.resume_from_checkpoint is not None:
        trainer.train(resume_from_checkpoint=train_args.resume_from_checkpoint)
    else:
        trainer.train()
elif train_args.do_eval:
    if isinstance(train_args.resume_from_checkpoint, bool) and train_args.resume_from_checkpoint:
        print(train_args.output_dir)
        train_args.resume_from_checkpoint = get_last_checkpoint(train_args.output_dir)
        trainer._load_from_checkpoint(resume_from_checkpoint=train_args.resume_from_checkpoint)
    trainer.evaluate()

if "LOCAL_RANK" not in os.environ or os.environ["LOCAL_RANK"] == "0":
    wandb.finish(quiet=True)
