
from vllm import LLM, SamplingParams
import os
import json
import time
import argparse
import re

def environment_HRL(x_min, x_max, y_min, y_max, new_desc, llm_low):
    obstacles_pattern = r"obstacles that you have to avoid at: ([^\.]*)"
    obstacles_match = re.search(obstacles_pattern, new_desc)
    obstacles_str = ""
    if obstacles_match:
        obstacles_str = obstacles_match.group(1)
        high_level_obstacles_prompt = f"[{obstacles_str}] {x_min-1} {x_max+1}"
        messages = format_chat_message(high_level_obstacles_prompt)
        sampling_params = SamplingParams(temperature=0.0, max_tokens=1000)
        responses = llm_low.chat(messages, sampling_params)
        response_text = responses[0].outputs[0].text
        last_line = response_text.strip().split('\n')[-1]
        try:
            obstacles = last_line.split('=')[1].strip()
        except Exception as e:
            print(f"Error parsing sub-environment: {e}")
            obstacles = ''

        high_level_obstacles_prompt = f"[{obstacles}] {y_min} {y_max}"
        messages = format_chat_message(high_level_obstacles_prompt)
        sampling_params = SamplingParams(temperature=0.0, max_tokens=1000)
        responses = llm_low.chat(messages, sampling_params)
        response_text = responses[0].outputs[0].text
        last_line = response_text.strip().split('\n')[-1]
        try:
            obstacles = last_line.split('=')[1].strip()
        except Exception as e:
            print(f"Error parsing sub-environment: {e}")
            obstacles = ''
        try:
            save_obstacles = eval(obstacles)
        except Exception as e:
            print(f"解析障碍物列表出错: {e}")
            save_obstacles = []
        obstacles_pattern2 = r"(obstacles that you have to avoid at:)([^\.]*)"
        match = re.search(obstacles_pattern2, new_desc)
        if match:
            prefix = match.group(1)
            all_obstacles_str = match.group(2)
            all_obstacles = re.findall(r"\(\d+,\d+\)", all_obstacles_str)
            filtered_obstacles = []
            for obs in all_obstacles:
                try:
                    obs_tuple = eval(obs)
                except:
                    continue
                if obs_tuple in save_obstacles:
                    filtered_obstacles.append(obs)
            if filtered_obstacles:
                if len(filtered_obstacles) == 1:
                    new_obstacles_str = filtered_obstacles[0]
                elif len(filtered_obstacles) == 2:
                    new_obstacles_str = f"{filtered_obstacles[0]} and {filtered_obstacles[1]}"
                else:
                    new_obstacles_str = ', '.join(filtered_obstacles[:-1]) + f" and {filtered_obstacles[-1]}"
            else:
                new_obstacles_str = ''
            if new_obstacles_str:
                for_low_desc = re.sub(obstacles_pattern2, f"{prefix} {new_obstacles_str}", new_desc)
            else:
                for_low_desc = re.sub(obstacles_pattern2, '', new_desc)
            for_low_desc = for_low_desc.strip()
        else:
            for_low_desc = new_desc
        return for_low_desc, response_text, high_level_obstacles_prompt
    else:
        return new_desc, "", ""

def format_chat_message(content):
    messages = [{"role": "user", "content": content}]
    return messages

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_name_high", type=str, required=True, default='Qwen/Qwen3-4B-Instruct-2507')
    parser.add_argument("--input_file", type=str, required=True)
    parser.add_argument("--output_file", type=str, required=True)
    args = parser.parse_args()

    llm_high = LLM(
        model=args.model_name_high,
        trust_remote_code=True,
        tensor_parallel_size=2,
        gpu_memory_utilization=0.9
    )

    os.makedirs('outputs', exist_ok=True)

    try:
        eng = json.load(open(args.input_file, 'r'))
    except FileNotFoundError:
        print(f"Error: Input file '{args.input_file}' not found.")
        return

    high_level_results = []

    for i in range(len(eng)):
        print(f"Processing high-level plan for example #{i}")
        pattern = r"Go from \((\d+),(\d+)\) to \((\d+),(\d+)\)"
        match = re.search(pattern, eng[i]['nl_description'])
        if match:
            start_x, start_y, end_x, end_y = map(int, match.groups())
            x_min = min(start_x, end_x)
            x_max = max(start_x, end_x)
            y_min = min(start_y, end_y)
            y_max = max(start_y, end_y)
        
        start_time = time.time()
        new_desc, response_text, high_level_obstacles_prompt = environment_HRL(x_min, x_max, y_min, y_max, eng[i]['nl_description'], llm_high)

        try:
            high_level_plan_prompt = f"{new_desc}"
            messages = format_chat_message(high_level_plan_prompt)
            sampling_params = SamplingParams(temperature=0.0, max_tokens=1000)

            responses = llm_high.chat(messages, sampling_params)
            response_text = responses[0].outputs[0].text

            last_line = response_text.strip().split('\n')[-1]
            try:
                coord_str = last_line.split('=')[1].strip()
                coord_list = eval(coord_str)
            except Exception as e:
                print(f"Error parsing anchor points: {e}")
                coord_list = []

            high_level_results.append({
                "nl_description": eng[i]['nl_description'],
                "ground_truth": eng[i].get('agent_as_a_point'),
                "coord_list": coord_list,
                "time": time.time() - start_time,
                "new_desc": new_desc,
                "response_text": response_text,
                "high_level_obstacles_prompt": high_level_obstacles_prompt,
                "high_level_plan_prompt": high_level_plan_prompt
            })
            print(f"High-level planning completed for example #{i}.")

        except Exception as e:
            print(f"Error during high-level planning for example #{i}: {e}")
            high_level_results.append({
                "nl_description": eng[i]['nl_description'],
                "ground_truth": eng[i].get('agent_as_a_point', ''),
                "coord_list": [],
                "time": time.time() - start_time
            })
    
    with open(f'outputs/{args.output_file}', 'w', encoding='utf-8') as fo:
        json.dump(high_level_results, fo, indent=4, ensure_ascii=False)
    
    print(f"High-level plans saved to outputs/{args.output_file}")


if __name__ == "__main__":
    main()