import os
import pandas as pd



datasets = ['wikitq', 'tabfact', 'finqa', 'tablebench']
train = []
test = []

train_counts = {}
test_counts = {}

for dataset in datasets:
    if dataset != 'tablebench':
        df_train = pd.read_parquet(f'data/processed_data/formula/qwen/{dataset}/train.parquet')
        train.append(df_train)
        train_counts[dataset] = len(df_train)

    df_test = pd.read_parquet(f'data/processed_data/formula/qwen/{dataset}/test.parquet')
    test.append(df_test)
    test_counts[dataset] = len(df_test)

# merge and shuffle
train = pd.concat(train).sample(frac=1, random_state=42).reset_index(drop=True)
test = pd.concat(test).sample(frac=1, random_state=42).reset_index(drop=True)


def transform_data(df):

    def transform_prompt(prompt):
        
        new_note = '''# Notes
1. Generate a Python code that can be executed to answer the question.
2. The result of executing the code should be the final answer.
3. You must output the Python code as a single line in the code field of the JSON, enclosed in triple backticks with the python tag (```python```).
4. If a question has multiple answers, concatenate them using ", " as the separator.
5. The input value for the table is already assigned to the variable 'table_data = [[...],[...],...]'.
6. The result of final answer must be assigned to a variable named 'answer'.\n
'''
        new_prompt = prompt.split('# Spreadsheet Formula Operator List')[0].replace('Spreadsheet formulas', 'Python code').replace('spreadsheet formula', 'Python code')
        new_prompt = new_prompt + '# Table' + prompt.split('# Table')[1].split('# Notes')[0] + new_note + 'Here is the question:' + prompt.split('Here is the question:')[1]
        new_prompt = new_prompt.replace('spreadsheet formula', 'Python code').replace('formula', 'code').replace('"=......."', '"......."')
        return new_prompt

    df['data_source'] = df['data_source'].apply(lambda x: x.replace('formula', 'python'))
    df['prompt'] = df['prompt'].apply(lambda x: transform_prompt(x))
    return df


train = transform_data(train)
test = transform_data(test)


# save
os.makedirs('data/processed_data/symbolic/python', exist_ok=True)
train.to_parquet('data/processed_data/symbolic/python/train.parquet', index=False)
test.to_parquet('data/processed_data/symbolic/python/test.parquet', index=False)

# print statistics
print("Train set sizes:")
for k, v in train_counts.items():
    print(f"  {k}: {v}")
print(f"  Total: {len(train)}")

print("Test set sizes:")
for k, v in test_counts.items():
    print(f"  {k}: {v}")
print(f"  Total: {len(test)}")