import os
import json
import pandas as pd



# we construct the SFT data based on qwen train data



for dataset in ['wikitq', 'tabfact', 'finqa', 'hitab', 'multihiertt']:
    for task in ['formula', 'text']:
        distill_response_mapping = {}
        for model in ['qwen', 'llama']:
            distill_data = []

            original_train_data_path = f"data/processed_data/{task}/{model}/{dataset}/train.parquet"
            original_train_data = pd.read_parquet(original_train_data_path)


            sft_response_path = f"data/distill_data/{task}/{dataset}/distill_data.json"
            
            with open(sft_response_path, 'r') as f:
                sft_response = json.load(f)

            # print(model)
            # print(len(original_train_data))
            # print(len(sft_response))
            # print(len(distill_response_mapping))

            for i in range(len(original_train_data)):
                item = original_train_data.iloc[i]

                # print(item['extra_info'])

                data_id = str(item['table_id']) + ' - ' + item['question'] + ' - ' + str(item['reward_model']['ground_truth']['answer'])

                if model == 'qwen':
                    if sft_response[i] == '':
                        continue
                    distill_response_mapping[data_id] = sft_response[i].replace('<think>', '')
                    item['extra_info']['sft_response'] = sft_response[i]
                elif model == 'llama':
                    if data_id not in distill_response_mapping:
                        continue
                    else:
                        item['extra_info']['sft_response'] = distill_response_mapping[data_id]
                        del distill_response_mapping[data_id]

                # print(item['prompt'])
                # print(item['extra_info']['sft_response'])
                distill_data.append(item)

            distill_data = pd.DataFrame(distill_data)
            distill_data.to_parquet(f"data/processed_data/{task}/{model}/{dataset}/sft.parquet")

            print(f'[Task {task}, Model {model}] number of distill SFT data in {dataset}: {len(distill_data)}')


