import logging
import os
import random
import re
from datetime import datetime
from dataclasses import dataclass
import string

from transformers.trainer_utils import get_last_checkpoint
from transformers import AutoTokenizer
from transformers import AutoModelForCausalLM
from datasets import load_dataset
from trl import GRPOConfig, GRPOTrainer, get_peft_config, ModelConfig, TrlParser

from grpo_data_util import generate_tictactoe_prompt

from torch.distributed.elastic.multiprocessing.errors import record

# import os
# os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

# -------------------------------
# Script arguments and logging
# -------------------------------
@dataclass
class ScriptArguments:
    train_dataset_path: str = None
    val_dataset_path: str = None
    test_dataset_path: str = None
    experiment_mode: str = "legal_move"  # Options: legal_move, best_move
    tokenizer_name_or_path: str = None
    representation_mode: str = "nl"      # Options: nl, special, etc.
    dataset_id_or_path: str = None
    dataset_splits: str = "train"         # Default split if using a HF dataset
    instruct_model: bool = False          # If True, use an instruction-style chat template

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# -------------------------------
# Reward Functions
# -------------------------------
def format_reward_func(prompts, completions, **kwargs):
    """
    Checks that the output contains exactly one <think>...</think> block and one <answer>...</answer> block,
    with the <think> block coming before the <answer> block.
    """
    rewards = []
    for completion in completions:
        # Find all think and answer tags.
        think_matches = re.findall(r"<think>(.*?)</think>", completion, re.DOTALL)
        answer_matches = re.findall(r"<answer>(.*?)</answer>", completion, re.DOTALL)
        # Check there is exactly one think and one answer and that the think text appears before answer text.
        if len(think_matches) == 1 and len(answer_matches) == 1:
            # To enforce order, we can check that the index of the first occurrence of <think> is before that of <answer>
            think_index = completion.find("<think>")
            answer_index = completion.find("<answer>")
            reward = 1.0 if think_index < answer_index else 0.0
        else:
            reward = 0.0
        rewards.append(reward)
    # Log an example from the batch.
    if prompts and completions:
        example_reward = rewards[0]
        logger.info("Format reward example:")
        logger.info("  Prompt: %s", prompts[0])
        logger.info("  Completion: %s", completions[0])
        logger.info("  Computed format reward: %s", example_reward)
    return rewards

def format_reward_func_updated(prompts, completions, tokenizer=None, **kwargs):
    """
    Checks that the output contains exactly one <think>...</think> block and one <answer>...</answer> block,
    with the <think> block coming before the <answer> block, and that the output ends with the EOS token.
    """
    eos_token = tokenizer.eos_token if tokenizer and tokenizer.eos_token is not None else ""
    rewards = []
    for completion in completions:
        # Find all <think> and <answer> blocks.
        think_matches = re.findall(r"<think>(.*?)</think>", completion, re.DOTALL)
        answer_matches = re.findall(r"<answer>(.*?)</answer>", completion, re.DOTALL)
        # Find the index of the closing </answer> tag.
        answer_close_index = completion.find("</answer>")
        # Get any text that comes after the closing </answer> tag.
        post_text = completion[answer_close_index + len("</answer>"):].strip() if answer_close_index != -1 else ""
        
        # Check that there is exactly one think and one answer, that they are in order,
        # and that the only content after </answer> is the EOS token.
        if (
            len(think_matches) == 1 and 
            len(answer_matches) == 1 and 
            completion.find("<think>") < completion.find("<answer>") and 
            post_text == eos_token
        ):
            reward = 1.0
        else:
            reward = 0.0
        rewards.append(reward)
    
    # Log an example for debugging.
    if prompts and completions:
        logger.info("Format reward example:")
        logger.info("  Prompt: %s", prompts[0])
        logger.info("  Completion: %s", completions[0])
        logger.info("  Computed format reward: %s", rewards[0])
    return rewards


def answer_length_reward_func(prompts, completions, **kwargs):
    """
    Rewards if the final answer (extracted from the last <answer>...</answer> tag) is concise.
    For NL mode, we allow up to a threshold (e.g. 3 tokens). This function does not assume the answer is a single token.
    """
    rewards = []
    threshold = 3  # Allow up to 3 tokens.
    for completion in completions:
        # Get all answer blocks and use the last one.
        answer_matches = re.findall(r"<answer>(.*?)</answer>", completion, re.DOTALL)
        if answer_matches:
            raw_answer = answer_matches[-1].strip()
            # Remove punctuation and extra whitespace.
            answer_clean = raw_answer.translate(str.maketrans("", "", string.punctuation)).strip()
            tokens = answer_clean.split()
            reward = 1.0 if 1 <= len(tokens) <= threshold else 0.0
        else:
            reward = 0.0
        rewards.append(reward)
    # Log an example.
    if prompts and completions:
        example_tokens = tokens if answer_matches else []
        logger.info("Answer length reward example:")
        logger.info("  Prompt: %s", prompts[0])
        logger.info("  Completion: %s", completions[0])
        logger.info("  Extracted answer: '%s'", raw_answer if answer_matches else "None")
        logger.info("  Tokenized answer: %s", example_tokens)
        logger.info("  Computed answer length reward: %s", rewards[0])
    return rewards


def legal_move_reward_func(prompts, completions, **kwargs):
    """
    Rewards if the first token (after cleaning) from the final <answer> tag is one of the allowed moves.
    If no legal moves are available (i.e. allowed_moves is empty), then reward if the answer matches the default:
      "None" for NL mode or "<move_null>" for special mode.
    Expects kwargs to contain 'allowed_moves' as a list of lists (one per sample) and 'representation_mode'.
    """
    rewards = []
    allowed_moves_all = kwargs.get("allowed_moves", None)
    rep_mode = kwargs.get("representation_mode", "nl")
    default_nl = "None"
    default_special = "<move_null>"
    for idx, completion in enumerate(completions):
        # Use the final answer block.
        answer_matches = re.findall(r"<answer>(.*?)</answer>", completion, re.DOTALL)
        if answer_matches:
            raw_answer = answer_matches[-1].strip()
            answer_clean = raw_answer.translate(str.maketrans("", "", string.punctuation)).strip()
            tokens = answer_clean.split()
            first_token = tokens[0] if tokens else ""
            if allowed_moves_all and isinstance(allowed_moves_all, list) and len(allowed_moves_all) > idx:
                allowed = allowed_moves_all[idx]
                # If allowed is empty, use default token based on representation mode.
                if not allowed:
                    allowed = [default_nl] if rep_mode == "nl" else [default_special]
                reward = 1.0 if first_token in allowed else 0.0
            else:
                reward = 0.0
        else:
            reward = 0.0
        rewards.append(reward)
    # Log an example.
    if prompts and completions:
        # For logging, use the first sample.
        answer_matches = re.findall(r"<answer>(.*?)</answer>", completions[0], re.DOTALL)
        if answer_matches:
            raw_answer = answer_matches[-1].strip()
            answer_clean = raw_answer.translate(str.maketrans("", "", string.punctuation)).strip()
            tokens = answer_clean.split()
            first_token = tokens[0] if tokens else ""
            example_allowed = allowed_moves_all[0] if (allowed_moves_all and len(allowed_moves_all) > 0) else []
            if not example_allowed:
                example_allowed = [default_nl] if rep_mode == "nl" else [default_special]
            example_reward = 1.0 if first_token in example_allowed else 0.0
        else:
            example_reward = 0.0
            first_token = ""
            example_allowed = []
        logger.info("Legal move reward example:")
        logger.info("  Prompt: %s", prompts[0])
        logger.info("  Completion: %s", completions[0])
        logger.info("  Extracted first token: '%s'", first_token)
        logger.info("  Allowed moves for sample: %s", example_allowed)
        logger.info("  Computed legal move reward: %s", example_reward)
    return rewards

def get_checkpoint(training_args: GRPOConfig):
    last_checkpoint = None
    if os.path.isdir(training_args.output_dir):
        last_checkpoint = get_last_checkpoint(training_args.output_dir)
    return last_checkpoint

# -------------------------------
# Main GRPO Training Function
# -------------------------------
def grpo_training_function(model_args: ModelConfig, script_args: ScriptArguments, training_args: GRPOConfig):
    logger.info(f"Model parameters: {model_args}")
    logger.info(f"Training parameters: {training_args}")

    # Load the tokenizer
    tokenizer = AutoTokenizer.from_pretrained(
        script_args.tokenizer_name_or_path if script_args.tokenizer_name_or_path else model_args.model_name_or_path,
        revision=model_args.model_revision,
        trust_remote_code=model_args.trust_remote_code,
    )
    
    if tokenizer.pad_token is None:
        # tokenizer.pad_token = tokenizer.eos_token
        # tokenizer.add_special_tokens({'pad_token': '[PAD]'})
        tokenizer.add_special_tokens({'pad_token': tokenizer.eos_token})
        print("Added a padding token: ", tokenizer.pad_token)
        logger.info(f"No pad token found; setting pad token to {tokenizer.pad_token}.")


    # If using special move tokens, add them to the tokenizer's vocabulary using the new method.
    if script_args.representation_mode == "special":
        special_tokens = [f"<move_{i}>" for i in range(1, 19)]
        # Check if tokens are already in the vocabulary.
        vocab = tokenizer.get_vocab()
        new_tokens = list(set(special_tokens) - set(vocab.keys()))
        if new_tokens:
            num_added_tokens = tokenizer.add_tokens(new_tokens)
            logger.info(f"Added {num_added_tokens} new tokens: {new_tokens}")
        else:
            logger.info("All special tokens already exist in the vocabulary.")

    # Load your tic-tac-toe dataset.
    if script_args.train_dataset_path and script_args.val_dataset_path:
        # dataset = load_dataset(script_args.dataset_id_or_path, split=script_args.dataset_splits)
        train_dataset = load_dataset("json",data_files=script_args.train_dataset_path, split="train")
        val_dataset = load_dataset("json",data_files=script_args.val_dataset_path, split="train")
    else:
        train_dataset = load_dataset("json", data_files=script_args.train_dataset_path, split="train")

    # Convert each sample into a prompt (and include allowed moves for reward function)
    def prompt_fn(x):
        return generate_tictactoe_prompt(x, script_args.representation_mode, script_args.instruct_model)
    train_dataset = train_dataset.map(prompt_fn)
    val_dataset = val_dataset.map(prompt_fn)

    print("Train dataset:", train_dataset)
    print("Validation dataset:", val_dataset)

    sample = train_dataset[0]
    print("Sample before transformation:", sample)
    print("Sample after transformation:", prompt_fn(sample))


    # Split dataset into train/test (e.g., 90/10 split)
    # split_dataset = dataset.train_test_split(test_size=0.1, seed=42)
    # train_dataset = split_dataset["train"]
    # test_dataset = split_dataset["test"]

    # List of reward functions
    reward_funcs = [format_reward_func_updated, answer_length_reward_func, legal_move_reward_func]
    # Our reward functions require the tokenizer for processing, and should contain the updated tokenizer with pad token and move tokens if needed.
    reward_processing_classes = [tokenizer, tokenizer, tokenizer] # must match the number of reward functions

    # Instantiate the GRPO trainer
    trainer = GRPOTrainer(
        model=model_args.model_name_or_path,
        reward_funcs=reward_funcs,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        peft_config=get_peft_config(model_args),
        processing_class = tokenizer,
        reward_processing_classes = reward_processing_classes
    )
    
    if "Qwen" in model_args.model_name_or_path:
        trainer.model.config.vocab_size = len(tokenizer)
    else:
        trainer.model.config.vocab_size = len(tokenizer) # added for special
        trainer.model.resize_token_embeddings(len(tokenizer))
    logger.info("Resized model token embeddings to match the updated vocabulary size after adding pad token and move token if needed.")
    
    # After the model is loaded in the trainer, resize its token embeddings if new tokens were added.
    # if script_args.representation_mode == "special":
    #     trainer.model.resize_token_embeddings(len(tokenizer))
    #     logger.info("Resized model token embeddings to match the updated vocabulary size.")

    last_checkpoint = get_checkpoint(training_args)
    if last_checkpoint is not None and training_args.resume_from_checkpoint is None:
        logger.info(f"Resuming from checkpoint: {last_checkpoint}")

    logger.info(f"*** Starting training at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} for {training_args.num_train_epochs} epochs ***")
    try:
        train_result = trainer.train(resume_from_checkpoint=last_checkpoint)
        metrics = train_result.metrics
        metrics["train_samples"] = len(train_dataset)
        trainer.log_metrics("train", metrics)
        trainer.save_metrics("train", metrics)
        trainer.save_state()
    except Exception as e:
        logger.error(f"Error during training:\n{e}")

    logger.info("*** Training complete ***")
    logger.info("*** Saving model ***")
    trainer.model.config.use_cache = True
    trainer.save_model(training_args.output_dir)
    logger.info(f"Model saved to {training_args.output_dir}")
    training_args.distributed_state.wait_for_everyone()
    tokenizer.save_pretrained(training_args.output_dir)
    logger.info(f"Tokenizer saved to {training_args.output_dir}")

    if trainer.accelerator.is_main_process:
        trainer.create_model_card({"tags": ["rl", "grpo", "tictactoe", script_args.representation_mode]})
    if training_args.push_to_hub:
        logger.info("Pushing model to hub...")
        trainer.push_to_hub()

    logger.info("*** Training complete! ***")

@record
def main():
    parser = TrlParser((ModelConfig, ScriptArguments, GRPOConfig))
    model_args, script_args, training_args = parser.parse_args_and_config()
    grpo_training_function(model_args, script_args, training_args)

if __name__ == "__main__":
    main()
