
from utils.util import load_jsonl, re_keywords, save_jsonl, save_data_to_fac
from utils.db_utils import get_column_table_from_full_schema, get_pretty_schema
import os
import json
from utils.util import save_data_to_fac
from src.prompts.react.lf_prompt.generate_data_for_train import system_prompt,traj_wo_mp_prompt

REACT_PROMPT = """<thought>{thought}</thought>
<action>{action}</action>
<observation>
{observation}
</observation>
"""


TA_PROMPT = """<thought>{thought}</thought>
<action>{action}</action>
"""

def build_instance(traj_history):
    traj_history = sorted(traj_history, key=lambda x: x["turn"])
    instances = []
    for idx, content in enumerate(traj_history):
        history = traj_history[:idx]
        instances.append({
            'history': history,
            'target': content
        })   
    return instances

def construct_data(data_jsonl, status_jsonl, full_schema_jsonl, full_round_json, turns = 5):
    data_list = []
    for idx, content in enumerate(data_jsonl):
        status = status_jsonl[idx]['status']
        if status != 'success':
            continue
        traj_history = full_round_json[str(idx)]
        
        instances = build_instance(traj_history) 
        user_issue = content['query']
        issue_sql = content['issue_sql']
        schema = get_pretty_schema(idx, full_schema_jsonl, column_meaning_dict= None, selected_schema= None,
                                show_data = True, show_meaning= False)
        db_id = content['db_id']
        
        
        for instance in instances:
            history = instance['history']
            target = instance['target']   
            current_turn = turns - len(history)
            final_traj_prompt = traj_wo_mp_prompt.format(db_id = db_id, schema = schema, user_issue = user_issue, issue_sql = issue_sql, turn = current_turn)
            if len(history) == 0:
                prompt = final_traj_prompt + '<thought>'
            else:
                react_list = []
                for his_dic in history:
                    react = REACT_PROMPT.format(thought = his_dic['thought'], action = his_dic['action'], observation = his_dic['observation'])
                    react_list.append(react)
                react_traj = '\n\n'.join(react_list)
                prompt = final_traj_prompt + react_traj + '\n\n' + '<thought>'
            output = TA_PROMPT.format(thought = target['thought'], action = target['action'])
            message = {
            'instruction': prompt,
            'output': output,
            'system': system_prompt
            }
            data_list.append(message)
    print(data_list[1]['instruction'])
    print(len(data_list))
    print(data_list[1]['output'])
    return data_list

