#!/usr/bin/env python
"""
inference.py – Standalone inference script for evaluating a saved SFT model on a test dataset.

This script:
  - Imports the shared preprocess_example function and DataArguments from your common module.
  - Loads the saved model and tokenizer from a provided model path.
  - Loads the test dataset (a pre-split JSON file) and processes it using a DataArguments instance.
  - Runs generation with a specified max_new_tokens.
  - Computes accuracy by comparing the generated answer (extracted after the "Answer:" token) with the target answer.
"""

import argparse
import json
import torch
import traceback
from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, HfArgumentParser
from tqdm import tqdm
from typing import List, Dict

# Import the shared data processing functions and DataArguments.
from data_utils import preprocess_example, DataArguments, SYSTEM_MESSAGE_INSTRUCTION_MODEL, process_raw_data

from dataclasses import dataclass, field

from transformers import LogitsProcessor, LogitsProcessorList

# TODO: check implementation
class AllowedTokensLogitsProcessor(LogitsProcessor):
    def __init__(self, allowed_token_ids):
        self.allowed_token_ids = allowed_token_ids

    def __call__(self, input_ids, scores):
        # Create a mask with -inf for disallowed tokens.
        mask = torch.full_like(scores, float('-inf'))
        for token_id in self.allowed_token_ids:
            mask[:, token_id] = 0
        return scores + mask


# Define a separate InferenceConfig dataclass that does not inherit from DataArguments.
@dataclass
class InferenceConfig:
    model_path: str = field(metadata={"help": "Path to the saved model directory."})
    test_dataset_path: str = field(metadata={"help": "Path to the test dataset JSON file."})
    experiment_mode: str = field(default="legal_move", metadata={"help": "Experiment mode: legal_move or best_move."})
    representation_mode: str = field(default="ascii", metadata={"help": "Representation mode: ascii, natural, move_seq_explained, or move_seq_special."})
    max_new_tokens: int = field(default=10, metadata={"help": "Max number of new tokens to generate."})
    batch_size: int = field(default=4, metadata={"help": "Batch size for inference."})
    random_seed: int = field(default=42, metadata={"help": "Random seed for any random operations."})
    instruction_model: bool = field(default=False, metadata={"help": "True if the model is instruction-finetuned."})
    constrained: bool = field(default=False, metadata={"help": "Enable constrained generation."})
    random_moves: bool = field(default=False, metadata={"help": "Switch X and Y to random moves like A and B instead."})
    

# def prepare_test_dataset(raw_data, data_args: DataArguments) -> Dataset:
#     """
#     Processes the raw test data by applying preprocess_example to each raw example
#     and flattening the results.
#     """
#     flattened_examples = []
#     for ex in raw_data:
#         processed = preprocess_example(ex, data_args)
#         if processed:
#             if isinstance(processed, list):
#                 flattened_examples.extend(processed)
#             else:
#                 flattened_examples.append(processed)
#     return Dataset.from_list(flattened_examples)

def prepare_test_dataset(raw_data: List[Dict], data_args: DataArguments) -> Dataset:
    dataset, _ = process_raw_data(raw_data, data_args)
    return dataset


# TODO: Move this to data utils, this is copied in sft and inference scripts

# Define the special token mapping string.
mapping_str_special_token_test = (
    "Mapping of special move tokens to board positions:\n"
    "For Player 1 (X):\n"
    "  <move_1> = top-left, <move_2> = top-center, <move_3> = top-right,\n"
    "  <move_4> = middle-left, <move_5> = center, <move_6> = middle-right,\n"
    "  <move_7> = bottom-left, <move_8> = bottom-center, <move_9> = bottom-right.\n\n"
    "For Player 2 (O):\n"
    "  <move_10> = top-left, <move_11> = top-center, <move_12> = top-right,\n"
    "  <move_13> = middle-left, <move_14> = center, <move_15> = middle-right,\n"
    "  <move_16> = bottom-left, <move_17> = bottom-center, <move_18> = bottom-right.\n\n"
    "In this game, Player 1 (X) moves first, and moves are represented using special tokens. "
    "Player 2 (O) uses the same positions but with tokens indexed from 10 to 18."
)

# TODO: Inference for chat models
# TODO: Update to consider any valid legal move
def main():
    # Use HfArgumentParser to parse only the InferenceConfig.
    parser = HfArgumentParser(InferenceConfig)
    config, = parser.parse_args_into_dataclasses()  # config is an instance of InferenceConfig

    print(f"Model path: {config.model_path}")
    print(f"Test dataset path: {config.test_dataset_path}")
    print(f"Experiment mode: {config.experiment_mode}")
    print(f"Representation mode: {config.representation_mode}")
    print(f"Max new tokens: {config.max_new_tokens}")
    print(f"Batch size: {config.batch_size}")
    print(f"Constrained Generation: {config.constrained}")

    # Create a DataArguments instance for processing the dataset.
    # For inference, train and validation paths are not used.
    data_args = DataArguments(
        train_dataset_path="",
        val_dataset_path="",
        test_dataset_path=config.test_dataset_path,
        experiment_mode=config.experiment_mode,
        representation_mode=config.representation_mode,
        random_seed=config.random_seed
    )

    # Load the tokenizer and model.
    tokenizer = AutoTokenizer.from_pretrained(config.model_path, use_fast=True)
    model = AutoModelForCausalLM.from_pretrained(config.model_path, torch_dtype=torch.float16)
    model.eval()

    # For special move tokens representation, add tokens to vocabulary.
    # TODO: Move model loading and data loading to another script, we have duplicated logic in sft and inference scripts for adding new tokens.
    # TODO: Check if this already exists in the tokenizer before adding the special tokens!!!
    # For special move tokens representation, add tokens to vocabulary.
    if data_args.representation_mode == "move_seq_special":
        special_tokens = [f"<move_{i}>" for i in range(1, 19)]
        # Check which tokens are not already in the tokenizer's vocabulary.
        current_vocab = tokenizer.get_vocab()
        tokens_to_add = [token for token in special_tokens if token not in current_vocab]
        if tokens_to_add:
            tokenizer.add_tokens(tokens_to_add)
            # IMPORTANT: Resize model embeddings to include new tokens.
            model.resize_token_embeddings(len(tokenizer))
            print("Added special tokens:", tokens_to_add)
        else:
            print("Special tokens already exist in the tokenizer.")

        # Debug prints.
        print("Original Prompt:\n", mapping_str_special_token_test)
        encoded = tokenizer.encode(mapping_str_special_token_test)
        print("\nTokenized IDs:\n", encoded)
        decoded = tokenizer.decode(encoded)
        print("\nDecoded Prompt:\n", decoded)

        encoded = tokenizer.encode("<move_1>")
        print("\nTokenized ID for <move_1>:\n", encoded)
        decoded = tokenizer.decode(encoded)
        print("\nDecoded Prompt for <move_1>:\n", decoded)



    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    if torch.cuda.is_available():
        model.to("cuda")

    # Load the raw test dataset.
    with open(config.test_dataset_path, "r") as f:
        raw_data = json.load(f)

    # Process the raw test dataset using the shared preprocess_example function.
    test_dataset = prepare_test_dataset(raw_data, data_args)
    print(f"Test dataset prepared with {len(test_dataset)} examples.")

    total, correct = 0, 0
    num_debug_samples = 5  # Number of samples to print for debugging
    debug_count = 0

    # Inference loop: iterate over the test dataset in batches.
    for j, batch in tqdm(enumerate(test_dataset.batch(config.batch_size)),
                      total=(len(test_dataset) // config.batch_size) + 1):
        # input_texts = [text + "\nAnswer: " for text in batch["input_text"]]  # Ensure consistency with training
        input_texts = batch["input_text"]
        target_texts = batch["target_text"]

        # If using an instruction model, convert each input text into a messages format.
        if config.instruction_model:
            new_input_texts = []
            for text in input_texts:
                messages = [
                    {"role": "system", "content": SYSTEM_MESSAGE_INSTRUCTION_MODEL},
                    {"role": "user", "content": text},
                ]
                # Apply the chat template provided by the tokenizer.
                chat_prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=False, tokenize=False)
                new_input_texts.append(chat_prompt)
            input_texts = new_input_texts

        # Tokenize inputs with padding and truncation.
        inputs = tokenizer(input_texts, return_tensors="pt", padding=True, truncation=True)
        if torch.cuda.is_available():
            inputs = {k: v.to("cuda") for k, v in inputs.items()}

        # Build generation kwargs, applying constraints if enabled.
        generation_kwargs = {"max_new_tokens": config.max_new_tokens}
        if config.constrained:
            print("RUNNING CONSTRAINED GENERATION.....")
            # Set constrained generation options based on representation mode.
            if config.representation_mode == "move_seq_special":
                # Force output to include one of the special move tokens.
                allowed_moves = [f"<move_{i}>" for i in range(1, 19)]
                force_words_ids = tokenizer(allowed_moves, add_special_tokens=False).input_ids
                
                # forces the presence of certain tokens somewhere in the output, but other tokens can still appear
                print("FORCE WORDS IDs for special move tokens:", force_words_ids)
                generation_kwargs["force_words_ids"] = force_words_ids
                # Get token IDs for allowed moves.
                allowed_token_ids = [tokenizer.convert_tokens_to_ids(token) for token in allowed_moves]

                # Create a logits processor that restricts output to allowed tokens
                logits_processor = LogitsProcessorList([AllowedTokensLogitsProcessor(allowed_token_ids)])
                generation_kwargs["logits_processor"] = logits_processor
            elif config.representation_mode == "natural":
                # Force inclusion of the delimiter.
                force_word = "#########"
                force_words_ids = tokenizer([force_word], add_special_tokens=False).input_ids
                generation_kwargs["force_words_ids"] = force_words_ids
            elif config.representation_mode in ["ascii", "move_seq_explained"]:
                # Optionally, force that the output is one of the valid move numbers.
                valid_moves = [str(i) for i in range(1, 10)] + [str(i) for i in range(10, 19)]
                force_words_ids = tokenizer(valid_moves, add_special_tokens=False).input_ids
                generation_kwargs["force_words_ids"] = force_words_ids
            
            # Use beam search for constrained generation so we can explore different output trajectories
            generation_kwargs["num_beams"] = 5
            generation_kwargs["return_dict_in_generate"] = True
            generation_kwargs["remove_invalid_values"] = True
            generation_kwargs["do_sample"] = False

            print("Constrained Generation args:", generation_kwargs)

            # Optionally, add more parameters such as no_repeat_ngram_size, remove_invalid_values, etc.
            # generation_kwargs["no_repeat_ngram_size"] = 1
            # generation_kwargs["remove_invalid_values"] = True

        # Run generation.
        with torch.no_grad():
            if config.constrained:
                generation_output = model.generate(**inputs, **generation_kwargs, temperature = 0.1)
                full_gen = model.generate(**inputs, **generation_kwargs, temperature = 0.1)
                print("Constrained generation full gen:", type(full_gen))
            else:
                generation_output = model.generate(
                    **inputs,
                    max_new_tokens=config.max_new_tokens,
                    return_dict_in_generate=True,
                    temperature = 0.1
                )
                full_gen = model.generate(
                    **inputs,
                    max_new_tokens=config.max_new_tokens,
                    temperature = 0.1
                )

        # Always extract the sequences attribute from the generation output.
        generated_sequences = generation_output.sequences
        new_tokens = generated_sequences[:, inputs["input_ids"].shape[1]:]
        decoded_outputs = tokenizer.batch_decode(new_tokens, skip_special_tokens=True)

        if config.constrained:
            decoded_full_gen = tokenizer.batch_decode(full_gen.sequences, skip_special_tokens=True)
        else:
            decoded_full_gen = tokenizer.batch_decode(full_gen, skip_special_tokens=True)



        # Compare each output with the target text.
        for pred, target, inp, full_out in zip(decoded_outputs, target_texts, input_texts, decoded_full_gen):

            try:
                generated = pred  # We assume the answer is the entire response since "Answer:" is appended
            except Exception:
                generated = pred
            print("----- Sample Debug Info -----")
            print(f"Input:\n{inp}")
            print('++++++++++++++++++++++++')
            print(f"Target: {target}")
            print('++++++++++++++++++++++++')
            print(f"Predicted:\n{generated}")
            print('++++++++++++++++++++++++')
            print(f"Predicted Full Outputs:\n{full_out}")
            print('++++++++++++++++++++++++')
            print("-----------------------------\n")
            # Print debugging info for the first few samples.
            if debug_count < num_debug_samples:
                print("----- Sample Debug Info -----")
                print(f"Input: {inp}")
                print(f"Target: {target}")
                print(f"Predicted: {generated}")
                print("-----------------------------\n")
                debug_count += 1

            if generated == target.strip():
                correct += 1
            total += 1
        
        # debugging
        if j == 10:
            break

    accuracy = correct / total if total > 0 else 0.0
    print(f"Test Accuracy: {accuracy:.4f}")

if __name__ == "__main__":
    main()
