#!/usr/bin/env python
"""
sft.py – A general-purpose SFT training script for tic-tac-toe next-move generation.
This script supports multiple experiments:
  - Experiment mode: "legal_move" (use last move token from game sequences)
                     "best_move" (use a lookup/minimax-derived optimal move)
  - Representation mode: "ascii", "natural", "move_seq_explained", or "move_seq_special"
  - Expects pre-split train, validation, and test datasets (as JSON files)
  - Support for both instruction-finetuned and non-instruction-finetuned models
  - Model families: e.g. GPT-2 and T5 (starting with the smallest models)
  - Distributed training via DeepSpeed is supported via command-line arguments
  - Debug logging for a few training and evaluation samples via wandb and stdout
  - Additional metric (accuracy) computation during evaluation
  - Memory efficient training enabled with use_liger=True
"""

import json
import random
from dataclasses import dataclass, field
from typing import Optional, Dict, List
import logging
import traceback

import torch
from datasets import Dataset
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    AutoConfig,
    HfArgumentParser,
    TrainerCallback,
    EarlyStoppingCallback
)
from trl import SFTTrainer, SFTConfig
from data_utils import *

# Try importing wandb. If not available, we simply skip wandb logging.
try:
    import wandb
except ImportError:
    wandb = None


# -----------------------------------------------------------------------------
# Custom Callbacks
# -----------------------------------------------------------------------------

class MemoryClearCallback(TrainerCallback):
    def on_step_end(self, args, state, control, **kwargs):
        torch.cuda.empty_cache()

class SampleLoggingCallback(TrainerCallback):
    """
    Logs the first few training and evaluation samples.
    If wandb is available, also log them there.
    """
    def __init__(self, tokenizer, train_dataset, eval_dataset, formatting_func, num_samples=5):
        self.tokenizer = tokenizer
        self.train_dataset = train_dataset
        self.eval_dataset = eval_dataset
        self.formatting_func = formatting_func
        self.num_samples = num_samples

    def on_train_begin(self, args, state, control, **kwargs):
        # Log dataset statistics when training starts and wandb is initialized
        if wandb is not None:
            try:
                print("Dataset stats after processing")
                print(wandb_dataset_stats)
                print(f"Successfully logged dataset stats to wandb.")
            except Exception as e:
                print(f"Failed to log dataset to wandb: {e}")

        print("=== First few training samples ===")
        for i in range(min(self.num_samples, len(self.train_dataset))):
            sample = self.train_dataset[i]
            formatted = self.formatting_func(sample)
            print(f"Train sample {i}:\n{formatted}\n")
            if wandb is not None:
                wandb.log({f"train_sample_{i}": formatted})

    def on_evaluate(self, args, state, control, **kwargs):
        print("=== First few evaluation samples ===")
        if self.eval_dataset is not None:
            for i in range(min(self.num_samples, len(self.eval_dataset))):
                sample = self.eval_dataset[i]
                formatted = self.formatting_func(sample)
                print(f"Eval sample {i}:\n{formatted}\n")
                if wandb is not None:
                    wandb.log({f"eval_sample_{i}": formatted})

# -----------------------------------------------------------------------------
# Dataset Preparation Functions
# -----------------------------------------------------------------------------

# Global variable to store dataset stats before wandb is initialized
wandb_dataset_stats = {}

# def prepare_dataset(raw_data: List[Dict], data_args: DataArguments, type: str) -> Dataset:
#     # Store dataset info globally for later logging, 
#     # this is done because wandb logger is initialized after dataset processing, when training begins
#     global wandb_dataset_stats 

#     flattened_examples = []

#     for example in raw_data:
#         processed_output = preprocess_example(example, data_args)
#         if processed_output:  # Only add if not None
#             if isinstance(processed_output, list):
#                 flattened_examples.extend(processed_output)
#             else:
#                 flattened_examples.append(processed_output)
    
#     # Convert to Hugging Face dataset
#     dataset = Dataset.from_list(flattened_examples)

#     # Log the processed dataset to wandb (only if wandb is available)
#     # Store dataset stats for later logging
#     wandb_dataset_stats[type] = {
#         "raw_dataset_len": len(raw_data),
#         "total_examples": len(flattened_examples),
#     }

#     return dataset

def prepare_dataset(raw_data: List[Dict], data_args: DataArguments, type: str) -> Dataset:
    dataset, stats = process_raw_data(raw_data, data_args)
    global wandb_dataset_stats
    wandb_dataset_stats[type] = stats
    return dataset


# Define the special token mapping string.
mapping_str_special_token_test = (
    "Mapping of special move tokens to board positions:\n"
    "For Player 1 (X):\n"
    "  <move_1> = top-left, <move_2> = top-center, <move_3> = top-right,\n"
    "  <move_4> = middle-left, <move_5> = center, <move_6> = middle-right,\n"
    "  <move_7> = bottom-left, <move_8> = bottom-center, <move_9> = bottom-right.\n\n"
    "For Player 2 (O):\n"
    "  <move_10> = top-left, <move_11> = top-center, <move_12> = top-right,\n"
    "  <move_13> = middle-left, <move_14> = center, <move_15> = middle-right,\n"
    "  <move_16> = bottom-left, <move_17> = bottom-center, <move_18> = bottom-right.\n\n"
    "In this game, Player 1 (X) moves first, and moves are represented using special tokens. "
    "Player 2 (O) uses the same positions but with tokens indexed from 10 to 18."
)

# -----------------------------------------------------------------------------
# Main Training Function
# -----------------------------------------------------------------------------

def main():
    # Parse arguments.
    parser = HfArgumentParser((ModelArguments, DataArguments, TTTTrainingArguments))
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()

    # If using memory-efficient training, note it.
    if training_args.use_liger:
        print("Using Liger for memory efficient training.")

    # Load the pre-split datasets.
    with open(data_args.train_dataset_path, "r") as f:
        train_raw = json.load(f)
    with open(data_args.val_dataset_path, "r") as f:
        val_raw = json.load(f)
    test_raw = None
    if data_args.test_dataset_path is not None:
        with open(data_args.test_dataset_path, "r") as f:
            test_raw = json.load(f)

    train_dataset = prepare_dataset(train_raw, data_args, type="train")
    val_dataset = prepare_dataset(val_raw, data_args, type="val")
    test_dataset = prepare_dataset(test_raw, data_args, type="test") if test_raw is not None else None

    # Load tokenizer.
    tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, use_fast=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    # For special move tokens representation, add tokens to vocabulary.
    if data_args.representation_mode == "move_seq_special":
        special_tokens = [f"<move_{i}>" for i in range(1, 19)]
        tokenizer.add_tokens(special_tokens)

        # Print the original prompt (Before tokenization)
        print("Original Prompt:\n", mapping_str_special_token_test)

        # Tokenize the prompt with special tokens enabled
        encoded = tokenizer.encode(mapping_str_special_token_test)
        print("\nTokenized IDs:\n", encoded)

        # Detokenize the encoded prompt
        decoded = tokenizer.decode(encoded)
        print("\nDecoded Prompt:\n", decoded)

        # Tokenize the prompt with special tokens enabled
        encoded = tokenizer.encode("<move_1>")
        print("\nTokenized ID for <move_1>:\n", encoded)

        # Detokenize the encoded prompt
        decoded = tokenizer.decode(encoded)
        print("\nDecoded Prompt:\n", decoded)

    # Load (or initialize) model.
    if model_args.use_pretrained:
        model = AutoModelForCausalLM.from_pretrained(
            model_args.model_name_or_path,
            trust_remote_code=True,
            torch_dtype=torch.float16,
        )
    else:
        config = AutoConfig.from_pretrained(model_args.model_name_or_path)
        model = AutoModelForCausalLM.from_config(config)

    # Resize model embeddings if new tokens were added.
    if data_args.representation_mode == "move_seq_special":
        model.resize_token_embeddings(len(tokenizer))

    # Define formatting function for SFTTrainer.
    def formatting_func(example):
        # Append the EOS token at the end of the target_text.
        # Ensure that tokenizer.eos_token is defined.
        eos_token = tokenizer.eos_token if tokenizer.eos_token is not None else ""
        # For instruction-finetuned models, format as a conversational prompt.
        target_text = convert_target_tenxt_to_representation(example["target_text"], data_args.representation_mode)
        if model_args.instruction_model:
            messages = [
                {"role": "system", "content": SYSTEM_MESSAGE_INSTRUCTION_MODEL},
                {"role": "user", "content": example["input_text"]},
                {"role": "assistant", "content": target_text},
            ]
            return tokenizer.apply_chat_template(messages, add_generation_prompt=False, tokenize=False)
        else:
            # For standard models, simply concatenate prompt and answer.
            return example["input_text"] + target_text + eos_token

    # Define a compute_metrics function for evaluation.
    # TODO: Fix this? Not sure why evaluation time memory cost is so huge!!!
    def compute_metrics(eval_preds):
        preds, labels = eval_preds

        # Convert to list if they are numpy arrays.
        if hasattr(preds, "tolist"):
            preds = preds.tolist()
        if hasattr(labels, "tolist"):
            labels = labels.tolist()

        def recursive_flatten(x):
            """
            Recursively flattens a nested list so that the final result is a flat list of ints.
            """
            if isinstance(x, list):
                result = []
                for item in x:
                    result.extend(recursive_flatten(item))
                return result
            else:
                return [x]

        # Apply recursive flattening for each prediction and label.
        preds = [recursive_flatten(p) for p in preds]
        labels = [recursive_flatten(l) for l in labels]

        try:
            # Log the shape and content of the first prediction for debugging.
            print(f"After flattening, first prediction: {preds[0]}")
            print(f"After flattening, first label: {labels[0]}")

            # Decode predictions and labels.
            decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
            decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

            acc_list = []
            for pred, label in zip(decoded_preds, decoded_labels):
                try:
                    # Extract the answer portion from both strings.
                    target = label.split("Answer:")[-1].strip()
                    output = pred.split("Answer:")[-1].strip()
                    acc_list.append(1 if target == output else 0)
                except Exception as inner_err:
                    print(f"Inner error while comparing predictions: {inner_err}")
                    traceback.print_exc()

            accuracy = sum(acc_list) / len(acc_list) if acc_list else 0

            # Log debugging info to wandb if available.
            if wandb is not None:
                wandb.log({
                    "debug/flattened_preds": preds[:5],
                    "debug/flattened_labels": labels[:5],
                    "debug/decoded_preds": decoded_preds[:5],
                    "debug/decoded_labels": decoded_labels[:5],
                    "debug/accuracy": accuracy
                })

            return {"accuracy": accuracy}

        except Exception as e:
            print("=== ERROR in compute_metrics ===")
            traceback.print_exc()
            # Log error details to wandb.
            if wandb is not None:
                wandb.log({
                    "error/exception": str(e),
                    "error/flattened_preds": preds[:5],
                    "error/flattened_labels": labels[:5]
                })
            return {"accuracy": 0.0}


    early_stop = EarlyStoppingCallback(
        early_stopping_patience=3, early_stopping_threshold=0.0001
    )

    # Initialize SFTTrainer.
    trainer = SFTTrainer(
        model=model,
        tokenizer=tokenizer,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        formatting_func=formatting_func,
        # compute_metrics=compute_metrics,
        # callbacks=[early_stop],
    )

    # Add sample logging callback to help debug the first few samples.
    trainer.add_callback(SampleLoggingCallback(tokenizer, train_dataset, val_dataset, formatting_func, num_samples=5))

    # Begin training.
    trainer.train()
    trainer.save_model(training_args.output_dir)
    print(f"Model weights saved to {training_args.output_dir}")

if __name__ == "__main__":
    main()
