
from utils.util import load_jsonl, re_keywords, save_jsonl, save_data_to_fac
from utils.db_utils import get_column_table_from_full_schema, get_pretty_schema
import os
import json
from src.prompts.react.generate_data_with_code_mp import intermidate_prompt, intermidate_prompt_wo_turn
import re
import argparse

REACT_PROMPT = """<thought>{thought}</thought>
<action>{action}</action>
<observation>
{observation}
</observation>
"""

def generate_initial_prompt(full_round_history_json, data_jsonl, full_schema_jsonl, meta_plan_jsonl, save_path, turns = 5):
    prompt_list = []
    for idx, history_list in full_round_history_json.items():
        idx = int(idx)
        content = data_jsonl[idx]
        flag = history_list[-1]['flag']
        if flag == 'True': continue
        
        history_list = full_round_history_json[str(idx)]
        observation_list = []
        for obs in history_list:
            if 'MISS' in obs['action']: continue
            info = REACT_PROMPT.format(thought = obs['thought'], action = obs['action'], observation = obs['observation'])
            observation_list.append(info)
            
        meta_plan = meta_plan_jsonl[idx]['response']
        if not meta_plan.startswith('<user_issue_summary>'):
            meta_plan = '<user_issue_summary>' + meta_plan
            
        PLAN_RE = re.compile(r"<selected_table>(.*?)</selected_table>", re.DOTALL | re.IGNORECASE)
        plan_selected_tables = PLAN_RE.findall(meta_plan)
        if plan_selected_tables: 
            plan_selected_tables = plan_selected_tables[-1].strip()
            plan_selected_tables = plan_selected_tables.split(',')
            
        full_schema = full_schema_jsonl[idx]['schema_info']
        db_tables = list(full_schema.keys())
        final_selected_table = [tb for tb in plan_selected_tables if tb in db_tables]
        # print(final_selected_table)
        
        if len(final_selected_table) == 0:
            schema = get_pretty_schema(idx, full_schema_jsonl, column_meaning_dict= None, selected_schema= None,
                                show_data = True, show_meaning= False)
        else:
            selected_schema = []
            prev_table = [tb for tb in db_tables if tb not in final_selected_table]
            resort_tables = final_selected_table + prev_table
            for tbl in resort_tables:
                column_list = [info[0] for info in full_schema[tbl]['columns_info']]
                for col in column_list:
                    selected_schema.append([tbl, col])
        
            schema = get_pretty_schema(idx, full_schema_jsonl, column_meaning_dict= None, selected_schema= selected_schema,
                                show_data = True, show_meaning= False)
            
        user_issue = content['query']
        issue_sql = content['issue_sql']
        sol_sql = content['sol_sql']
        db_id  = content['db_id']
        current_turn = turns - len(observation_list)
        prompt = intermidate_prompt.format(db_id = db_id, schema = schema, user_issue = user_issue, issue_sql = issue_sql,sol_sql= sol_sql,
                                           meta_plan = meta_plan, history = '\n\n'.join(observation_list)
                                        , turn = current_turn)

        prompt_list.append({'idx':idx, 'prompt':prompt})
    print(prompt_list[0]['prompt'])
    save_jsonl(prompt_list, save_path)
    return prompt_list



def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description="Generate intermediate thought-action prompts for ReAct."
    )
    parser.add_argument("--data_path",
        help="dev_data.jsonl file")
    parser.add_argument("--full_schema_path",
        help="dev_full_schema_info.jsonl file")
    parser.add_argument("--meta_plan_path",
        help="backward_mp_code.jsonl file")
    parser.add_argument("--full_round_path",
        help="Existing full-round history JSON")
    parser.add_argument("--save_path",
        help="Where to write the new intermediate prompts")
    return parser.parse_args()

if __name__ == '__main__':
    args = parse_args()
    data_jsonl = load_jsonl(args.data_path)
    full_schema_jsonl = load_jsonl(args.full_schema_path)
    meta_plan_jsonl = load_jsonl(args.meta_plan_path)
    full_round_history_json = json.load(open(args.full_round_path, 'r'))
    prompt_list = generate_initial_prompt(full_round_history_json, data_jsonl, full_schema_jsonl, meta_plan_jsonl, args.save_path)
    print(f"This turn has {len(prompt_list)} instances.")