from utils.util import load_jsonl, re_keywords, save_jsonl, save_data_to_fac
from utils.db_utils import get_column_table_from_full_schema, get_pretty_schema
import os
import json
from src.prompts.react.generate_data_with_code_mp import action_prompt_wo_mp
import re
import argparse

REACT_PROMPT = """<thought>{thought}</thought>
<action>{action}</action>
<observation>
{observation}
</observation>
"""

FOR_REACT_PROMPT = """<thought>{thought}</thought>
<action>
"""

def generate_prompt(full_round_history_json, data_jsonl, full_schema_jsonl, thought_json, turns = 5):
    data_list = []
    for idx, history_list in full_round_history_json.items():
        idx = int(idx)
        content = data_jsonl[idx]
        if history_list != []:
            flag = history_list[-1]['flag']
            if flag == 'True': continue
        
        history_list = full_round_history_json[str(idx)]
        observation_list = []
        for obs in history_list:
            if 'MISS' in obs['action']: continue
            info = REACT_PROMPT.format(thought = obs['thought'], action = obs['action'], observation = obs['observation'])
            observation_list.append(info)
            
        
        schema = get_pretty_schema(idx, full_schema_jsonl, column_meaning_dict= None, selected_schema= None,
                                show_data = True, show_meaning= False)

            
        user_issue = content['query']
        issue_sql = content['issue_sql']
        db_id  = content['db_id']
        current_turn = turns - len(observation_list)
        thought_prompt = FOR_REACT_PROMPT.format(thought = thought_json[str(idx)])
        if observation_list == []:
            prompt = action_prompt_wo_mp.format(db_id = db_id, schema = schema, user_issue = user_issue, issue_sql = issue_sql,
                                           turn = current_turn) + '\n' + thought_prompt
        else:
            prompt = action_prompt_wo_mp.format(db_id = db_id, schema = schema, user_issue = user_issue, issue_sql = issue_sql,
                                           turn = current_turn) + '\n' + '\n\n'.join(observation_list) + '\n\n' + thought_prompt

        data_list.append({'idx':idx, 'prompt': prompt})
    print(data_list[1]['prompt'])
    return data_list

def parse_args():
    parser = argparse.ArgumentParser(
        description="Generate prompts for ReAct with code multi-planning."
    )
    parser.add_argument("--data_path", required=True, help="dev_data.jsonl")
    parser.add_argument("--schema_path", required=True, help="dev_full_schema_info.jsonl")
    parser.add_argument("--full_round_path", required=True, help="Full-round history JSON to read")
    parser.add_argument("--thought_path", required=True, help="thought for this turn")
    parser.add_argument("--save_path", required=True, help="action prompt path")

    return parser.parse_args()
    
if __name__ == '__main__':
    args = parse_args()
    
    data_jsonl = load_jsonl(args.data_path)
    full_schema_jsonl = load_jsonl(args.schema_path)
    thought_json = json.load(open(args.thought_path, 'r'))
    
    if os.path.exists(args.full_round_path): 
        history_json = json.load(open(args.full_round_path, 'r'))
    else: 
        history_json = dict(zip([str(idx) for idx, _ in enumerate(data_jsonl)], [[] for idx, _ in enumerate(data_jsonl)]))
    
    data_list = generate_prompt(history_json, data_jsonl, full_schema_jsonl, thought_json)
    print(len(data_list))
    save_jsonl(data_list, args.save_path)