import os
import json
import glob
import argparse
import pdb

def merge_feedback_files_tabmwp(directory):
    merged_data = {}
    
    # Get all feedback files in the specified directory
    feedback_files = glob.glob(os.path.join(directory, "turn_*_feedback.jsonl"))
    
    for feedback_file in feedback_files:
        turn = int(os.path.basename(feedback_file).split('_')[1])
        with open(feedback_file, 'r') as file:
            for line in file:
                data = json.loads(line)
                question_id = data.get('question_id')
                data_title = data.get('table_title', '')
                data_path = data.get('data_path', '')
                data_overview = data.get('data_overview', '')
                
                common_data = {
                    'question_id': question_id,
                    'question': data.get('question'),
                    'data_path': data_path,
                    'data_title': data_title,
                    'data_overview': data_overview,
                    'answer': data.get('answer', ''),
                    'answer_type': data.get('answer_type', ''),
                }
                
                entry = {
                    'question': data.get('question'),
                    'function_code': data.get('function_code'),
                    'final_code': data.get('final_code'),
                    'result': data.get('result', ''),
                    'require_opt': data.get('require_opt'),
                    'turn': turn
                }
                
                if question_id not in merged_data:
                    merged_data[question_id] = {
                        **common_data,
                        'trad': []
                    }
                
                merged_data[question_id]['trad'].append(entry)
    
    # Sort entries by 'turn' and add 'trad_seq'
    for question_id, data in merged_data.items():
        data['trad'].sort(key=lambda x: x['turn'])
        
        filtered_trad = []
        found_correct = False
        
        for entry in data['trad']:
            if entry['require_opt'] == 'true':
                filtered_trad.append(entry)
            elif entry['require_opt'] == 'false' and not found_correct:
                filtered_trad.append(entry)
                found_correct = True
        
        data['trad'] = filtered_trad
        data['trad_seq'] = ['correct' if entry['require_opt'] == 'false' else 'wrong' for entry in filtered_trad]
    
    return merged_data


def merge_feedback_files_wiki(directory):
    merged_data = {}
    
    # Get all feedback files in the specified directory
    feedback_files = glob.glob(os.path.join(directory, "turn_*_feedback.jsonl"))
    
    for feedback_file in feedback_files:
        turn = int(os.path.basename(feedback_file).split('_')[1])
        with open(feedback_file, 'r') as file:
            for line in file:
                data = json.loads(line)
                question_id = data.get('question_id')
                data_path = data.get('data_path', '')
                
                common_data = {
                    'question_id': question_id,
                    'question': data.get('question'),
                    'data_path': data_path,
                    'data_overview': data.get('data_overview', ''),
                    'answer': data.get('answer', ''),
                }
                
                entry = {
                    'question': data.get('question'),
                    'function_code': data.get('function_code'),
                    'final_code': data.get('final_code'),
                    'execution': data.get('execution'),
                    'result': data.get('result', ''),
                    'require_opt': data.get('require_opt'),
                    'turn': turn
                }
                
                if question_id not in merged_data:
                    merged_data[question_id] = {
                        **common_data,
                        'trad': []
                    }
                
                merged_data[question_id]['trad'].append(entry)
    
    # Sort entries by 'turn' and add 'trad_seq'
    for question_id, data in merged_data.items():
        data['trad'].sort(key=lambda x: x['turn'])
        
        filtered_trad = []
        found_correct = False
        
        for entry in data['trad']:
            if entry['require_opt'] == 'true':
                filtered_trad.append(entry)
            elif entry['require_opt'] == 'false' and not found_correct:
                filtered_trad.append(entry)
                found_correct = True
        
        data['trad'] = filtered_trad
        data['trad_seq'] = ['correct' if entry['require_opt'] == 'false' else 'wrong' for entry in filtered_trad]
    
    return merged_data

def main():
    parser = argparse.ArgumentParser(description='Merge feedback JSONL files.')
    parser.add_argument('--trad_directory', type=str, help='Directory containing feedback JSONL files.')
    parser.add_argument('--output_file', type=str, help='Output file for merged JSONL data.')
    
    args = parser.parse_args()
    
    merged_data = merge_feedback_files_wiki(args.trad_directory)
    
    # Save the merged data to a JSONL file
    with open(args.output_file, 'w') as outfile:
        for question_id, data in merged_data.items():
            outfile.write(json.dumps(data) + '\n')
    # pdb.set_trace()
    print(f"Merged data has been saved to {args.output_file}")

if __name__ == "__main__":
    main()

