
from utils.util import load_jsonl, re_keywords, save_jsonl, save_data_to_fac
from utils.db_utils import get_column_table_from_full_schema, get_pretty_schema
import os
import json
from src.prompts.react.generate_data_with_code_mp import initial_prompt
import re
import argparse

def generate_initial_prompt(data_jsonl, full_schema_jsonl, meta_plan_jsonl, save_path, turns = 5):
    prompt_list = []
        
    for idx, content in enumerate(data_jsonl):
        meta_plan = meta_plan_jsonl[idx]['response']
        if not meta_plan.startswith('<user_issue_summary>'):
            meta_plan = '<user_issue_summary>' + meta_plan
            
        PLAN_RE = re.compile(r"<selected_table>(.*?)</selected_table>", re.DOTALL | re.IGNORECASE)
        plan_selected_tables = PLAN_RE.findall(meta_plan)
        if plan_selected_tables: 
            plan_selected_tables = plan_selected_tables[-1].strip()
            plan_selected_tables = plan_selected_tables.split(',')
            
        full_schema = full_schema_jsonl[idx]['schema_info']
        db_tables = list(full_schema.keys())
        final_selected_table = [tb for tb in plan_selected_tables if tb in db_tables]
        # print(final_selected_table)
        
        if len(final_selected_table) == 0:
            schema = get_pretty_schema(idx, full_schema_jsonl, column_meaning_dict= None, selected_schema= None,
                                show_data = True, show_meaning= False)
        else:
            selected_schema = []
            prev_table = [tb for tb in db_tables if tb not in final_selected_table]
            resort_tables = final_selected_table + prev_table
            for tbl in resort_tables:
                column_list = [info[0] for info in full_schema[tbl]['columns_info']]
                for col in column_list:
                    selected_schema.append([tbl, col])
        
            schema = get_pretty_schema(idx, full_schema_jsonl, column_meaning_dict= None, selected_schema= selected_schema,
                                show_data = True, show_meaning= False)
            
        user_issue = content['query']
        issue_sql = content['issue_sql']
        sol_sql = content['sol_sql']
        db_id  = content['db_id']
        prompt = initial_prompt.format(user_issue = user_issue, schema = schema, issue_sql = issue_sql, 
                                       sol_sql = sol_sql, db_id = db_id, meta_plan = meta_plan, turn = turns)
        prompt_list.append({'idx':idx, 'prompt':prompt})
    print(prompt_list[0]['prompt'])
    save_jsonl(prompt_list, save_path)
    
def parse_args():
    parser = argparse.ArgumentParser(
        description="Generate prompts for ReAct with code multi-planning."
    )
    parser.add_argument("--data_path", required=True, help="dev_data.jsonl")
    parser.add_argument("--schema_path", required=True, help="dev_full_schema_info.jsonl")
    parser.add_argument("--meta_plan_path", required=True, help="backward_mp_code.jsonl")
    parser.add_argument("--save_path", required=True, help="Output jsonl path")
    parser.add_argument(
        "--turns", type=int, default=5, help="Number of ReAct turns to include"
    )
    return parser.parse_args()
    
if __name__ == '__main__':
    args = parse_args()

    data_jsonl = load_jsonl(args.data_path)
    full_schema_jsonl = load_jsonl(args.schema_path)
    meta_plan_jsonl = load_jsonl(args.meta_plan_path)
    generate_initial_prompt(
        data_jsonl,
        full_schema_jsonl,
        meta_plan_jsonl,
        save_path=args.save_path,
        turns=args.turns,
    )