from utils.util import load_jsonl, re_keywords, save_jsonl, save_data_to_fac
from utils.db_utils import get_column_table_from_full_schema, get_pretty_schema
import os
import json
from utils.util import save_data_to_fac
from src.prompts.react.lf_prompt.generate_data_for_train import system_prompt,traj_prompt_wo_mp
import re 
import argparse

REACT_PROMPT = """<thought>{thought}</thought>
<action>{action}</action>
<observation>
{observation}
</observation>
"""



def construct_data(data_jsonl, full_schema_jsonl, turns = 5):
    data_list = []
    record_list = []
    for idx, content in enumerate(data_jsonl):
            
        full_schema = full_schema_jsonl[idx]['schema_info']
        schema = get_pretty_schema(idx, full_schema_jsonl, column_meaning_dict= None, selected_schema= None,
                                show_data = True, show_meaning= False)
            
        user_issue = content['query']
        issue_sql = content['issue_sql']
        db_id  = content['db_id']
        prompt = traj_prompt_wo_mp.format(user_issue = user_issue, schema = schema, issue_sql = issue_sql, 
                                    db_id = db_id, turn = turns)
        message = {
            'instruction': prompt,
            'output': "",
            'system': system_prompt
        }
        data_list.append(message)
        record_list.append({"idx": idx, "db_id": db_id, "prompt": prompt})
    print(data_list[1]['instruction'])
    return data_list, record_list

def parse_args():
    parser = argparse.ArgumentParser(
        description="Generate prompts for ReAct with code multi-planning."
    )
    parser.add_argument("--data_path", required=True, help="dev_data.jsonl")
    parser.add_argument("--schema_path", required=True, help="dev_full_schema_info.jsonl")
    parser.add_argument("--save_dir", required=True, help="save directory for llama-factory")
    parser.add_argument("--file_name", required=True, help="file name for llama-factory")
    parser.add_argument("--record_save_path", required=True, help="record of prompt")
    parser.add_argument(
        "--turns", type=int, default=5, help="Number of ReAct turns to include")
    return parser.parse_args()

if __name__ == '__main__':
    args = parse_args()
     
    data_jsonl = load_jsonl(args.data_path)
    full_schema_jsonl = load_jsonl(args.schema_path)
    
    data, record = construct_data(data_jsonl, full_schema_jsonl)
    save_data_to_fac(data, args.save_dir, args.file_name)
    save_jsonl(record, args.record_save_path)