"""
This module provides functions to process data dictionaries, specifically to modify 
the user's message to request a detailed reasoning process.

It emphasizes robustness, maintainability, and performance through modular design,
comprehensive error handling, and optional parallel execution for batch processing.
"""

import copy
import logging
from typing import Dict, Any, List, MutableMapping, Optional
from concurrent.futures import ProcessPoolExecutor, as_completed

# --- Constants for better maintainability ---
# Using constants makes the code easier to read and update. If the role or 
# template needs to change, it can be done in one place.
USER_ROLE = "user"
SYSTEM_ROLE = "system"
# Explain how to frame the response, which points to include, and the rationale behind selecting each element of the answer.
# Crucially, do not merely repeat the answer. Instead, focus on making the thought process and reasoning transparent, demonstrating precisely how and why the reference answer is reached.

# For ablation, we can use the following simple prompt template:
REASONING_PROMPT_TEMPLATE = """Given the following question and its reference answer, your task is to provide a detailed, step-by-step explanation that logically leads to the reference answer.

**Question:** 
```
{question}
```

**Reference Answer:** {ground_truth}
"""

# REASONING_PROMPT_TEMPLATE = """Given the following question and its reference answer, your task is to provide a detailed, step-by-step explanation that logically leads to the reference answer, as if reasoning it out with full clarity.

# - For objective, precise, or factual reference answers (e.g., a numerical value, single-choice option, proper noun, date, or other fact-based answer), break down the derivation process clearly. Specify all relevant information, intermediate steps, and logical connections that yield the unique correct answer.

# - For discursive, open-ended, or analytical reference answers (e.g., an explanation, summary, opinion, or creative text), outline your strategic approach, identify key considerations, and explain the conceptual framework or reasoning process that would arrive at such an answer.

# **Question:** 
# ```
# {question}
# ```

# **Reference Answer:** {ground_truth}

# Imagine you are an expert tutor guiding a student. You already know the Reference Answer and may use that knowledge internally to construct the reasoning path. OUTPUT REQUIREMENTS:
# 1. Output ONLY the step-by-step explanatory reasoning (no restatement of these instructions, no extra preface or summary beyond the reasoning itself).
# 2. Do NOT mention or imply that you had prior access to the Reference Answer. Avoid any phrases like: "According to the answer...", "Since the correct result is...", "To get to that answer...".
# 3. Do NOT use meta-commentary such as: "as if we're working through it together", "for the first time", "let's figure this out", "we already know", or any phrasing that reveals hidden foreknowledge or simulates collaboration.
# 4. Do NOT merely restate the final answer; elucidate the pathway that justifies it so a learner could reproduce the method independently next time.
# """

# REASONING_PROMPT_TEMPLATE = """Given the following question and its Reference Answer, your task is to generate a high-fidelity, first-person think-aloud monologue from the perspective of a meticulous and brilliant thinker encountering the problem for the first time. The monologue should realistically and comprehensively simulate the full problem-solving cycle—including analysis, summarization, exploration, reassessment, reflection, backtracking, and iteration—and culminate in a well-considered line of reasoning. You are strongly encouraged to leverage the Reference Answer internally to inform and scaffold a successful reasoning, but the final monologue must read as a genuine, first-time, real-time discovery and must not explicitly reference, cite, or quote the Reference Answer.

# **Question:**
# ```
# {question}
# ```
# **Reference Answer:**
# ```
# {ground_truth}
# ```

# **OUTPUT REQUIREMENTS:**
# 1.  Output ONLY the first-person, think-aloud monologue. Do not include any preface, summary, or restatement of these instructions.
# 2.  Maintain the tone of a focused individual thinking to themself. Avoid meta-commentary like: "as if we're working through it together," “for the first time,” and any phrasing that reveals simulation.
# 3.  Do not mention, imply, or hint at prior access to the Reference Answer in the monologue. Avoid phrases like “according to the answer…” or “to get to that answer…”, and any euphemism that signals foreknowledge.
# 4.  Do not merely restate the final answer in the monologue; articulate the reasoning pathway with sufficient intermediate steps, rationale, decision points, verification, and any necessary error-correction or backtracking.
# """

SYSTEM_PROMPT = (
        "A conversation between user and assistant. The user asks a question, and the assistant solves it. The "
        "assistant first thinks about the reasoning process in the mind and then provides the user with the answer. "
        "The reasoning process is enclosed within <think></think> tags, i.e., <think>\nThis is my "
        "reasoning.\n</think>\nThis is my answer."
    )

# --- Configure Logging ---
# Using the logging module is a best practice over print() for tracking events
# and errors, as it's configurable and more powerful.
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')


# --- Custom Exceptions for Clarity ---
# Custom exceptions make error handling more specific and informative.
class DataValidationError(ValueError):
    """Custom exception for missing or malformed data dictionary fields."""
    pass


class UserMessageNotFoundError(ValueError):
    """Custom exception for when a user message is not found in the prompt."""
    pass


class SystemMessageNotFoundError(ValueError):
    """Custom exception for when a system message is not found in the prompt."""
    pass


def _validate_data_dict(data_dict: Dict[str, Any]) -> None:
    """
    Validates that the necessary keys are present in the data dictionary.
    
    Args:
        data_dict: The dictionary to validate.
        
    Raises:
        DataValidationError: If a required key is missing.
    """
    if "prompt" not in data_dict:
        raise DataValidationError("Data dictionary is missing the 'prompt' field.")
    if not isinstance(data_dict["prompt"], list):
        raise DataValidationError("'prompt' field must be a list of messages.")
    if "reward_model" not in data_dict or "ground_truth" not in data_dict.get("reward_model", {}):
        raise DataValidationError("Data dictionary is missing the 'reward_model.ground_truth' field.")


def _find_message_by_role(messages: List[Dict[str, Any]], role: str) -> Optional[MutableMapping[str, Any]]:
    """
    Finds the first message with the specified role in a list of messages.
    
    Args:
        messages: A list of message dictionaries.
        role: The role to search for (e.g., 'user').
        
    Returns:
        The message dictionary if found, otherwise None.
    """
    for message in messages:
        if message.get("role") == role:
            return message
    return None


def modify_user_message_for_reasoning(data_dict: Dict[str, Any]) -> Dict[str, Any]:
    """
    Modifies the 'user' message in a data dictionary to request a reasoning process.

    This function performs a deep copy to ensure the original dictionary is not mutated.
    
    Args:
        data_dict: The original data dictionary.
        
    Returns:
        A new data dictionary with the modified user message.
        
    Raises:
        DataValidationError: If the input dictionary is malformed.
        UserMessageNotFoundError: If no 'user' message is found in the prompt.
    """
    _validate_data_dict(data_dict)
    
    # Using deepcopy is crucial for immutability, preventing unintended side effects.
    # While potentially slow for very large/complex objects, it's the safest approach.
    modified_dict = copy.deepcopy(data_dict)
    
    prompt_messages = modified_dict["prompt"]
    
    user_message = _find_message_by_role(prompt_messages, USER_ROLE)
    
    if not user_message:
        raise UserMessageNotFoundError(f"No message with role '{USER_ROLE}' found in the prompt.")
        
    original_question = user_message.get("content", "")
    ground_truth = modified_dict["reward_model"]["ground_truth"]
    
    # Update the content of the user message using the predefined template.
    user_message["content"] = REASONING_PROMPT_TEMPLATE.format(
        question=original_question,
        ground_truth=ground_truth
    )
    
    return modified_dict

def modify_system_message(data_dict: Dict[str, Any], new_system_content: Optional[str] = None) -> Dict[str, Any]:
    """
    Modify the content of the system message in the data dictionary.

    This function performs a deep copy to ensure the original dictionary is not mutated.
    
    Args:
        data_dict: The original data dictionary.
        new_system_content: If provided, use this string as the new system message content;
            otherwise fall back to the module-level SYSTEM_PROMPT.
        
    Returns:
        A new data dictionary with the updated system message.
        
    Raises:
        DataValidationError: If the input dictionary is malformed.
        SystemMessageNotFoundError: If no 'system' message is found in the prompt.
    """
    _validate_data_dict(data_dict)
    
    # Use deepcopy to ensure immutability and avoid unintended side effects
    modified_dict = copy.deepcopy(data_dict)
    
    prompt_messages = modified_dict["prompt"]
    
    system_message = _find_message_by_role(prompt_messages, SYSTEM_ROLE)
    
    if not system_message:
        raise SystemMessageNotFoundError(f"No message with role '{SYSTEM_ROLE}' found in the prompt.")
        
    # Update the system message content
    system_message["content"] = new_system_content if new_system_content is not None else SYSTEM_PROMPT

    return modified_dict


def find_system_message(data_dict: Dict[str, Any]) -> Optional[Dict[str, Any]]:
    """
    Locate the system message in the given data dictionary.
    
    Args:
        data_dict: The data dictionary to search.
        
    Returns:
        The system message dictionary if found, otherwise None.
        
    Raises:
        DataValidationError: If the input dictionary is malformed.
    """
    _validate_data_dict(data_dict)
    
    prompt_messages = data_dict["prompt"]
    return _find_message_by_role(prompt_messages, SYSTEM_ROLE)


def batch_modify_user_messages(data_list: List[Dict[str, Any]], parallel: bool = False, max_workers: Optional[int] = None) -> List[Dict[str, Any]]:
    """
    Applies the user message modification to a list of data dictionaries.

    Includes an option for parallel processing to improve performance on large datasets.
    
    Args:
        data_list: A list of data dictionaries to process.
        parallel: If True, uses multiprocessing to speed up the task. Defaults to False.
        max_workers: The maximum number of processes to use in parallel mode. 
                     Defaults to the number of CPUs on the machine.
                     
    Returns:
        A new list of modified data dictionaries. Original data is returned for items that fail.
    """
    if not parallel:
        # Sequential processing: simple, good for smaller datasets or debugging.
        modified_list = []
        for i, data_dict in enumerate(data_list):
            try:
                modified_dict = modify_user_message_for_reasoning(data_dict)
                modified_list.append(modified_dict)
            except (DataValidationError, UserMessageNotFoundError) as e:
                logging.warning(f"Skipping item {i} due to an error: {e}")
                # Append the original, unmodified item on failure.
                modified_list.append(copy.deepcopy(data_dict))
        return modified_list
    else:
        # Parallel processing: much faster for large datasets on multi-core machines.
        logging.info(f"Starting parallel processing with up to {max_workers or 'default'} workers.")
        modified_results = [None] * len(data_list)
        # We map futures to original indices to return the list in the correct order.
        future_to_index = {}

        with ProcessPoolExecutor(max_workers=max_workers) as executor:
            for i, data_dict in enumerate(data_list):
                future = executor.submit(modify_user_message_for_reasoning, data_dict)
                future_to_index[future] = i

            for future in as_completed(future_to_index):
                index = future_to_index[future]
                try:
                    modified_dict = future.result()
                    modified_results[index] = modified_dict
                except (DataValidationError, UserMessageNotFoundError) as e:
                    logging.warning(f"Processing item {index} failed with error: {e}")
                    # Append the original, unmodified item on failure.
                    modified_results[index] = copy.deepcopy(data_list[index])
        
        return modified_results


# --- Example Usage ---
if __name__ == "__main__":
    # 1. Define sample data, including one good and two malformed items for testing.
    sample_data_list = [
        {
            "uid": "Math-500-0",
            "prompt": [
                {"role": "system", "content": "A system message."},
                {"role": "user", "content": "Convert the point $(0,3)$ in rectangular coordinates to polar coordinates."}
            ],
            "reward_model": {"ground_truth": "\\left( 3, \\frac{\\pi}{2} \\right)"}
        },
        {
            "uid": "Test-Malformed-1",
            "prompt": [
                {"role": "system", "content": "A system message."}
                # This item is missing a 'user' role message
            ],
            "reward_model": {"ground_truth": "Some answer"}
        },
        {
            "uid": "Test-Malformed-2",
            # This item is missing the 'prompt' key entirely
            "reward_model": {"ground_truth": "Some answer"}
        }
    ]
    
    # --- Test Single Item Processing ---
    print("--- Testing Single Item Processing ---")
    first_item = sample_data_list[0]
    print("Original User Message:")
    print(first_item["prompt"][1]["content"])
    
    try:
        modified_item = modify_user_message_for_reasoning(first_item)
        print("\nModified User Message:")
        print(modified_item["prompt"][1]["content"])
    except Exception as e:
        logging.error(f"An error occurred during single item processing: {e}")

    print("\n" + "="*80 + "\n")

    # --- Test System Message Modification ---
    print("--- Testing System Message Modification ---")
    print("Original System Message:")
    system_msg = find_system_message(first_item)
    if system_msg:
        print(system_msg["content"])
        
    # Modify the system prompt
        new_system_content = "You are a helpful mathematics tutor. Please provide clear and detailed explanations."
        try:
            modified_system_item = modify_system_message(first_item, new_system_content)
            print("\nModified System Message:")
            print(modified_system_item["prompt"][0]["content"])
        except Exception as e:
            logging.error(f"An error occurred during system message modification: {e}")
    else:
        print("No system message found.")

    print("\n" + "="*80 + "\n")

    # --- Test Batch Processing (Sequential) ---
    print("--- Testing Batch Processing (Sequential) ---")
    modified_batch_seq = batch_modify_user_messages(sample_data_list)
    print("Sequential batch processing complete.")
    # The second and third items will be skipped and logged as warnings.
    print("\nModified content of first item in batch:")
    print(modified_batch_seq[0]["prompt"][1]["content"])
    print("\nUID of second item (should be unchanged):")
    print(modified_batch_seq[1]["uid"])
    
    print("\n" + "="*80 + "\n")

    # --- Test Batch Processing (Parallel) ---
    # This is most effective for thousands of items, but the example demonstrates its usage.
    print("--- Testing Batch Processing (Parallel) ---")
    modified_batch_para = batch_modify_user_messages(sample_data_list, parallel=True)
    print("Parallel batch processing complete.")
    print("\nModified content of first item in parallel batch:")
    print(modified_batch_para[0]["prompt"][1]["content"])
    print("\nUID of second item in parallel batch (should be unchanged):")
    print(modified_batch_para[1]["uid"])