import re
from typing import Tuple
import os, json
from copy import deepcopy
from utils.util import load_jsonl, save_jsonl
import argparse
import copy

ACTION = """<action>
{action}
</action>
"""

def get_obs(content):
    obs = {}
    obs['status'] = content['status']
    obs['exec_res'] = content['execution_result']
    obs['error'] = content['error_message']
    
    exec_status = obs['status']
    if exec_status == 'success':
        try:
            exec_results = eval(obs['exec_res']).get('table_preview')
            if len(exec_results) > 1000: exec_results = ''
        except: exec_results = 'NULL'
        observation = f"execution status: {exec_status}\nexecution results:\n{exec_results}"
    else:
        exec_error = obs['error']
        observation = f"execution status: {exec_status}\nexecution error:\n{exec_error}"
    return observation

def generate_full_round(data_jsonl, thought_json, action_jsonl, status_jsonl, full_round_json):
    new_full_round_history_dict = {}
    for idx, content in enumerate(data_jsonl):
        full_round_history = full_round_json[str(idx)]
        if full_round_history != []:
            last_round = full_round_history[-1]
            if last_round['flag'] == 'True':
                new_full_round_history_dict[idx] = full_round_history
                continue
        end_flag = 'False'
        current_thought = thought_json[str(idx)]
        current_action = '\n\n'.join(action_jsonl[idx]['pred_sqls'])
        # current_action = ACTION.format(action = action)
        current_obs = get_obs(status_jsonl[idx])
        if '[DONE]' in current_action or '<DONE>' in current_action: 
                end_flag = 'True'
        turn = len(full_round_history)
        tmp_data = {'turn': turn, 'thought': current_thought, 'action': current_action, 'observation': current_obs, 'flag': end_flag}
        new_full_round_history = copy.deepcopy(full_round_history)
        new_full_round_history.append(tmp_data)
        new_full_round_history_dict[idx] = new_full_round_history
    return new_full_round_history_dict

def parse_args():
    parser = argparse.ArgumentParser(
        description="Generate prompts for ReAct with code multi-planning."
    )
    parser.add_argument("--data_path", required=True, help="dev_data.jsonl")
    parser.add_argument("--thought_path", required=True, help="thought for this round")
    parser.add_argument("--processed_action_path", required=True, help="processed action path")
    parser.add_argument("--processed_status_path", required=True, help="status path")
    parser.add_argument("--full_round_path", required=True, help="Full-round history JSON to read")
    parser.add_argument("--save_path", required=True, help="Full-round history JSON to save")
    return parser.parse_args()

if __name__ == '__main__':
    args = parse_args()
    data_jsonl = load_jsonl(args.data_path)
    thought_json = json.load(open(args.thought_path, 'r'))
    action_jsonl = load_jsonl(args.processed_action_path)
    status_jsonl = load_jsonl(args.processed_status_path)
    
    if os.path.exists(args.full_round_path): 
        history_json = json.load(open(args.full_round_path, 'r'))
    else: 
        history_json = dict(zip([str(idx) for idx, _ in enumerate(data_jsonl)], [[] for idx, _ in enumerate(data_jsonl)]))
    
    full_round_dict = generate_full_round(data_jsonl, thought_json, action_jsonl, status_jsonl, history_json) 
    json.dump(full_round_dict, open(args.save_path, 'w'), indent = 4)
        