import logging 
import os 
import re 
import random 
import trl 
import re 
import hydra 
import torch 
import numpy as np 

from custom_data .countdown_data import find_equation ,add_solutions 
from omegaconf import DictConfig ,OmegaConf 
from datetime import datetime 
from datasets import load_dataset 
from transformers .trainer_utils import get_last_checkpoint 
from transformers import AutoTokenizer 
from custom_data import sft_data 
from trl import GRPOConfig ,GRPOTrainer ,get_peft_config ,ModelConfig 
from train import compute_accumulation_steps 
from hydra import compose ,initialize 


from datasets import load_dataset ,DatasetDict ,Dataset 
from custom_data .sft_data import add_indices 
from custom_data .reasoning_datasets_info import ReasoningData ,DATA_CONFIGS 

import omegaconf 

from custom_data .pretraining_data import load_pretraining_dataset 


hydra_overrides =[
"per_device_train_batch_size=2",
"data_cfg@_global_=fineweb",
"dataset_configuration=sample-350BT",
"trainer_cfg@_global_=base_trainer",
"exp_name=scratch",
"wandb_project=scratch",
"report_to=null",
]


hydra_overrides_from_dict =dict (
)

hydra_overrides_from_dict =[f"{k}={v}"for k ,v in 
hydra_overrides_from_dict .items ()]

hydra_overrides =hydra_overrides +hydra_overrides_from_dict 


with initialize (version_base =None ,config_path ="cfgs",job_name ="test_r1"):
    cfg =compose (config_name ="train",
    overrides =hydra_overrides ,
    )
    if OmegaConf .is_missing (cfg ,"gradient_accumulation_steps"):
            accumulation_steps =compute_accumulation_steps (
            train_batch_size =cfg .train_batch_size ,
            per_device_train_batch_size =cfg .per_device_train_batch_size )
            cfg .gradient_accumulation_steps =accumulation_steps 
            print (cfg .gradient_accumulation_steps )
            print (cfg .gradient_accumulation_steps )
    print (omegaconf .OmegaConf .to_yaml (cfg ))

tokenizer =hydra .utils .instantiate (cfg .make_tokenizer_fn )
datasets =hydra .utils .instantiate (cfg .make_dataset_fn ,tokenizer =tokenizer ,)