#!/usr/bin/env python
import argparse
import logging
import sys
from transformers import AutoTokenizer
import prometheus_client

# Define a dummy `disable_created_metrics` function if it does not exist
if not hasattr(prometheus_client, "disable_created_metrics"):
    setattr(prometheus_client, "disable_created_metrics", lambda: None)

from vllm import LLM, SamplingParams  # vLLM for efficient inference
from data_utils import SYSTEM_MESSAGE_INSTRUCTION_MODEL
from grpo_data_util import generate_tictactoe_prompt, extract_final_answer
from logging_util import setup_logging, get_logger

setup_logging()
logger = get_logger(__name__)

def parse_args():
    parser = argparse.ArgumentParser(description="GRPO Tic-Tac-Toe CLI Inference")
    parser.add_argument("--model_checkpoint", type=str, required=True,
                        help="Path to the trained model checkpoint")
    parser.add_argument("--representation_mode", type=str, default="nl", choices=["nl", "special"],
                        help="Representation mode")
    parser.add_argument("--instruct_model", type=lambda x: x.lower() == "true", default=False,
                        help="Whether the model uses instruct prompt style")
    parser.add_argument("--max_new_tokens", type=int, default=256,
                        help="Maximum tokens to generate per prompt")
    parser.add_argument("--temperature", type=float, default=0.7,
                        help="Sampling temperature for generation")
    parser.add_argument("--top_k", type=int, default=50,
                        help="Top-K sampling parameter")
    parser.add_argument("--top_p", type=float, default=0.9,
                        help="Top-P sampling parameter")
    parser.add_argument("--mode", type=str, default="game", choices=["game", "chat"],
                        help="Mode of interaction: 'game' (suggests a move) or 'chat' (freeform interaction)")
    return parser.parse_args()

def format_chat_message(history, user_input):
    """
    Formats chat messages using Meta-Llama instruct format.
    Maintains history and ensures correct chat structure.
    """
    messages = [{"role": "system", "content": SYSTEM_MESSAGE_INSTRUCTION_MODEL}]
    messages.extend(history)
    messages.append({"role": "user", "content": user_input})
    
    formatted_chat = "\n".join([f"{m['role'].capitalize()}: {m['content']}" for m in messages])
    return formatted_chat, messages  # Return formatted string and updated history list

def main():
    args = parse_args()
    
    logger.info("Initializing model and tokenizer...")

    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(args.model_checkpoint, trust_remote_code=True)
    if tokenizer.pad_token is None:
        tokenizer.add_special_tokens({'pad_token': tokenizer.eos_token})
        logger.info("Added padding token to tokenizer.")

    # Load model with vLLM for fast inference
    engine = LLM(model=args.model_checkpoint)

    sampling_params = SamplingParams(
        max_tokens=args.max_new_tokens,
        temperature=args.temperature,
        top_k=args.top_k,
        top_p=args.top_p
    )

    chat_history = []  # Store chat history for conversation mode

    print("\n🧠 Tic-Tac-Toe AI CLI\n")

    if args.mode == "game":
        print("Game Mode: Enter the board state (e.g., 'XOX| XO|X  ' where | separates rows).")
    else:
        print("Chat Mode: Type anything to interact with the model. Chat history will be maintained.")

    print("Press Ctrl+C to exit.\n")

    while True:
        try:
            print("\nYou (Press Ctrl+D to submit on Linux/macOS or Ctrl+Z + Enter on Windows):")
            user_input = sys.stdin.read().strip()  # Allow multi-line input

            if not user_input:
                print("⚠️  Please enter a valid input.")
                continue

            if args.mode == "game":
                # Format input for Tic-Tac-Toe move prediction
                sample = {"text_instruction": user_input, "board": []}  # Dummy board representation
                prompt_dict = generate_tictactoe_prompt(sample, tokenizer, args.representation_mode, args.instruct_model)
                prompt_text = prompt_dict.get("prompt")

                # Reset chat history since each move is independent
                chat_history = []
            else:
                # Format conversation using chat history
                prompt_text, chat_history = format_chat_message(chat_history, user_input)

            print("\n🤖 Thinking...\n")
            outputs = engine.generate(prompt_text, sampling_params)

            for output in outputs:
                completion = output.outputs[0].text
                
                if args.mode == "game":
                    predicted_move = extract_final_answer(completion)
                    print(f"✨ Suggested Move: {predicted_move}\n")
                else:
                    print(f"🤖 {completion}\n")  # Direct chat response
                    
                    # Update chat history with assistant's response
                    chat_history.append({"role": "assistant", "content": completion})

                    # Keep chat history within the last 10 messages
                    if len(chat_history) > 10:
                        chat_history = chat_history[-10:]

        except KeyboardInterrupt:
            print("\n👋 Exiting the Tic-Tac-Toe AI CLI.")
            sys.exit(0)

if __name__ == "__main__":
    main()
