import logging
import os
import re
import random
import trl
import re
import hydra
import torch
import numpy as np

from custom_data.countdown_data import find_equation, add_solutions
from omegaconf import DictConfig, OmegaConf
from datetime import datetime
from datasets import load_dataset
from transformers.trainer_utils import get_last_checkpoint
from transformers import AutoTokenizer
from custom_data import sft_data
from trl import GRPOConfig, GRPOTrainer, get_peft_config, ModelConfig
from train import compute_accumulation_steps
from hydra import compose, initialize


from datasets import load_dataset, DatasetDict, Dataset
from custom_data.sft_data import add_indices
from custom_data.reasoning_datasets_info import ReasoningData, DATA_CONFIGS

import omegaconf

from custom_data.pretraining_data import load_pretraining_dataset


hydra_overrides = [
        "per_device_train_batch_size=2",
        "data_cfg@_global_=fineweb",
        "dataset_configuration=sample-350BT",
        
        "trainer_cfg@_global_=base_trainer",
        
        "exp_name=scratch",
        "wandb_project=scratch",
        "report_to=null",
]


hydra_overrides_from_dict = dict(        
)

hydra_overrides_from_dict = [f"{k}={v}" for k, v in 
                             hydra_overrides_from_dict.items()]

hydra_overrides = hydra_overrides + hydra_overrides_from_dict


with initialize(version_base=None, config_path="cfgs", job_name="test_r1"):
    cfg = compose(config_name="train", 
                    overrides=hydra_overrides,
    )
    if OmegaConf.is_missing(cfg, "gradient_accumulation_steps"):
            accumulation_steps = compute_accumulation_steps(
                    train_batch_size=cfg.train_batch_size,
                    per_device_train_batch_size=cfg.per_device_train_batch_size)
            cfg.gradient_accumulation_steps = accumulation_steps
            print(cfg.gradient_accumulation_steps)
            print(cfg.gradient_accumulation_steps)
    print(omegaconf.OmegaConf.to_yaml(cfg))

tokenizer = hydra.utils.instantiate(cfg.make_tokenizer_fn)
datasets = hydra.utils.instantiate(cfg.make_dataset_fn, tokenizer=tokenizer,)