from vllm import LLM, SamplingParams
import os 
import json
import time
import argparse
from tqdm import tqdm
from utils import *
import ast

def format_chat_message(content):
    messages = [{"role": "user", "content": content}]
    return messages

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_name_high", type=str, required=True)
    parser.add_argument("--output_filename", default="hier_high_output", type=str)
    parser.add_argument("--input_data", default="data/test.json", type=str)
    args = parser.parse_args()
    
    model_name_high = args.model_name_high
    output_filename = args.output_filename
    input_data_path = args.input_data

    llm_high = LLM(model=model_name_high, trust_remote_code=True, tensor_parallel_size=4, gpu_memory_utilization=0.9)
    test_data = json.load(open(input_data_path))

    high_level_prompt = "{task}"

    result = []
    successful_count = 0
    failed_count = 0
    
    for i in tqdm(range(len(test_data)), desc="Processing high-level planning"):
        print(f'Processing example #{i}')
        start_time = time.time()
        
        task = test_data[i]['prompt']
        messages = format_chat_message(high_level_prompt.format(task=task))
        sampling_params = SamplingParams(
            temperature=0,
            max_tokens=3000
        )
        responses = llm_high.chat(messages, sampling_params)
        response_text = responses[0].outputs[0].text        
        try:    
            response_text = response_text.split('Output:')[1].strip()
            response_text = response_text.split('[')[-1].strip()
            response_text = response_text.split(']')[0].strip()
            response_text = '[' + response_text.strip() + ']'
            intermediate_states = eval(response_text)
            test_data[i]['intermediate_states'] = intermediate_states
            test_data[i]['high_level_success'] = True
            successful_count += 1
            print(f"Successfully generated {len(intermediate_states)} intermediate states")
        except Exception as e:
            print(f"Intermediate states eval error: {e}")
            print("##################################################")
            print(response_text)
            print("##################################################")
            test_data[i]['intermediate_states'] = []
            test_data[i]['high_level_success'] = False
            failed_count += 1
        
        test_data[i]['high_level_processing_time'] = time.time() - start_time
        result.append(test_data[i])
    
    os.makedirs("output", exist_ok=True)
    
    output_path = f"output/{output_filename}.json"
    with open(output_path, 'w', encoding='utf-8') as f:
        json.dump(result, f, indent=4, ensure_ascii=False)
    
    print(f'High-level planning results saved to {output_path}')
    print(f'Success: {successful_count}, Failed: {failed_count}, Total: {len(test_data)}')
    print(f'Success rate: {successful_count/len(test_data)*100:.2f}%')
