import re
from typing import Dict, List, Optional, Set, Tuple

from rllm import Action
from rllm.rewards.reward_types import RewardOutput

# Cache for word sets by length
_word_sets_cache: Optional[Dict[int, Set[str]]] = None

def word_ladder_reward_fn_eval(task_info: dict, action):
    """Wrapper for word_ladder_reward_fn with eval=True."""
    return word_ladder_reward_fn(task_info, action, eval=True)


def extract_solution(solution_str: str) -> Optional[str]:
    """Extract the word ladder sequence from the solution string."""
    # Remove everything before the first "Assistant:" if present
    if "Assistant:" in solution_str:
        solution_str = solution_str.split("Assistant:", 1)[1]
    elif "<|im_start|>assistant" in solution_str:
        solution_str = solution_str.split("<|im_start|>assistant", 1)[1]

    # Look for answer pattern in the entire string
    answer_pattern = r"<answer>(.*?)</answer>"
    match = re.finditer(answer_pattern, solution_str, re.IGNORECASE | re.DOTALL)
    matches = list(match)
    if matches:
        final_answer = matches[-1].group(1).strip()
        return final_answer
    return None


def parse_word_ladder(ladder_str: str) -> List[str]:
    """Parse a word ladder string into a list of words.
    
    Handles various formats:
    - "word1 -> word2 -> word3"
    - "word1, word2, word3"
    - "word1 word2 word3"
    - "['word1', 'word2', 'word3']"
    """
    if not ladder_str:
        return []
    
    # Try to parse as Python list first
    ladder_str = ladder_str.strip()
    if ladder_str.startswith("[") and ladder_str.endswith("]"):
        try:
            import ast
            parsed = ast.literal_eval(ladder_str)
            if isinstance(parsed, list):
                return [str(w).strip().lower() for w in parsed if w]
        except:
            pass
    
    # Try arrow format: "word1 -> word2 -> word3"
    if "->" in ladder_str:
        words = [w.strip().lower() for w in ladder_str.split("->")]
        return [w for w in words if w]
    
    # Try comma format: "word1, word2, word3"
    if "," in ladder_str:
        words = [w.strip().lower() for w in ladder_str.split(",")]
        return [w for w in words if w]
    
    # Try space-separated format
    words = ladder_str.split()
    return [w.strip().lower() for w in words if w]


def is_valid_word_ladder(ladder: List[str], start_word: str, end_word: str) -> Tuple[bool, str]:
    """Validate that a word ladder is valid.
    
    A valid word ladder:
    1. Starts with start_word
    2. Ends with end_word
    3. Each consecutive pair differs by exactly one letter
    4. All words have the same length
    5. All words are valid (exist in the vocabulary)
    
    Returns:
        (is_valid, error_message)
    """
    if not ladder:
        return False, "Empty ladder"
    
    start_word = start_word.lower().strip()
    end_word = end_word.lower().strip()
    
    # Check start and end
    if ladder[0].lower() != start_word:
        return False, f"Does not start with '{start_word}'"
    
    if ladder[-1].lower() != end_word:
        return False, f"Does not end with '{end_word}'"
    
    # Check all words have same length
    word_len = len(start_word)
    for word in ladder:
        if len(word) != word_len:
            return False, f"Word '{word}' has incorrect length (expected {word_len})"
    
    # Check each consecutive pair differs by exactly one letter
    for i in range(len(ladder) - 1):
        word1 = ladder[i].lower()
        word2 = ladder[i + 1].lower()
        
        diff_count = sum(c1 != c2 for c1, c2 in zip(word1, word2))
        if diff_count != 1:
            return False, f"Words '{word1}' and '{word2}' differ by {diff_count} letters (expected 1)"
    
    # Check all words are valid (exist in vocabulary)
    word_sets = _get_word_sets()
    known_words = word_sets.get(word_len, set())
    
    for word in ladder:
        word_upper = word.upper().strip()
        if word_upper not in known_words:
            return False, f"Word '{word}' is not in the vocabulary"
    
    return True, ""


def normalize_ladder(ladder: List[str]) -> str:
    """Normalize a word ladder to a standard format for comparison."""
    return " -> ".join(w.lower().strip() for w in ladder)


def _get_word_sets() -> Dict[int, Set[str]]:
    """Get word sets by length, loading from words.csv file."""
    global _word_sets_cache
    if _word_sets_cache is not None:
        return _word_sets_cache
    
    import csv
    import os
    
    # Find the words.csv file relative to this module
    # The file is at examples/word_ladder/words.csv relative to project root
    # This module is at rllm/rewards/word_ladder_reward.py
    current_dir = os.path.dirname(os.path.abspath(__file__))
    # Go up: rllm/rewards -> rllm -> project root
    project_root = os.path.dirname(os.path.dirname(current_dir))
    words_csv_path = os.path.join(project_root, "examples", "word_ladder", "words.csv")
    
    word_sets: Dict[int, Set[str]] = {}
    
    with open(words_csv_path, 'r', encoding='utf-8') as f:
        reader = csv.DictReader(f)
        for row in reader:
            # Process each column: 3_letter, 4_letter, 5_letter
            for col_name, word in row.items():
                if word and word.strip():
                    # Extract length from column name (e.g., "3_letter" -> 3)
                    word_length = int(col_name.split('_')[0])
                    word_upper = word.strip().upper()
                    if word_length not in word_sets:
                        word_sets[word_length] = set()
                    word_sets[word_length].add(word_upper)
    
    _word_sets_cache = word_sets
    return _word_sets_cache


def score_answer(answer: Optional[str], entry: Dict, eval: bool = False) -> float:
    """Score a word ladder answer using the provided scoring logic.
    
    Args:
        answer: The answer string (comma-separated words)
        entry: Dictionary containing metadata with start_word and end_word
        eval: If True, return 0.0 for any invalid word ladder (strict evaluation).
              If False, return 1.0 * 0.5^n for valid structure with n invalid words.
    
    Returns:
        Score: 1.0 if perfect, 0.5^n for n invalid words (if structure valid and eval=False),
               0.0 if invalid structure or if eval=True and ladder is invalid
    """
    if not isinstance(answer, str):
        return 0.0
    
    answer_words = tuple(s.strip() for s in answer.upper().split(","))
    metadata = entry.get("metadata", {})
    start_word = metadata.get("start_word")
    end_word = metadata.get("end_word")
    
    if not start_word or not end_word:
        return 0.0
    
    start_word = start_word.upper().strip()
    end_word = end_word.upper().strip()
    word_length = len(end_word)
    
    word_sets = _get_word_sets()
    known_words = word_sets.get(word_length, set())
    
    # Check conditions:
    # 1. start and end word match question
    # 2. all words have the correct length
    # 3. every changed word is a single letter change from the previous word
    # 4. all words are in our vocabulary
    
    if len(answer_words) < 2:
        return 0.0
    
    if answer_words[0] != start_word or answer_words[-1] != end_word:
        return 0.0
    
    if not all(len(w) == word_length for w in answer_words):
        return 0.0
    
    for i in range(1, len(answer_words)):
        if sum(1 for a, b in zip(answer_words[i - 1], answer_words[i]) if a != b) != 1:
            return 0.0
    
    # If eval=True, check if all words are valid using is_valid_word_ladder
    if eval:
        ladder_lower = [w.lower() for w in answer_words]
        is_valid, _ = is_valid_word_ladder(ladder_lower, start_word.lower(), end_word.lower())
        if not is_valid:
            return 0.0
        # If valid, return 1.0 (all words must be in vocabulary for eval=True)
        return 1.0
    
    # If eval=False, use partial credit for invalid words
    reward = 1.0
    for word in answer_words:
        if word not in known_words:
            reward *= 0.5
    
    return reward


def word_ladder_reward_fn(task_info: dict, action: str | Action, eval: bool = False) -> RewardOutput:
    """
    A specialized reward function for word ladder tasks.
    
    Evaluates whether the agent correctly finds a valid word ladder from start to end word.
    Uses the score_answer function which checks structure and vocabulary.
    
    Args:
        task_info: Dictionary containing question, ground_truth, metadata (with start_word/end_word)
        action: The agent's solution string
        eval: If True, return 0.0 for any invalid word ladder (strict evaluation).
              If False, return partial credit for valid structure with invalid words.
    
    Returns:
        RewardOutput with reward and metadata
    """
    try:
        if isinstance(action, Action):
            action = action.action
        
        # Extract solution from action (may contain <answer> tags)
        extracted = extract_solution(action)
        if extracted is None:
            return RewardOutput(
                reward=0.0,
                is_correct=False,
                metadata={
                    "validation": "invalid_solution",
                    "error": "No solution extracted from answer tags",
                }
            )
        
        # Parse answer - expect comma-separated format
        # Handle various formats but convert to comma-separated for scoring
        print(f"action: {action}")
        print(f"extracted: {extracted}")
        ladder = parse_word_ladder(extracted)
        print(f"ladder: {ladder}")
        if not ladder:
            return RewardOutput(
                reward=0.0,
                is_correct=False,
                metadata={
                    "validation": "invalid_solution",
                    "error": "Could not parse word ladder",
                }
            )
        
        # Convert to comma-separated uppercase format for score_answer
        answer_str = ",".join(w.upper().strip() for w in ladder)
        
        # Get metadata - try multiple sources
        metadata = task_info.get("metadata", {})
        if not metadata:
            # Try to build metadata from task_info
            metadata = {}
            if task_info.get("start_word"):
                metadata["start_word"] = task_info["start_word"]
            if task_info.get("end_word"):
                metadata["end_word"] = task_info["end_word"]
        
        # Extract start and end words from metadata or task_info
        start_word = metadata.get("start_word") or task_info.get("start_word")
        end_word = metadata.get("end_word") or task_info.get("end_word")
        
        # If not in metadata, try to extract from question
        if not start_word or not end_word:
            question = task_info.get("question", "")
            patterns = [
                r"from\s+(\w+)\s+to\s+(\w+)",
                r"(\w+)\s+->\s+(\w+)",
                r"transform\s+(\w+)\s+to\s+(\w+)",
                r"(\w+)\s+to\s+(\w+)",
            ]
            for pattern in patterns:
                match = re.search(pattern, question, re.IGNORECASE)
                if match:
                    start_word = match.group(1)
                    end_word = match.group(2)
                    break
        
        # If still not found, try to extract from ground truth
        if not start_word or not end_word:
            ground_truth = task_info.get("ground_truth") or task_info.get("answer")
            if ground_truth:
                gt_ladder = parse_word_ladder(str(ground_truth)) if isinstance(ground_truth, str) else ground_truth
                if isinstance(gt_ladder, list) and len(gt_ladder) >= 2:
                    start_word = str(gt_ladder[0])
                    end_word = str(gt_ladder[-1])
        
        if not start_word or not end_word:
            return RewardOutput(
                reward=0.0,
                is_correct=False,
                metadata={
                    "validation": "invalid_solution",
                    "error": "Could not extract start_word and end_word",
                }
            )
        
        # Ensure metadata has start_word and end_word
        metadata["start_word"] = start_word
        metadata["end_word"] = end_word
        
        # Create entry dict for score_answer
        entry = {"metadata": metadata}
        
        # Score the answer
        score = score_answer(answer_str, entry, eval=eval)
        print(f"answer_str: {answer_str}")
        print(f"entry: {entry}")
        print(f"score: {score}")
        
        # Validate structure for metadata
        is_valid, error_msg = is_valid_word_ladder(ladder, start_word, end_word)
        print(f"is_valid: {is_valid}")
        print(f"error_msg: {error_msg}")
        
        # Determine validation status and correctness
        # If eval=True, is_correct only if score is 1.0 (strict)
        # If eval=False, is_correct if score >= 1.0 (allows partial credit)
        is_correct = (score >= 1.0) if not eval else (score == 1.0)
        
        if score >= 1.0:
            return RewardOutput(
                reward=score,
                is_correct=True,
                metadata={
                    "validation": "correct_solution",
                    "ladder": ladder,
                    "length": len(ladder) if ladder else 0,
                    "score": score,
                }
            )
        elif score > 0.0:
            # Valid structure but some words not in vocabulary
            return RewardOutput(
                reward=score,
                is_correct=False,
                metadata={
                    "validation": "valid_structure_partial_vocab",
                    "ladder": ladder,
                    "error": f"Some words not in vocabulary (score: {score})",
                    "score": score,
                }
            )
        else:
            # Invalid structure
            return RewardOutput(
                reward=0.0,
                is_correct=False,
                metadata={
                    "validation": "invalid_solution",
                    "error": error_msg if not is_valid else "Invalid word ladder structure",
                    "ladder": ladder,
                    "score": score,
                }
            )
    
    except Exception as e:
        return RewardOutput(reward=0.0, is_correct=False, metadata={"error": str(e)})

