import os
import pandas as pd



datasets = ['wikitq', 'tabfact', 'finqa', 'tablebench']
train = []
test = []

train_counts = {}
test_counts = {}

for dataset in datasets:
    if dataset != 'tablebench':
        df_train = pd.read_parquet(f'data/processed_data/formula/qwen/{dataset}/train.parquet')
        train.append(df_train)
        train_counts[dataset] = len(df_train)

    df_test = pd.read_parquet(f'data/processed_data/formula/qwen/{dataset}/test.parquet')
    test.append(df_test)
    test_counts[dataset] = len(df_test)

# merge and shuffle
train = pd.concat(train).sample(frac=1, random_state=42).reset_index(drop=True)
test = pd.concat(test).sample(frac=1, random_state=42).reset_index(drop=True)


def transform_data(df):

    def transform_prompt(prompt):
        
        new_note = '''# Notes
Generate a SQL query that can be executed onto the table to answer the question.
The result of executing the SQL query should be the final answer.
You must output the SQL query as a single line in the SQL field of the JSON.
The table name in the SQL query must be 'TMP_TABLE'.\n
'''
        new_prompt = prompt.split('# Spreadsheet Formula Operator List')[0].replace('Spreadsheet formulas', 'SQL query').replace('spreadsheet formula', 'SQL query')
        new_prompt = new_prompt + '# Table' + prompt.split('# Table')[1].split('# Notes')[0] + new_note + 'Here is the question:' + prompt.split('Here is the question:')[1]
        new_prompt = new_prompt.replace('spreadsheet formula', 'SQL query').replace('formula', 'SQL').replace('"=......."', '"......."')
        return new_prompt

    df['data_source'] = df['data_source'].apply(lambda x: x.replace('formula', 'sql'))
    df['prompt'] = df['prompt'].apply(lambda x: transform_prompt(x))
    return df


train = transform_data(train)
test = transform_data(test)



# save
os.makedirs('data/processed_data/symbolic/sql', exist_ok=True)
train.to_parquet('data/processed_data/symbolic/sql/train.parquet', index=False)
test.to_parquet('data/processed_data/symbolic/sql/test.parquet', index=False)

# print statistics
print("Train set sizes:")
for k, v in train_counts.items():
    print(f"  {k}: {v}")
print(f"  Total: {len(train)}")

print("Test set sizes:")
for k, v in test_counts.items():
    print(f"  {k}: {v}")
print(f"  Total: {len(test)}")