import pandas as pd
import copy
import re
from itertools import zip_longest      # pairs last chunk with None
from tqdm import tqdm

df_pandas = pd.read_parquet('/home/ubuntu/TinyZero_NuminaMath-CoT/train.parquet')

new_data_dict = {}
for key in df_pandas.columns:
    new_data_dict[key] = []

remaining_steps_ratio = 0.9
steps_ls = []
for i in tqdm(range(len(df_pandas)), total=len(df_pandas), desc="Processing rows", unit="row"):
# for i in range(len(df_pandas)):
    original_data = df_pandas.iloc[i]
    solution_split = original_data['extra_info']['answer'].split('\n\n')
    steps_ls.append(len(solution_split))
    remaining_steps = int(len(solution_split) * remaining_steps_ratio)
    if remaining_steps == 0:
        remaining_steps = 1
    if remaining_steps < len(solution_split):
        for key in df_pandas.columns:
            if key == 'prompt':
                new_prompt = copy.deepcopy(original_data[key])
                new_prompt[0]['content'] += "\n\n".join(solution_split[:-remaining_steps])
                new_data_dict[key].append(new_prompt)
            else:
                new_data_dict[key].append(original_data[key])
    else:
        for key in df_pandas.columns:
            new_data_dict[key].append(original_data[key])

df = pd.DataFrame(new_data_dict)
df.to_parquet(f"/home/ubuntu/TinyZero_NuminaMath-CoT_gold/train_0-9.parquet", index=False, engine="pyarrow", compression="snappy")
print(max(steps_ls))
    #     solution_split = solution_split[:remaining_steps]
    # for j in range(len(solution_split)):
    #     for key in df_pandas.columns:
    #         if key == 'prompt':
    #             new_prompt = copy.deepcopy(original_data[key])
    #             new_data_dict[key].append(solution_split[j])
    #         else:
    #             new_data_dict[key].append(original_data[key])