import os, sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from Group_SAE.SAETran_model_v2 import HDF5Dataset
import yaml

# L_ls = [2, 13, 26]
data_name_ls = ['Pile_github', 'Pile_wiki']
L_ls = [26]
# data_name_ls = ['Pile_github']

for L in L_ls:
    for data_name in data_name_ls:
        data_path = f"/path/to/your/scratch/{data_name}-Qwen2.5-1.5B-L{L}-mlp-out-2048/train_data.h5"
        # load the dataset
        dataset = HDF5Dataset(data_path)
        
        # get the number of samples
        num_samples = len(dataset)
        
        # get the file sample
        sample = dataset[0]
        
        # create the datainfo dictionary
        data_info = {
            'data_type': 'float32', 
            'dimensions': sample.shape[0], 
            'length': num_samples, 
        }
        
        # store the data_info dictionary in dataset_info.yaml under the same directory
        data_info_path = os.path.join(os.path.dirname(data_path), 'dataset_info.yaml')
        
        # save the data_info dictionary to a yaml file
        with open(data_info_path, 'w') as f:
            yaml.dump(data_info, f)