from vllm import LLM, SamplingParams
import os 
import json
import time
import argparse
from tqdm import tqdm
from utils import *

def environment_BW(s_1, s_2, llm_low):
    bw_env_decom_prompt = f"""Given the 2 state of the same group of Blocks:
{s_1}
{s_2}
Please select all state pairs (including block–block pairs, block–table pairs, and block–clear pairs) whose relative positions remain unchanged.
Please strictly follow this format for the output, including only the selected pairs:
### Output Format
<pairs> = [(Object1,Object2),(Object3,Object4),...]


<assistant>
<pairs> = """
    
    messages = format_chat_message(bw_env_decom_prompt)
    sampling_params = SamplingParams(temperature=0.0, max_tokens=1000)
    responses = llm_low.chat(messages, sampling_params)
    response_text = responses[0].outputs[0].text
    last_line = response_text.strip().split('\n')[-1]
    try:
        pairs = last_line.split('=')[1].strip()
    except Exception as e:
        print(f"Error parsing pairs: {e}")
        pairs = ''
    
    return pairs, response_text

def format_chat_message(content):
    messages = [{"role": "user", "content": content}]
    return messages

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_name_low", type=str, required=True)
    parser.add_argument("--input_filename", type=str, required=True)
    parser.add_argument("--output_filename", default="hier_low_output", type=str)
    args = parser.parse_args()
    
    model_name_low = args.model_name_low
    input_filename = args.input_filename
    output_filename = args.output_filename

    llm_low = LLM(model=model_name_low, trust_remote_code=True, tensor_parallel_size=4, gpu_memory_utilization=0.95)
    
    input_path = f"output/{input_filename}.json"
    if not os.path.exists(input_path):
        raise FileNotFoundError(f"Input file not found: {input_path}")
    
    test_data = json.load(open(input_path))
    print(f"Loaded {len(test_data)} examples from {input_path}")

    low_level_prompt = "{start} {goal}"

    result = []
    successful_count = 0
    failed_count = 0
    skipped_count = 0
    
    for i in tqdm(range(len(test_data)), desc="Processing low-level planning"):
        print(f'Processing example #{i}')
        start_time = time.time()
        
        if not test_data[i].get('high_level_success', False):
            print(f"Skipping example #{i}: High-level planning failed")
            test_data[i]['generated_actions'] = ""
            test_data[i]['low_level_success'] = False
            test_data[i]['skip_reason'] = "High-level planning failed"
            skipped_count += 1
            result.append(test_data[i])
            continue
            
        intermediate_states = test_data[i].get('intermediate_states', [])
        if len(intermediate_states) == 0:
            print(f"Skipping example #{i}: No intermediate states found")
            test_data[i]['generated_actions'] = ""
            test_data[i]['low_level_success'] = False
            test_data[i]['skip_reason'] = "No intermediate states"
            skipped_count += 1
            result.append(test_data[i])
            continue
        
        actions_full = ""
        questions_pairs = []
        env_decomp_info = []
        
        for a in range(len(intermediate_states) - 1):
            questions_pairs.append((intermediate_states[a], intermediate_states[a + 1]))
        
        print(f"Processing {len(questions_pairs)} state transitions")
        
        try:
            for j, questions_pair in enumerate(questions_pairs):
                print(f"  Processing transition {j+1}/{len(questions_pairs)}")
                
                pairs, env_response = environment_BW(questions_pair[0], questions_pair[1], llm_low)
                print(f"    Environment pairs: {pairs}")
                
                env_decomp_info.append({
                    "transition": j,
                    "start_state": questions_pair[0],
                    "goal_state": questions_pair[1],
                    "pairs": pairs,
                    "env_response": env_response
                })
                
                modified_start = questions_pair[0]
                modified_goal = questions_pair[1]
                
                if pairs:
                    try:
                        pairs_list = eval(pairs)
                        if isinstance(pairs_list, list) and len(pairs_list) > 0:
                            modified_start = f"{questions_pair[0]}\nUnchanged pairs: {pairs}"
                            modified_goal = f"{questions_pair[1]}\nUnchanged pairs: {pairs}"
                    except Exception as e:
                        print(f"    Error processing pairs: {e}")
                
                messages = format_chat_message(low_level_prompt.format(
                    start=modified_start, 
                    goal=modified_goal
                ))
                sampling_params = SamplingParams(
                    temperature=0,
                    max_tokens=2000
                )
                responses = llm_low.chat(messages, sampling_params)
                response_text = responses[0].outputs[0].text.strip()
                response_text = response_text.split('Actions:')[-1].strip()
                response_text = response_text.split('[')[-1].strip()
                response_text = response_text.split(']')[0].strip()
                print(f"    Response: {response_text}")
                
                if response_text:
                    actions_full += response_text
                    if j < len(questions_pairs) - 1:
                        actions_full += " | "
            
            test_data[i]['generated_actions'] = actions_full
            test_data[i]['low_level_success'] = True
            successful_count += 1
            print(f"Successfully generated actions: {actions_full}")
            
        except Exception as e:
            print(f"Low-level planning error for example #{i}: {e}")
            test_data[i]['generated_actions'] = ""
            test_data[i]['low_level_success'] = False
            test_data[i]['error_message'] = str(e)
            failed_count += 1
        
        test_data[i]['low_level_processing_time'] = time.time() - start_time
        test_data[i]['system2_plan_v1'] = ''
        test_data[i]['env_decomposition_info'] = env_decomp_info
        result.append(test_data[i])
    
    os.makedirs("output", exist_ok=True)
    
    output_path = f"output/{output_filename}.json"
    with open(output_path, 'w', encoding='utf-8') as f:
        json.dump(result, f, indent=4, ensure_ascii=False)
    
    print(f'Low-level planning results saved to {output_path}')
    print(f'Success: {successful_count}, Failed: {failed_count}, Skipped: {skipped_count}, Total: {len(test_data)}')
    print(f'Success rate: {successful_count/(len(test_data)-skipped_count)*100:.2f}%' if len(test_data)-skipped_count > 0 else 'No valid examples to process')
