#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import os
import glob
import sys
import csv
import re
import json
import argparse
from pathlib import Path
from typing import Dict, List, Tuple, Optional
import llm_plan_bench as lpb
from concurrent.futures import ThreadPoolExecutor
from functools import partial
import concurrent.futures

def parse_args() -> argparse.Namespace:
    """Parse command line arguments with defaults"""
    parser = argparse.ArgumentParser(description='LLM Comparison for Spatial Reasoning')
    
    parser.add_argument(
        '--model_path',
        default="gpt-4o",
        help='Path to the LLM model'
    )
        
    parser.add_argument(
        '--folder_path_prob',
        default="final_datasets/floortile/spatial_instances_instances/problems",
        help='Directory containing PDDL problem files'
    )
    
    parser.add_argument(
        '--folder_path_plan',
        default="final_datasets/floortile/spatial_instances_instances/solutions",
        help='Directory containing plan files'
    )
    
    parser.add_argument(
        '--comparison_folder',
        default="position_comparison_deepseek_r1",
        help='Output directory for comparison results'
    )
    
    return parser.parse_args()

def extract_init_block(content):
    """
    在 content 中找 `(:init`，然后用“括号计数”方法
    提取从它开始到匹配结束的完整块（包括嵌套）。
    返回 init_block (str)。如果没找到或不匹配则返回 None。
    """
    pattern_init = re.compile(r'\(\s*:init\b', re.IGNORECASE)
    match_init = pattern_init.search(content)
    if not match_init:
        return None
    
    start_idx = match_init.start()
    count = 0
    i = start_idx
    end_idx = -1
    
    while i < len(content):p
        if content[i] == '(':
            count += 1
        elif content[i] == ')':
            count -= 1
        i += 1
        if count == 0:
            end_idx = i
            break
    
    if end_idx == -1:
        return None
    
    init_block = content[start_idx:end_idx]
    return init_block

def parse_pddl_problem(problem_file_path):
    """
    解析 PDDL problem 文件，获取:
    1) adjacency: { 'up': {...}, 'down': {...}, 'left': {...}, 'right': {...} }
    2) initial_positions: { 'robot1': XXX, 'robot2': YYY }
    """
    with open(problem_file_path, 'r', encoding='utf-8') as f:
        content = f.read()
    adjacency = {'up': {}, 'down': {}, 'left': {}, 'right': {}}
    initial_positions = {'robot1': None, 'robot2': None}
    
    init_block = extract_init_block(content)
    if not init_block:
        return adjacency, initial_positions

    init_content = init_block[init_block.index(':init') + len(':init'): ].strip()

    fact_pattern = re.compile(r'\(([^()]*)\)')
    facts = fact_pattern.findall(init_content)
    
    for fact in facts:
        fact = fact.strip()
        parts = fact.split()
        if not parts:
            continue
        
        predicate = parts[0].lower()
        
        if predicate in adjacency:
            if len(parts) == 3:
                direction = predicate
                from_loc = parts[2]
                to_loc   = parts[1]
                adjacency[direction][from_loc] = to_loc
        
        elif predicate == 'robot-at':
            if len(parts) == 3:
                robot_id = parts[1].lower()
                loc      = parts[2]
                if robot_id in initial_positions:
                    initial_positions[robot_id] = loc
                    # print('debug_init_position',' robot',robot_id,' ',loc)
    
    return adjacency, initial_positions

def parse_plan_file(plan_file: str) -> List[Tuple[str, str]]:
    """
    Parse plan file to get list of moves
    
    Returns:
        List[Tuple[str, str]]: List of (robot_id, direction) tuples
    """
    pattern = re.compile(
        r'\(\s*(up|down|left|right)\s+(robot[12])\s+([^\s]+)\s+([^\s]+)\s*\)'
    )
    
    moves = []
    with open(plan_file, 'r', encoding='utf-8') as f:
        for line in f:
            line = line.strip()
            if not line or line.startswith(';'):
                continue
            
            match = pattern.search(line)
            if match:
                direction = match.group(1).lower()
                robot_id = match.group(2).lower()
                moves.append((robot_id, direction))
    
    return moves

def move_robots(adjacency: Dict, initial_positions: Dict, 
                moves: List[Tuple[str, str]]) -> Tuple[str, str]:
    """
    Execute moves to get final positions
    
    Returns:
        Tuple[str, str]: Final positions of robot1 and robot2
    """
    current_pos = initial_positions.copy()
    
    for robot_id, direction in moves:
        if robot_id not in current_pos or direction not in adjacency:
            continue
        
        pos = current_pos[robot_id]
        if pos and pos in adjacency[direction]:
            current_pos[robot_id] = adjacency[direction][pos]
    
    return current_pos['robot1'], current_pos['robot2']

def create_prompt(adjacency: Dict, initial_positions: Dict, 
                 moves: List[Tuple[str, str]]) -> str:
    """Create prompt for LLM"""
    return (
        "We have a PDDL-like environment.\n"
        f"Adjacency relations are: {adjacency}\n"
        f"The robot initial positions is at: {initial_positions}\n"
        f"Movement sequence: {moves}\n"
        "Predict final positions of robot1 and robot2 in the form:\n"
        "(final_r1_llm=..., final_r2_llm=...)\n"
        "No extra explanation.\n"
    )

def parse_llm_response(response: str) -> Tuple[Optional[str], Optional[str]]:
    """
    Parse LLM response to get predicted positions, handling quoted strings and None values
    
    Args:
        response: Raw LLM response string
        
    Returns:
        Tuple[Optional[str], Optional[str]]: Clean tile positions for robot1 and robot2, 
                                           can be None if position is not found or invalid
    """
    r1_tag, r2_tag = "final_r1_llm=", "final_r2_llm="
    final_r1_llm = final_r2_llm = None
    
    for tag, var in [(r1_tag, 'final_r1_llm'), (r2_tag, 'final_r2_llm')]:
        idx_start = response.find(tag)
        if idx_start != -1:
            idx_start += len(tag)
            idx_end = response.find(",", idx_start)
            if idx_end == -1:
                idx_end = response.find(")", idx_start)
            if idx_end == -1:
                idx_end = len(response)
            
            # Extract and clean the value
            value = response[idx_start:idx_end].strip()
            # Remove any single or double quotes and handle None/null values
            value = value.strip("'\"")
            if value.lower() in ('none', 'null', 'unknown', ''):
                value = None
            
            if var == 'final_r1_llm':
                final_r1_llm = value
            else:
                final_r2_llm = value
    
    return final_r1_llm, final_r2_llm

def call_llm_for_positions(adjacency: Dict, initial_positions: Dict,
                          moves: List[Tuple[str, str]], model_path: str) -> Tuple[str, str]:
    """Get LLM predictions for final positions"""
    prompt = create_prompt(adjacency, initial_positions, moves)
    model = lpb.BlackboxLLM(model_path)
    response = model(json.dumps({"prompt": prompt}, ensure_ascii=False))
    return parse_llm_response(response)

def process_problem_plan_pair(
    model_path: str,
    comparison_folder: Path,
    pair: Tuple[str, str]
) -> Tuple[str, List[str]]:
    """Process a single problem-plan pair"""
    problem_file, plan_file = pair
    problem_basename = Path(problem_file).stem
    print(f"\nProcessing {problem_file}...")
    
    # Parse problem and plan
    adjacency, initial_positions = parse_pddl_problem(problem_file)
    all_moves = parse_plan_file(plan_file)
    
    step_results = []
    comparison_lines = []
    
    # Process each step
    for step, moves in enumerate([all_moves[:i+1] for i in range(len(all_moves))], 1):
        # Get ground truth and LLM predictions
        final_r1, final_r2 = move_robots(adjacency, initial_positions, moves)
        final_r1_llm, final_r2_llm = call_llm_for_positions(
            adjacency, initial_positions, moves, model_path
        )
        
        # Record results
        comparison_lines.append(
            f"-step_{step}: final_r1={final_r1}, final_r2={final_r2}, "
            f"final_r1_llm={final_r1_llm}, final_r2_llm={final_r2_llm}"
        )
        
        is_correct = (
            (final_r1 == final_r1_llm)  and
            (final_r2 == final_r2_llm)
        )
        step_results.append("correct" if is_correct else "wrong")
        print(f"{problem_basename} - Step {step}/{len(all_moves)}: {'✓' if is_correct else '✗'}")
    
    # Save comparison output
    comparison_file = Path(comparison_folder) / f"{problem_basename}_comparison.txt"
    comparison_file.write_text('\n'.join(comparison_lines))
    
    return problem_basename, step_results

def llm_comparison(
    model_path: str,
    folder_path_prob: str,
    folder_path_plan: str,
    comparison_folder: str,
) -> Path:
    """Run LLM comparison on problem-solution pairs and save results"""
    comparison_folder = Path(comparison_folder)
    os.makedirs(comparison_folder, exist_ok=True)
    
    # Collect and pair files
    problem_plan_pairs = list(zip(
        sorted(glob.glob(os.path.join(folder_path_prob, "*.pddl"))),
        sorted(glob.glob(os.path.join(folder_path_plan, "*.sol")))
    ))
    
    # Create a partial function with fixed arguments
    process_pair = partial(
        process_problem_plan_pair,
        model_path,
        comparison_folder
    )
    
    # Process pairs in parallel
    results_dict = {}
    print(f"Processing {len(problem_plan_pairs)} problems with 10 workers...")
    
    with ThreadPoolExecutor(max_workers=1) as executor:
        # Submit all pairs for processing
        future_to_pair = {
            executor.submit(process_pair, pair): pair 
            for pair in problem_plan_pairs
        }
        
        # Collect results as they complete
        for future in concurrent.futures.as_completed(future_to_pair):
            pair = future_to_pair[future]
            try:
                problem_basename, step_results = future.result()
                results_dict[problem_basename] = step_results
                print(f"Completed {problem_basename}")
            except Exception as e:
                print(f"Error processing {Path(pair[0]).stem}: {str(e)}")
    
    # Generate CSV
    csv_path = comparison_folder / "problem_compare_table.csv"
    max_steps = max(len(steps) for steps in results_dict.values())
    
    with csv_path.open('w', newline='', encoding='utf-8') as f:
        writer = csv.writer(f)
        writer.writerow(["Problem"] + [f"step_{s}" for s in range(1, max_steps + 1)])
        for problem, results in sorted(results_dict.items()):
            writer.writerow([problem] + results + [""] * (max_steps - len(results)))
    
    print(f"\nAll results saved to {csv_path}")
    return csv_path


def main():
    """Main entry point"""
    args = parse_args()
    result_path = llm_comparison(
        model_path=args.model_path,
        folder_path_prob=args.folder_path_prob,
        folder_path_plan=args.folder_path_plan,
        comparison_folder=args.comparison_folder
    )
    print(f"\nResults saved to {result_path}")

if __name__ == "__main__":
    main()