import os
import pandas as pd



datasets = ['wikitq', 'tabfact', 'finqa', 'tablebench']
train = []
test = []

train_counts = {}
test_counts = {}

for dataset in datasets:
    if dataset != 'tablebench':
        df_train = pd.read_parquet(f'data/processed_data/formula/qwen/{dataset}/train.parquet')
        train.append(df_train)
        train_counts[dataset] = len(df_train)

    df_test = pd.read_parquet(f'data/processed_data/formula/qwen/{dataset}/test.parquet')
    test.append(df_test)
    test_counts[dataset] = len(df_test)

# merge and shuffle
train = pd.concat(train).sample(frac=1, random_state=42).reset_index(drop=True)
test = pd.concat(test).sample(frac=1, random_state=42).reset_index(drop=True)

# save
os.makedirs('data/processed_data/symbolic/formula', exist_ok=True)
train.to_parquet('data/processed_data/symbolic/formula/train.parquet', index=False)
test.to_parquet('data/processed_data/symbolic/formula/test.parquet', index=False)

# print statistics
print("Train set sizes:")
for k, v in train_counts.items():
    print(f"  {k}: {v}")
print(f"  Total: {len(train)}")

print("Test set sizes:")
for k, v in test_counts.items():
    print(f"  {k}: {v}")
print(f"  Total: {len(test)}")