#!/usr/bin/env python

################################################################################
#
# HOW TO INTERPRET THE FINE-GRAINED EVALUATION RESULTS
#
# This guide explains how to analyze the detailed statistics produced by this script.
#
################################################################################
"""
--------------------------------------------------------------------------------
I. CORE CONCEPTS
--------------------------------------------------------------------------------

There are two fundamental concepts to keep in mind:

1.  THE EXPERIMENT MODE DEFINES "CORRECTNESS":
    - If `--experiment_mode legal_move`: "Correct" means the model predicted ANY valid move from the list of legal moves.
    - If `--experiment_mode best_move`: "Correct" means the model predicted an OPTIMAL move from the list of best moves.

2.  STATE COMPLEXITY vs. TASK PERFORMANCE:
    - The statistic tables (by Ply, by Legal Moves, etc.) classify the *complexity of the game state*. They describe the properties of the board itself.
    - The accuracy scores within those tables measure the model's *performance on its assigned task* for that specific type of board.

--------------------------------------------------------------------------------
II. INTERPRETING EACH TABLE
--------------------------------------------------------------------------------

---
## 1. Overall Accuracy
---
* In `legal_move` mode: "What percentage of the time did the model successfully predict a valid move?" (This should be very high).
* In `best_move` mode: "What percentage of the time did the model successfully predict an optimal, perfect move?" (This is the main measure of strategic ability).

---
## 2. Accuracy by Game Ply
---
This table shows how performance changes as the game progresses.

* In `legal_move` mode: Less insightful, but a dip might indicate confusion in complex mid-game states.
* In `best_move` mode:
    -   **Low Ply (1-3):** Measures skill in the "opening" phase.
    -   **Mid Ply (4-6):** Measures skill in complex "mid-game" tactics. A dip here is common and shows a struggle with strategic calculation.
    -   **High Ply (7-9):** Measures skill in the "endgame". High accuracy shows it's good at finding obvious wins/blocks.

---
## 3. Accuracy by Number of Legal Moves
---
This table stress-tests the model against a changing number of choices (branching factor).

* In `legal_move` mode: Shows if the model struggles to find even one valid move when the board is very open and there are many options.
* In `best_move` mode (CRITICAL INTERPRETATION):
    -   This tests the model's ability to **focus**.
    -   A sharp drop in accuracy as `num_legal_moves` increases means the model is **distracted by suboptimal choices**. It struggles to differentiate the best move from a large pool of merely legal ones.

---
## 4. Accuracy by Number of Best Moves
---
This table shows how the model performs in critical vs. flexible situations.

* In `legal_move` mode: Shows if the model is more/less accurate at finding a legal move on boards that happen to be strategically simple (many best moves) or complex (one best move).
* In `best_move` mode:
    -   **Accuracy when `num_best_moves` is 1:** This is a measure of **precision under pressure**. A low score here indicates the model fails when only one specific move can win or save the game.
    -   **Accuracy when `num_best_moves` > 1:** The model has a higher chance of success here. Comparing this to the case above shows how much the model relies on "luck" vs. precise calculation.

---
## 5. Accuracy by Minimax Outcome Score
---
This is the deepest diagnostic, revealing the quality of the model's strategic reasoning.

* **Reference "Cheat Sheet":**
    #   Score | Interpretation           | Search Depth | Complexity
    #  -------|--------------------------|--------------|------------
    #   +ve Hi| Fast Win                 | Shallow      | Low
    #   +ve Lo| Slow, Forced Win         | Deep         | High
    #   0     | Forced Draw              | Very Deep    | Very High
    #   -ve Lo| Slow, Stalled Loss       | Deep         | High
    #   -ve Hi| Fast, Unavoidable Loss   | Shallow      | Low

* In `legal_move` mode: Shows if the model's basic legality check is affected by the underlying strategic nature of the position (e.g., "Does the model fail to find legal moves more often when it's in a losing position?").
* In `best_move` mode:
    -   **High Positive Scores (e.g., 4, 5):** Tests ability to find **short-term, tactical wins**.
    -   **Low Positive Scores (e.g., 1, 2):** Tests ability to find **long-term, strategic wins** that require deep lookahead.
    -   **Score of 0:** Tests **robust defensive play** and the ability to secure a draw.
    -   **Negative Scores:** Tests ability to play optimally in **losing positions** (e.g., choosing the move that prolongs the game the most).

"""

import argparse
import logging
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig

import hashlib

import math  # <--- Add math import
from collections import defaultdict # <--- Add defaultdict import
import os    # <--- Add this
import json  # <--- Add this

import prometheus_client
if not hasattr(prometheus_client, "disable_created_metrics"):
    prometheus_client.disable_created_metrics = lambda: None

from vllm import LLM, SamplingParams
from vllm.sampling_params import GuidedDecodingParams

from grpo_data_util import generate_tictactoe_prompt, extract_final_answer, evaluate_legal_move

from logging_util import setup_logging, get_logger

setup_logging()
logger = get_logger(__name__)
logger.info("GRPO Inference Started")

#################################################################
# Builder class for structured inference for off the shelf models
#################################################################
from pydantic import BaseModel

class ReasoningOutput(BaseModel):
    reasoning: str
    final_answer: str


#################################################################
# HELPER FUNCTIONS (Copied from dataset generator for on-the-fly analysis)
#################################################################

def check_winner(board):
    """
    Check if there is a winner on the given 3x3 board.
    Returns: winner (0 if none, 1 if P1, 2 if P2), is_terminal (bool)
    """
    lines = [
        (0,1,2), (3,4,5), (6,7,8), (0,3,6), (1,4,7), (2,5,8), (0,4,8), (2,4,6)
    ]
    for (a,b,c) in lines:
        if board[a] != 0 and board[a] == board[b] == board[c]:
            return board[a], True
    if 0 not in board:
        return 0, True
    return 0, False

def minimax_all_moves_depth_sensitive(board, player, depth=0, alpha=-math.inf, beta=math.inf):
    """
    Minimax that prefers faster wins and slower losses.
    """
    winner, is_terminal = check_winner(board)
    if is_terminal:
        if winner == 1: return 10 - depth, []
        elif winner == 2: return -10 + depth, []
        else: return 0, []

    empty_cells = [i for i, cell in enumerate(board) if cell == 0]
    
    if player == 1:  # Maximizing player
        max_eval = -math.inf
        best_moves = []
        for move in empty_cells:
            new_board = board[:]
            new_board[move] = 1
            evaluation, _ = minimax_all_moves_depth_sensitive(new_board, 2, depth + 1, alpha, beta)
            if evaluation > max_eval:
                max_eval, best_moves = evaluation, [move]
            elif evaluation == max_eval:
                best_moves.append(move)
            alpha = max(alpha, evaluation)
            if beta <= alpha: break
        return max_eval, best_moves
    else:  # Minimizing player
        min_eval = math.inf
        best_moves = []
        for move in empty_cells:
            new_board = board[:]
            new_board[move] = 2
            evaluation, _ = minimax_all_moves_depth_sensitive(new_board, 1, depth + 1, alpha, beta)
            if evaluation < min_eval:
                min_eval, best_moves = evaluation, [move]
            elif evaluation == min_eval:
                best_moves.append(move)
            beta = min(beta, evaluation)
            if beta <= alpha: break
        return min_eval, best_moves

# In inference.py

def calculate_complexity_metrics(sample):
    """
    Calculates complexity metrics for a given game state on-the-fly.
    This version correctly handles terminal states.
    
    ply stands for number of turns played
    - Low Ply (e.g., 1-3): Early-game. How well does the model handle open, strategic positions?
    - Mid Ply (e.g., 4-6): Mid-game. This is often where the most complex calculations for setting up wins or forcing blocks occur.
    - High Ply (e.g., 7-9): Late-game/Endgame. How well does the model perform when the board is nearly full and moves are more forced or obvious?
    
    Outcome score encodes search depth. The key is in the return values of the minimax_all_moves_depth_sensitive function:
    Win: return 10 - depth
    Loss: return -10 + depth
    Draw: return 0
    Here, depth represents the total number of moves on the board when the game ends (i.e., the final game_ply).
    
    - A high positive score (like 5) means a fast win, requiring a very shallow search. (Low complexity to find move)
    - A low positive score (like 1) means a slow, forced win, requiring a very deep search. (high complexity to find move)
    - A score of 0 often indicates a deep search was needed to confirm the draw. (Very high complexity to find move)
    - Low negative means a slow, stalled loss with deep search to prolong the game (High complexity to find move)
    - High negative means a very fast loss from shallow search which cannot avoid losing the game (Low complexity to find move)

    """
    board = sample['board']
    is_terminal = sample['is_terminal']
    
    # Game Ply is calculated for all states
    game_ply = sum(1 for x in board if x != 0)
    
    # Number of Legal Moves is 0 for terminal states
    num_legal_moves = len(sample.get('next_legal_moves', []))
    
    outcome_score = 0
    num_best_moves = 0 # There are no best "moves" in a terminal state

    if not is_terminal:
        # Calculate score for non-terminal states using Minimax
        p1_moves = board.count(1)
        p2_moves = board.count(2)
        current_player = 2 if p1_moves > p2_moves else 1
        initial_depth = p1_moves + p2_moves
        
        score, best_move_indices = minimax_all_moves_depth_sensitive(board, current_player, depth=initial_depth)
        
        outcome_score = score
        num_best_moves = len(best_move_indices)
    else:
        # Calculate a representative score for terminal states
        winner = sample['winner']
        if winner == 1:
            outcome_score = 10 - game_ply
        elif winner == 2:
            outcome_score = -10 + game_ply
        # If it's a draw (winner == 0), the score remains 0.

    return {
        "game_ply": game_ply,
        "num_legal_moves": num_legal_moves,
        "num_best_moves": num_best_moves,
        "outcome_score": outcome_score
    }

def print_stats(metric_name, stats_dict, logger):
    """Helper function to print formatted accuracy statistics."""
    logger.info(f"\n--- Accuracy by {metric_name} ---")
    sorted_keys = sorted(stats_dict.keys())
    for key in sorted_keys:
        stats = stats_dict[key]
        total = stats['total']
        correct = stats['correct']
        accuracy = (correct / total * 100) if total > 0 else 0
        logger.info(f"{metric_name} {key}: {accuracy:.2f}% ({correct}/{total})")




def load_test_dataset(args):
    """
    Loads the test dataset either from a HuggingFace dataset ID or a local JSON file.
    """
    if args.dataset_id_or_path:
        dataset = load_dataset(args.dataset_id_or_path, split=args.dataset_splits)
    else:
        dataset = load_dataset("json", data_files=args.test_dataset_path, split="train")
    return dataset

def parse_args():
    parser = argparse.ArgumentParser(description="GRPO Inference Script")
    parser.add_argument("--model_checkpoint", type=str, required=True,
                        help="Path to the trained model checkpoint")
    # Add argument to indicate that the model is off the shelf so that we can load it separately
    parser.add_argument("--off_the_shelf", action="store_true",
                        help="If set, the model is an off-the-shelf model and will be loaded differently.")
    parser.add_argument("--test_dataset_path", type=str, required=False,
                        help="Path to the test dataset JSON file")
    parser.add_argument("--dataset_id_or_path", type=str, default="",
                        help="Dataset identifier if using a HuggingFace dataset")
    parser.add_argument("--dataset_splits", type=str, default="test",
                        help="Dataset split to use")
    parser.add_argument("--representation_mode", type=str, default="nl", choices=["nl", "special"],
                        help="Representation mode")
    parser.add_argument("--instruct_model", type=lambda x: x.lower() == "true", default=False,
                        help="Whether the model uses instruct prompt style")
    parser.add_argument("--batch_size", type=int, default=8,
                        help="Batch size for inference (currently processes samples sequentially)")
    parser.add_argument("--max_new_tokens", type=int, default=1024,
                        help="Maximum tokens to generate per prompt")
    parser.add_argument("--model_mark", type=str, required=True,
                        help="Unique identifier for the model to use for log file naming")
    parser.add_argument("--experiment_mode", type=str, default="legal_move",
                        help="What kind of inference to run, default is legal move, and can be set to best_move for best move prediction instead.")
    parser.add_argument("--results_folder", type=str, required=True,
                        help="Path to the folder where results will be saved.")
    parser.add_argument("--random_xy_moves", action="store_true",
                        help="If set, use random moves like A and B instead of X and O.")
    parser.add_argument("--save_individual_outputs", action="store_true",
                        help="If set, save the detailed model output for each sample in the results JSON.")
    parser.add_argument("--structured_generation", action="store_true",
                        help="If set, use structured JSON generation.")
    parser.add_argument("--use-ascii-board", action="store_true",
                        help="If set, use the ascii_board field from the dataset for board representation instead of text_instruction.")
    return parser.parse_args()

def make_banned_tokens_processor(tokenizer, banned_words):
    """
    Creates a custom logits processor function that bans tokens which would complete
    any of the banned words.
    
    Args:
        tokenizer: The tokenizer used for encoding.
        banned_words: List of strings that should not be generated.
        
    Returns:
        A function (logits processor) that takes token_ids (generated so far) and logits
        (numpy array of shape (vocab_size,)) and returns modified logits.
    """
    # Pre-encode each banned word to get a list of token id sequences.
    banned_sequences = []
    for word in banned_words:
        token_ids = tokenizer.encode(word, add_special_tokens=False)
        if token_ids:
            banned_sequences.append(token_ids)
    
    def banned_tokens_processor(token_ids, logits, **kwargs):
        """
        For each banned sequence, if the generated token sequence ends with the prefix of the banned word,
        set the logit for the token that would complete the banned word to -infinity.
        
        Args:
            token_ids: List of generated token ids so far.
            logits: A NumPy array of logits (shape: (vocab_size,)).
            **kwargs: Other keyword arguments.
            
        Returns:
            Modified logits.
        """
        for seq in banned_sequences:
            seq_length = len(seq)
            # For a banned word that is only one token long, simply ban it outright.
            if seq_length == 1:
                logits[seq[0]] = -float("inf")
            else:
                # If we have generated enough tokens to compare
                if len(token_ids) >= seq_length - 1:
                    # Check if the last (seq_length - 1) tokens match the prefix of the banned sequence.
                    if token_ids[-(seq_length - 1):] == seq[:-1]:
                        # Ban the token that would complete the banned word.
                        logits[seq[-1]] = -float("inf")
        return logits

    return banned_tokens_processor


def main():
    args = parse_args()
    
    temperature = 0.0
    top_k = 1
    top_p = 1.0
    
    # Setup logging with a file based on the model mark.
    log_file = f"logs/{args.model_mark}_{temperature}_{top_k}_{top_p}_{args.max_new_tokens}_{args.experiment_mode}_{args.random_xy_moves}.log"
    setup_logging(log_file=log_file)
    logger = get_logger(__name__)
    logger.info("GRPO Inference Started")
    
    # --- 1. CONSTRUCT RESULTS PATH ---
    clean_model_checkpoint = args.model_checkpoint.replace("/", "_").strip('_')
    
    clean_dataset_path = os.path.basename(args.test_dataset_path) if args.test_dataset_path else "hf_dataset"

    results_dir = os.path.join(args.results_folder, args.experiment_mode, f"random_xy_moves_{args.random_xy_moves}")
    os.makedirs(results_dir, exist_ok=True)
    
    results_filename = (
        f"{args.model_mark}_results.json"
    )
    results_filepath = os.path.join(results_dir, results_filename)
    logger.info(f"Results will be saved to: {results_filepath}") 

    tokenizer = AutoTokenizer.from_pretrained(args.model_checkpoint, trust_remote_code=True)

    if "Qwen" in args.model_checkpoint and not args.off_the_shelf:
        config = AutoConfig.from_pretrained(args.model_checkpoint, trust_remote_code=True)
        logger.info("Tokenizer Vocab size before: %d", tokenizer.vocab_size)
        logger.info("Config vocab size: %d", config.vocab_size)
    else:
        if tokenizer.pad_token is None:
            tokenizer.add_special_tokens({'pad_token': tokenizer.eos_token})
            logger.info("Added padding token to tokenizer.")
            
        banned_phrases = []
        banned_subwords = []
        for phrase in banned_phrases:
            tokens = tokenizer.encode(phrase, add_special_tokens=False)
            banned_subwords.extend([tokenizer.decode([tid]) for tid in tokens])
        banned_subwords = list(set(banned_subwords))
        
        logit_bias = {token_id: -float("inf") for phrase in banned_phrases for token_id in tokenizer.encode(phrase, add_special_tokens=False)}


        logger.info(f"Bad words: {banned_subwords}")

    engine = LLM(model=args.model_checkpoint, enforce_eager=True, gpu_memory_utilization=0.5)

    if args.structured_generation:
        json_schema = ReasoningOutput.model_json_schema()
        guided_decoding_params = GuidedDecodingParams(json=json_schema)
        sampling_params = SamplingParams(max_tokens = args.max_new_tokens, temperature = temperature, top_k = top_k, top_p = top_p, repetition_penalty = 1.0, guided_decoding=guided_decoding_params)
    else:
        sampling_params = SamplingParams(max_tokens = args.max_new_tokens, temperature = temperature, top_k = top_k, top_p = top_p, repetition_penalty = 1.0)
    
    dataset = load_test_dataset(args)
    logger.info(f"Loaded test dataset with {len(dataset)} samples.")
    
    stats_by_ply = defaultdict(lambda: {'correct': 0, 'total': 0})
    stats_by_legal_moves = defaultdict(lambda: {'correct': 0, 'total': 0})
    stats_by_best_moves = defaultdict(lambda: {'correct': 0, 'total': 0})
    stats_by_outcome = defaultdict(lambda: {'correct': 0, 'total': 0})

    total_correct = 0
    total_samples = 0
    
    individual_outputs = {}

    for i, sample in enumerate(dataset):
        complexity = calculate_complexity_metrics(sample)
        
        prompt_dict = generate_tictactoe_prompt(
            sample,
            args.representation_mode,
            args.instruct_model,
            experiment_mode=args.experiment_mode,
            random_xy_moves=args.random_xy_moves,
            structured_generation=args.structured_generation,
            use_ascii_board=args.use_ascii_board
        )
        prompt_text = prompt_dict.get("prompt")
        allowed_moves = prompt_dict.get("allowed_moves", [])
        
        if args.structured_generation:
            outputs = engine.generate(prompt_text, sampling_params)
        else:
            if args.experiment_mode == "legal_move":
                logging.info("Using SYSTEM MESSAGE from legal move")
                SYSTEM_PROMPT = "You are a helpful assistant skilled at reasoning for tic tac toe. You first think about the reasoning process as an internal monologue and then provide the user with the answer. Respond in the following format: <think>\n...\n</think>\n<answer>\n...\n</answer>"
            elif args.experiment_mode == "best_move":
                logging.info("Using SYSTEM MESSAGE from best move")
                SYSTEM_PROMPT = "You are a helpful assistant skilled at reasoning for tic tac toe. You first think about the reasoning process as an internal monologue and then provide the user with the answer. Respond in the following format: <think>\n...\n</think>\n<answer>\n...\n</answer>"
            else:
                raise ValueError(f"Experiment Mode {args.experiment_mode} not supported!!")
            
            messages = []
            messages.append({"role": "system", "content": SYSTEM_PROMPT})
            messages.append({"role": "user",   "content": prompt_text})
            
            outputs = engine.chat(messages, sampling_params)
        
        print(outputs)
        for output in outputs:
            prompt = output.prompt
            completion = output.outputs[0].text
            logger.info(f"Prompt: {prompt!r}, Generated text: {completion!r}")
        
        logger.info(f"Prompt: {prompt_text}")
        logger.info(f"Completion: {completion}")
        predicted_move = extract_final_answer(completion)
        is_correct = evaluate_legal_move(predicted_move, allowed_moves, args.representation_mode)
        
        ply = complexity['game_ply']
        num_legal = complexity['num_legal_moves']
        num_best = complexity['num_best_moves']
        score = complexity['outcome_score']

        stats_by_ply[ply]['total'] += 1
        stats_by_legal_moves[num_legal]['total'] += 1
        stats_by_best_moves[num_best]['total'] += 1
        stats_by_outcome[score]['total'] += 1
        
        if is_correct:
            total_correct += 1
            
            stats_by_ply[ply]['correct'] += 1
            stats_by_legal_moves[num_legal]['correct'] += 1
            stats_by_best_moves[num_best]['correct'] += 1
            stats_by_outcome[score]['correct'] += 1
            
        total_samples += 1
        
        logger.info(f"Sample {total_samples} | Allowed Moves: {allowed_moves} | Ply={ply} | Predicted move: '{predicted_move}', Is_Correct: {is_correct}")
        
        if args.save_individual_outputs:
            board_key = str(sample['board'])

            output_record = {
                "sample_index": i,
                "prompt": prompt_text,
                "full_completion": completion,
                "predicted_move": predicted_move,
                "is_correct": is_correct,
                "allowed_moves": allowed_moves,
                "complexity_metrics": complexity
            }
            
            if board_key not in individual_outputs:
                individual_outputs[board_key] = []
            individual_outputs[board_key].append(output_record)
            
            if total_samples % 100 == 0:
                logger.info(f"Processed {total_samples} samples, individual outputs so far: {len(individual_outputs)} unique boards.")
                logger.debug(f"Current individual outputs: {individual_outputs.keys()}")
                logger.info(f"Output for board {board_key}: {output_record}")
                
            
            
    
    accuracy = (total_correct / total_samples * 100) if total_samples > 0 else 0.0
    logger.info("\n#############################################")
    logger.info("           EVALUATION SUMMARY")
    logger.info("#############################################")
    logger.info(f"Overall Accuracy: {accuracy:.2f}% ({total_correct}/{total_samples})")

    print_stats("Game Ply", stats_by_ply, logger)
    print_stats("Number of Legal Moves", stats_by_legal_moves, logger)
    print_stats("Number of Best Moves", stats_by_best_moves, logger)
    print_stats("Minimax Outcome Score", stats_by_outcome, logger)
    logger.info("#############################################\n")
    
    results_data = {
        "metadata": {
            "model_checkpoint": args.model_checkpoint,
            "test_dataset_path": args.test_dataset_path,
            "dataset_id_or_path": args.dataset_id_or_path,
            "representation_mode": args.representation_mode,
            "experiment_mode": args.experiment_mode,
            "model_mark": args.model_mark,
            "max_new_tokens": args.max_new_tokens,
            "structured_generation": args.structured_generation,
        },
        "overall_stats": {
            "accuracy_percent": round(accuracy, 4),
            "correct_predictions": total_correct,
            "total_samples": total_samples,
        },
        "fine_grained_stats": {
            "by_game_ply": dict(sorted(stats_by_ply.items())),
            "by_num_legal_moves": dict(sorted(stats_by_legal_moves.items())),
            "by_num_best_moves": dict(sorted(stats_by_best_moves.items())),
            "by_minimax_outcome_score": dict(sorted(stats_by_outcome.items())),
        }
    }
    
    
    if args.save_individual_outputs:
        results_data["individual_outputs"] = individual_outputs
    
    try:
        with open(results_filepath, 'w') as f:
            json.dump(results_data, f, indent=4)
        logger.info(f"Successfully saved evaluation results to {results_filepath}")
    except IOError as e:
        logger.error(f"Failed to save results to {results_filepath}. Error: {e}")

if __name__ == "__main__":
    main()