# intervention/tictactoe_task.py
import json
import os
import re
from typing import List, Dict, Any, Optional, Tuple

# --- User-Provided Prompt Generation Functions ---

BEST_MOVE_DEFINITION_PROMPT = """The determination of a "best move" follows a strict hierarchy of objectives:

1.  **Priority 1: Fastest Win.** If the current player can force a win, the 'best_move' will be the move that leads to the quickest possible victory (a win in the minimum number of subsequent turns).

2.  **Priority 2: Secure a Draw.** If a win is not possible, but the player can force a draw, the 'best_move' will contain all moves that guarantee at least a draw.

3.  **Priority 3: Slowest Loss.** If the player is in a losing position where every move leads to an eventual loss, the 'best_move' will be the move that prolong the game as long as possible before the loss occurs.

4.  **Terminal State: No Moves.** If the game has already concluded (a player has won, or the board is full), no further moves can be made. In this case, the 'best_move' will be None."""

def prepare_best_move_prompt(sample, representation_mode, instruct_model, style="text_instruction"):
    board_state = sample.get(style, "")
    board = sample.get("board", [])
    p1_count = board.count(1)
    p2_count = board.count(2)
    current_player = 1 if p1_count == p2_count else 2

    if representation_mode == "nl":
        mapping_str = (
            "Mapping:\n"
            "Player 1 (X) Tokens: 1 -> (0,0), 2 -> (0,1), 3 -> (0,2), 4 -> (1,0), 5 -> (1,1), 6 -> (1,2), 7 -> (2,0), 8 -> (2,1), 9 -> (2,2)\n"
            "Player 2 (O) Tokens: 10 -> (0,0), 11 -> (0,1), 12 -> (0,2), 13 -> (1,0), 14 -> (1,1), 15 -> (1,2), 16 -> (2,0), 17 -> (2,1), 18 -> (2,2), None -> No Move can be played"
        )
        allowed_moves = [str(token) for token in sample.get("best_moves", [])]
        if not allowed_moves:
            allowed_moves = ["None"]
    else:
        raise ValueError(f"Representation mode '{representation_mode}' not implemented for this integration.")

    if instruct_model:
        system_msg = "You are a helpful assistant skilled at reasoning for tic tac toe. You first think about the reasoning process as an internal monologue and then provide the user with the answer. Respond in the following format: <think>\n...\n</think>\n<answer>\n...\n</answer>."
        user_msg = (
            f"Board state:\n{board_state}\n"
            f"It is Player {current_player}'s turn.\n"
            f"Recommend the best move which the player can play. Here is the definition of best move:\n{BEST_MOVE_DEFINITION_PROMPT}"
            f"{mapping_str}\n"
            "Please provide your reasoning in the following format:\n"
            "<think> Your chain-of-thought reasoning here </think>\n"
            "<answer> Your final move here </answer>\n"
            "Remember to output exactly one of the best moves. You can only have one set of <think>...</think> and <answer>...</answer> in your response. The think section should be at the beginning of your response."
        )
        # For probing, we typically use a single combined prompt string
        combined_prompt = system_msg + "\n" + user_msg
        return {"prompt": combined_prompt, "allowed_moves": allowed_moves}
    else:
        # Non-instruct model prompt generation
        prompt_str = (
            f"Current board state:\n{board_state}\n"
            f"It is Player {current_player}'s turn.\n"
            f"Recommend the best move which the player can play. Here is the definition of best move:\n{BEST_MOVE_DEFINITION_PROMPT}"
            f"{mapping_str}\n"
            "Please provide your reasoning in the following format:\n"
            "<think> Your chain-of-thought reasoning here </think>\n"
            "<answer> Your final move here </answer>\n"
            "Remember to output exactly one of the best moves."
        )
        return {"prompt": prompt_str, "allowed_moves": allowed_moves}


def generate_tictactoe_prompt(sample, representation_mode, instruct_model, experiment_mode="best_move", style="text_instruction"):
    if experiment_mode == "best_move":
        return prepare_best_move_prompt(sample, representation_mode, instruct_model, style=style)
    # You can add the 'legal_move' logic back here if needed
    raise ValueError(f"Experiment mode {experiment_mode} not recognized for this integration.")

# --- Task Class Wrapper ---

class TicTacToeTask:
    """
    Handles loading the Tic-Tac-Toe dataset and using the provided
    functions to generate prompts for the model.
    """

    def __init__(self, dataset_path: str, representation_mode="nl", instruct_model=True, invariance_token_pair: Optional[Tuple[str, str]] = None):
        """Initializes the task by loading the dataset and setting prompt modes."""
        self.dataset = self._load_dataset(dataset_path)
        self.piece_map = {0: 'empty', 1: 'X', 2: 'O'}
        self.piece_map_rev = {'empty': 0, 'X': 1, 'O': 2}
        self.representation_mode = representation_mode
        self.instruct_model = instruct_model
        # For this experiment, we are focused on finding the best move
        self.experiment_mode = "best_move"
        # Optional invariance substitution (e.g., ('P','Q') or ('A','B'))
        self.invariance_token_pair = invariance_token_pair

    def _load_dataset(self, path: str) -> List[Dict[str, Any]]:
        """Loads the JSON dataset from the given path."""
        print(f"Attempting to load dataset from: {path}")
        if not os.path.exists(path):
            raise FileNotFoundError(f"Dataset file not found at {path}. Please check the path.")
        try:
            with open(path, 'r') as f:
                dataset = json.load(f)
            print(f"Successfully loaded {len(dataset)} records.")
            return dataset
        except Exception as e:
            print(f"Error loading dataset: {e}")
            return []

    def get_prompt(self, board_data: Dict[str, Any], style="text_instruction") -> str:
        """
        Generates a standardized prompt for a given board state using the
        user-provided functions. Returns only the prompt string.
        """
        # Work on a shallow copy so we never mutate the original dataset entry
        sample_local = dict(board_data)
        if self.invariance_token_pair:
            x_tok, o_tok = self.invariance_token_pair
            # Only replace in the board state textual field, leave mapping/instructions untouched downstream
            original_text = sample_local.get(style, "")
            # Use word boundaries so we don't affect substrings (defensive even if data is clean)
            replaced_text = re.sub(r"\bX\b", x_tok, re.sub(r"\bO\b", o_tok, original_text))
            sample_local[style] = replaced_text

        prompt_data = generate_tictactoe_prompt(
            sample=sample_local,
            representation_mode=self.representation_mode,
            instruct_model=self.instruct_model,
            experiment_mode=self.experiment_mode,
            style=style
        )
        return prompt_data['prompt']
    
    def get_best_move_str(self, board_data: Dict[str, Any]) -> str:
        """Returns the ground truth best move as a string."""
        # This logic is correct based on your dataset structure
        return str(board_data['best_moves'][0])

    def find_boards_by_square_state(self, square_index: int, piece: str) -> List[Dict[str, Any]]:
        """Finds all boards where a specific square has a specific piece."""
        if piece not in self.piece_map_rev:
            return []
        piece_val = self.piece_map_rev[piece]
        return [b for b in self.dataset if b['board'][square_index] == piece_val]

    def find_board_pair_differing_at_square(self, square_index: int) -> Optional[tuple[Dict, Dict]]:
        """
s a "dirty" (O) and "clean" (X) pair of boards that are identical otherwise.
        This is crucial for a clean causal intervention.
        """
        boards_with_o = self.find_boards_by_square_state(square_index, 'O')
        for dirty_board in boards_with_o:
            target_board = list(dirty_board['board'])
            target_board[square_index] = self.piece_map_rev['X'] # Create the hypothetical clean board
            
            # Search for this exact clean board in the dataset
            for potential_clean_board in self.find_boards_by_square_state(square_index, 'X'):
                if potential_clean_board['board'] == target_board:
                    return dirty_board, potential_clean_board # Found a pair
        return None # No matching pair found

