import json
import random
from dataclasses import dataclass, field
from typing import Optional, Dict, List
import logging
import traceback

import torch
from datasets import Dataset
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    AutoConfig,
    HfArgumentParser,
    TrainerCallback,
    EarlyStoppingCallback
)
from trl import SFTTrainer, SFTConfig

# V1
# SYSTEM_MESSAGE_INSTRUCTION_MODEL = "You are a game playing model which plays tic-tac-toe by generating valid moves in the requested format."

# V2
SYSTEM_MESSAGE_INSTRUCTION_MODEL = "You are a helpful assistant skilled at reasoning for tic tac toe. You first think about the reasoning process as an internal monologue and then provide the user with the answer. Respond in the following format: <think>\n...\n</think>\n<answer>\n...\n</answer>"
# -----------------------------------------------------------------------------
# Argument dataclasses
# -----------------------------------------------------------------------------

@dataclass
class ModelArguments:
    model_name_or_path: str = field(metadata={"help": "Base model identifier (e.g. 'gpt2', 't5-small')."})
    use_pretrained: bool = field(default=True, metadata={"help": "If False, instantiate model from scratch."})
    instruction_model: bool = field(default=False, metadata={"help": "True if the model is instruction-finetuned."})

@dataclass
class DataArguments:
    train_dataset_path: str = field(metadata={"help": "Path to the pre-split training dataset JSON file."})
    val_dataset_path: str = field(metadata={"help": "Path to the pre-split validation dataset JSON file."})
    test_dataset_path: Optional[str] = field(default=None, metadata={"help": "Path to the pre-split test dataset JSON file (optional)."})
    experiment_mode: str = field(default="legal_move", metadata={"help": "Mode: 'legal_move' or 'best_move'."})
    representation_mode: str = field(default="ascii", metadata={"help": "Representation: 'ascii', 'natural', 'move_seq_explained', or 'move_seq_special'."})
    random_seed: int = field(default=42, metadata={"help": "Random seed for any random operations."})

@dataclass
class TTTTrainingArguments(SFTConfig):
    dataset_text_field: Optional[str] = field(default="input_text", metadata={"help": "Leave blank to use custom formatting."})
    remove_unused_columns: bool = field(default=False, metadata={"help": "Do not remove unused dataset columns."})
    use_liger: bool = field(default=True, metadata={"help": "Enable memory efficient training with liger."})
    # Other SFTConfig fields (e.g., learning_rate, num_train_epochs, etc.) are inherited.

# -----------------------------------------------------------------------------
# Dataset Preparation Functions
# -----------------------------------------------------------------------------

def process_raw_data(raw_data: List[Dict], data_args: DataArguments) -> (Dataset, dict):
    flattened_examples = []
    for example in raw_data:
        processed_output = preprocess_example(example, data_args)
        if processed_output:
            if isinstance(processed_output, list):
                flattened_examples.extend(processed_output)
            else:
                flattened_examples.append(processed_output)
    dataset = Dataset.from_list(flattened_examples)
    stats = {
        "raw_dataset_len": len(raw_data),
        "total_examples": len(flattened_examples)
    }
    return dataset, stats


def convert_target_text_to_representation(target_text : str, representation_mode: str):
    """
    We expect the target text to be a number
    """
    target_text = target_text
    # convert to special token if required
    if representation_mode == "move_seq_special":
        return f"<move_{target_text}>"
    # else simply return the target text
    return target_text

# TODO: Update prompts for everything except natural
def preprocess_example(example: Dict, data_args: DataArguments) -> Optional[Dict]:
    """
    Convert a raw dataset entry (which includes keys such as 'board', 'ascii_board',
    'text_instruction', 'move_sequences', and 'next_legal_moves')
    into one or more training examples.
    
    For experiment_mode "legal_move", each legal move from the "next_legal_moves" list
    is returned as a separate training example. The next player is inferred from the legal moves.
    
    For experiment_mode "best_move", a single example is returned. If the "next_legal_moves"
    list exists, the next player is determined from its first element; otherwise, a fallback is used.
    """
    # Ensure example is a dictionary (in case of JSON parsing issues)
    if not isinstance(example, dict):
        try:
            example = json.loads(example)
        except Exception:
            return None  # Skip if JSON parsing fails
    # --- LEGAL MOVE MODE ---
    if data_args.experiment_mode == "legal_move":
        if "next_legal_moves" not in example or not example["next_legal_moves"]:
            return None
        legal_moves = example["next_legal_moves"]
        board = example["board"]

        # Determine next player based on the first legal move.
        first_move = legal_moves[0]
        if 1 <= first_move <= 9:
            next_player = 1
        elif 10 <= first_move <= 18:
            next_player = 2
        else:
            next_player = "Unknown"

        output_list = []
        for move in legal_moves:
            if data_args.representation_mode == "ascii":
                input_text = (
                    f"Board state (ASCII):\n{example['ascii_board']}\n"
                    f"Next player to move: {next_player}. What is your move?"
                )
            elif data_args.representation_mode == "natural":
                mapping_str = (
                    "For Player 1 (X):\n"
                    "  1 = top-left, 2 = top-center, 3 = top-right,\n"
                    "  4 = middle-left, 5 = center, 6 = middle-right,\n"
                    "  7 = bottom-left, 8 = bottom-center, 9 = bottom-right.\n\n"
                    "For Player 2 (O), the move numbers start from 10:\n"
                    "  10 = top-left, 11 = top-center, 12 = top-right,\n"
                    "  13 = middle-left, 14 = center, 15 = middle-right,\n"
                    "  16 = bottom-left, 17 = bottom-center, 18 = bottom-right.\n\n"
                    "In this game, Player 1 (X) moves first, and each move is represented as a number. "
                    "Player 2 (O) uses the same positions but with numbers increased by 9."
                )
                # explanation = get_move_explanation(move)
                input_text = (
                    f"Board description:\n{example['text_instruction']}\n"
                    f"Context: {mapping_str}\n"
                    f"Next player to move: {next_player}."
                    "What move should be made next?"
                    f"\nAnswer:\n"
                    f"#########\n"
                )
            elif data_args.representation_mode == "move_seq_explained":
                explanation = get_move_explanation(move)
                input_text = (
                    f"Board state:\n{example['ascii_board']}\n"
                    f"Context: In this game, moves are encoded as numbers. Next player to move: {next_player}. {explanation}\n"
                    "What is your move?"
                )
            elif data_args.representation_mode == "move_seq_special":
                mapping_str_special_token = (
                    "For Player 1 (X):\n"
                    "  <move_1> = top-left, <move_2> = top-center, <move_3> = top-right,\n"
                    "  <move_4> = middle-left, <move_5> = center, <move_6> = middle-right,\n"
                    "  <move_7> = bottom-left, <move_8> = bottom-center, <move_9> = bottom-right.\n\n"
                    "For Player 2 (O):\n"
                    "  <move_10> = top-left, <move_11> = top-center, <move_12> = top-right,\n"
                    "  <move_13> = middle-left, <move_14> = center, <move_15> = middle-right,\n"
                    "  <move_16> = bottom-left, <move_17> = bottom-center, <move_18> = bottom-right.\n\n"
                    "In this game, Player 1 (X) moves first."
                    # "Player 2 (O) uses the same positions but with tokens indexed from 10 to 18."
                )
                input_text = (
                    f"Board state:\n{example['ascii_board']}\n"
                    f"Here is how each game move token maps to the game action in the tic tac toe board: {mapping_str_special_token}\n"
                    f"Next player to move: {next_player}. Use the game move tokens provided. What is your move?"
                )
            else:
                input_text = (
                    f"Board state:\n{example['ascii_board']}\n"
                    f"Next player to move: {next_player}. What is your move?"
                )
            output_list.append({
                "input_text": input_text,
                "target_text": str(move),
                "board": board,
                "next_player": next_player
            })
        return output_list

    # --- BEST MOVE MODE ---
    # TODO: Review and update this
    elif data_args.experiment_mode == "best_move":
        board = example["board"]
        # If next_legal_moves exists, use it to determine next_player; otherwise, fall back to board counts.
        if "next_legal_moves" in example and example["next_legal_moves"]:
            first_move = example["next_legal_moves"][0]
            if 1 <= first_move <= 9:
                next_player = 1
            elif 10 <= first_move <= 18:
                next_player = 2
            else:
                next_player = "Unknown"
        else:
            p1_count = sum(1 for x in board if x == 1)
            p2_count = sum(1 for x in board if x == 2)
            next_player = 1 if p1_count == p2_count else 2

        target_move = get_best_move(board, next_player)
        if data_args.representation_mode == "ascii":
            input_text = (
                f"Board state (ASCII):\n{example['ascii_board']}\n"
                f"Next player to move: {next_player}. What is your move?"
            )
        elif data_args.representation_mode == "natural":
            mapping_str = (
                "Mapping: For Player 1, cells are numbered as follows: 1 = top-left, 2 = top-center, 3 = top-right, "
                "4 = middle-left, 5 = center, 6 = middle-right, 7 = bottom-left, 8 = bottom-center, 9 = bottom-right. "
                "For Player 2, add 9 to each number (e.g., 10 = top-left, 11 = top-center, etc.)."
            )
            explanation = get_move_explanation(target_move)
            input_text = (
                f"Board description:\n{example['text_instruction']}\n"
                f"Context: {mapping_str}\n"
                f"Next player to move: {next_player}. For instance, {explanation}\n"
                "What move should be made next?"
            )
        elif data_args.representation_mode == "move_seq_explained":
            explanation = get_move_explanation(target_move)
            input_text = (
                f"Board state:\n{example['ascii_board']}\n"
                f"Context: In this game, moves are encoded as numbers. Next player to move: {next_player}. {explanation}\n"
                "What is your move?"
            )
        elif data_args.representation_mode == "move_seq_special":
            mapping_str_special_token = (
                "Mapping of special move tokens to board positions:\n"
                "For Player 1 (X):\n"
                "  <move_1> = top-left, <move_2> = top-center, <move_3> = top-right,\n"
                "  <move_4> = middle-left, <move_5> = center, <move_6> = middle-right,\n"
                "  <move_7> = bottom-left, <move_8> = bottom-center, <move_9> = bottom-right.\n\n"
                "For Player 2 (O):\n"
                "  <move_10> = top-left, <move_11> = top-center, <move_12> = top-right,\n"
                "  <move_13> = middle-left, <move_14> = center, <move_15> = middle-right,\n"
                "  <move_16> = bottom-left, <move_17> = bottom-center, <move_18> = bottom-right.\n\n"
                "In this game, Player 1 (X) moves first, and moves are represented using special tokens. "
                "Player 2 (O) uses the same positions but with tokens indexed from 10 to 18."
            )

            input_text = (
                f"Board state:\n{example['ascii_board']}\n"
                f"Context: {mapping_str_special_token}\n"
                f"Next player to move: {next_player}. Use the special move tokens provided. What is your move?"
            )
        else:
            input_text = (
                f"Board state:\n{example['ascii_board']}\n"
                f"Next player to move: {next_player}. What is your move?"
            )
        return {"input_text": input_text, "target_text": str(target_move), "board": board, "next_player": next_player}
    else:
        return None

# -----------------------------------------------------------------------------
# Helper functions for board representations and move explanations
# -----------------------------------------------------------------------------

def board_to_ascii(board: List[int]) -> str:
    symbols = {0: '.', 1: 'X', 2: 'O'}
    rows = []
    for r in range(3):
        row_syms = [symbols[board[r*3 + c]] for c in range(3)]
        rows.append(" ".join(row_syms))
    return "\n".join(rows)

def board_to_text_instruction(board: List[int]) -> str:
    symbols = {0: 'empty', 1: 'X', 2: 'O'}
    rows = []
    for r in range(3):
        row_syms = [symbols[board[r*3 + c]] for c in range(3)]
        rows.append(f"Row {r}: " + ", ".join(row_syms) + ".")
    return " ".join(rows)

def get_move_explanation(move_token: int) -> str:
    """
    Provide a natural language explanation for a move token.
    For player 1 tokens (1–9) and player 2 tokens (10–18).
    """
    if 1 <= move_token <= 9:
        player = "Player 1 (X)"
        cell = move_token
    elif 10 <= move_token <= 18:
        player = "Player 2 (O)"
        cell = move_token - 9
    else:
        player = "Unknown"
        cell = move_token
    mapping = {
        1: "top-left",
        2: "top-center",
        3: "top-right",
        4: "middle-left",
        5: "center",
        6: "middle-right",
        7: "bottom-left",
        8: "bottom-center",
        9: "bottom-right"
    }
    pos = mapping.get(cell, "unknown position")
    explanation = f"{player} moves at {pos} (token {move_token})."
    return explanation

# TODO: Implement minimax based lookup table for the next best move.
def get_best_move(board: List[int], current_player: int) -> int:
    """
    Given a board and current player, return the 'best' move token.
    This function uses a simple heuristic for now: return the first legal move.
    """
    for idx in range(9):
        if board[idx] == 0:
            return (idx + 1) if current_player == 1 else (idx + 10)
    return -1  # Should not happen if board is non-terminal