
from vllm import LLM, SamplingParams
import os
import json
import time
import re
import argparse


def environment_HRL(x_min, x_max, y_min, y_max, new_desc, llm_low):

    obstacles_pattern = r"obstacles that you have to avoid at: ([^\.]*)"
    obstacles_match = re.search(obstacles_pattern, new_desc)
    obstacles_str = ""
    if obstacles_match:
        obstacles_str = obstacles_match.group(1)
        high_level_obstacles_prompt = f"[{obstacles_str}] {x_min-1} {x_max+1}"
        messages = format_chat_message(high_level_obstacles_prompt)
        sampling_params = SamplingParams(temperature=0.0, max_tokens=1000)
        responses = llm_low.chat(messages, sampling_params)
        response_text = responses[0].outputs[0].text
        last_line = response_text.strip().split('\n')[-1]
        try:
            obstacles = last_line.split('=')[1].strip()
        except Exception as e:
            print(f"Error parsing sub-environment: {e}")
            obstacles = ''

        high_level_obstacles_prompt = f"[{obstacles}] {y_min} {y_max}"
        messages = format_chat_message(high_level_obstacles_prompt)
        sampling_params = SamplingParams(temperature=0.0, max_tokens=1000)
        responses = llm_low.chat(messages, sampling_params)
        response_text = responses[0].outputs[0].text
        last_line = response_text.strip().split('\n')[-1]
        try:
            obstacles = last_line.split('=')[1].strip()
        except Exception as e:
            print(f"Error parsing sub-environment: {e}")
            obstacles = ''
        try:
            save_obstacles = eval(obstacles)
        except Exception as e:
            print(f"解析障碍物列表出错: {e}")
            save_obstacles = []
        obstacles_pattern2 = r"(obstacles that you have to avoid at:)([^\.]*)"
        match = re.search(obstacles_pattern2, new_desc)
        if match:
            prefix = match.group(1)
            all_obstacles_str = match.group(2)
            all_obstacles = re.findall(r"\(\d+,\d+\)", all_obstacles_str)
            filtered_obstacles = []
            for obs in all_obstacles:
                try:
                    obs_tuple = eval(obs)
                except:
                    continue
                if obs_tuple in save_obstacles:
                    filtered_obstacles.append(obs)
            if filtered_obstacles:
                if len(filtered_obstacles) == 1:
                    new_obstacles_str = filtered_obstacles[0]
                elif len(filtered_obstacles) == 2:
                    new_obstacles_str = f"{filtered_obstacles[0]} and {filtered_obstacles[1]}"
                else:
                    new_obstacles_str = ', '.join(filtered_obstacles[:-1]) + f" and {filtered_obstacles[-1]}"
            else:
                new_obstacles_str = ''
            if new_obstacles_str:
                for_low_desc = re.sub(obstacles_pattern2, f"{prefix} {new_obstacles_str}", new_desc)
            else:
                for_low_desc = re.sub(obstacles_pattern2, '', new_desc)
            for_low_desc = for_low_desc.strip()
        else:
            for_low_desc = new_desc
        return for_low_desc, response_text,high_level_obstacles_prompt
    else:
        return new_desc, "",""

def extract_actions_from_text(text):
    if not text:
        return ""
    line = text.split('\n')[-1]
    actions = ""
    actions = line.split(":", 1)[-1].strip().replace("**", "").strip()
    return actions if actions else ""

def format_chat_message(content):
    messages = [{"role": "user", "content": content}]
    return messages

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_name_low", type=str, required=True)
    parser.add_argument("--input_file", type=str, required=True)
    parser.add_argument("--output_filename", type=str, required=True)
    parser.add_argument("--prompt_file", type=str, required=True)
    args = parser.parse_args()

    llm_low = LLM(
        model=args.model_name_low,
        trust_remote_code=True,
        tensor_parallel_size=2,
        gpu_memory_utilization=0.9
    )

    os.makedirs('outputs', exist_ok=True)


    high_level_results = json.load(open(args.input_file, 'r'))
    with open(args.prompt_file, 'r') as f:
        low_level_prompt_template = f.read()

    final_results = []

    for i, data in enumerate(high_level_results):
        print(f"Processing low-level actions for example #{i}")
        start_time = time.time()
        response_low = {}
        actions_str = ''
        try:
            coord_list = data['coord_list']
            nl_description = data['nl_description']
            ground_truth = data['ground_truth']
            FOUND = 0
            if len(coord_list) >= 2:
                for idx in range(len(coord_list)):
                    x_start, y_start = coord_list[idx-FOUND]
                    if idx == len(coord_list) - 1:
                        if FOUND == 0:
                            continue
                        x_end, y_end = coord_list[idx]
                    else:
                        x_end, y_end = coord_list[idx + 1]

                    start_point = tuple(coord_list[idx-FOUND])
                    end_point = tuple(coord_list[idx + 1])
                    pattern = r"Go from \(\d+,\d+\) to \(\d+,\d+\)"
                    replacement = f"Go from {start_point} to {end_point}"
                    env_desc = re.sub(pattern, replacement, nl_description)
                    if idx == len(coord_list) - 1:
                        low_level_plan_prompt = f"{low_level_prompt_template} {env_desc}"
                    else:
                        x_min = min(x_start, x_end)
                        x_max = max(x_start, x_end)
                        y_min = min(y_start, y_end)
                        y_max = max(y_start, y_end)
                        new_desc, response_high_env_text,high_level_obstacles_prompt = environment_HRL(x_min, x_max, y_min, y_max, env_desc, llm_low)
                        low_level_plan_prompt = f"{low_level_prompt_template} {new_desc}"
                    messages = format_chat_message(low_level_plan_prompt)
                    sampling_params = SamplingParams(temperature=0.0, max_tokens=5000)
                    responses = llm_low.chat(messages, sampling_params)
                    response_text = responses[0].outputs[0].text
                    actions = extract_actions_from_text(response_text)
                    pattern = r'\b(left|right|up|down)\b'
                    matches = re.findall(pattern, actions)

                    if matches:
                        FOUND = 0
                    else:
                        FOUND += 1
                    response_low[idx] = {
                        "subtask_low": f"{start_point} -> {end_point}",
                        "actions": actions,
                        "matches": matches,
                        "low_level_desc": low_level_plan_prompt,
                        "response_low": response_text,
                        "env_desc": env_desc,
                        "high_level_obstacles_prompt": high_level_obstacles_prompt,
                        "response_high_env": response_high_env_text

                    }
                    actions_str += " ".join(matches) + ' '
                print("Final actions list:", actions_str)
            else:
                print("Insufficient anchor points to create sub-tasks. Skipping.")
            end_time = time.time()
            final_results.append({
                "english": nl_description,
                "ground_truth": ground_truth,
                "generated": actions_str.strip(),
                "subtasks_low": response_low,
                "coord_list": coord_list,
                "nl_description": nl_description,
                "time": end_time - start_time
            })
        except Exception as e:
            print(f"Error processing low-level tasks for example #{i}: {e}")
            end_time = time.time()
            final_results.append({
                "english": data.get('nl_description', ''),
                "ground_truth": data.get('ground_truth', ''),
                "generated": "",
                "time": end_time - start_time
            })

    output_filename = args.output_filename if args.output_filename.endswith('.json') else args.output_filename + '.json'
    with open(f'outputs/{output_filename}', 'w', encoding='utf-8') as fo:
        json.dump(final_results, fo, indent=4, ensure_ascii=False)
    print(f"Final results saved to outputs/{output_filename}")

if __name__ == "__main__":
    main()

