import os
import random
import numpy as np
import pandas as pd
import torch
from datasets import Dataset, load_dataset
from reward_func import extract_hash_answer


TRAINER_TYPE = "b1_wll"


def set_trainer_type(t):
    global TRAINER_TYPE, SYSTEM_PROMPT, CTD_SYSTEM_PROMPT, SUDOKU_SYSTEM_PROMPT, GSM_SYSTEM_PROMPT, MATH_SYSTEM_PROMPT
    TRAINER_TYPE = t

    # Baseline prompts do not need \\block for training
    if (
        TRAINER_TYPE != "b1_wll"
        and TRAINER_TYPE != "b1_d1"
        and TRAINER_TYPE != "b1_gdpo"
    ):

        GSM_SYSTEM_PROMPT = (
            MATH_SYSTEM_PROMPT
        ) = """You are a math expert. You will be given a question to solve. You should provide clear, logical reasoning step by step. Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""

        CTD_SYSTEM_PROMPT = (
            "Using only the provided numbers, create an arithmetic expression that evaluates to exactly the provided target number. You may use the operations +, -, *, and / as needed, but each number must be used exactly once. Think step-by-step. After reasoning, provide only your final expression inside \\boxed"
            + "{}"
            + " tags without including an equals sign or the target number. For example: <answer>a + b * c</answer>"
            + """Respond in the following format:
<reasoning>
Your reasoning here
</reasoning>
<answer>
...
</answer>"""
        )

        SUDOKU_SYSTEM_PROMPT = """Please solve the following 4x4 Sudoku puzzle. The puzzle is provided as a 16-character string reading left-to-right, top-to-bottom, where '0' represents empty cells.

Rules:
- Fill empty cells with digits 1-4
- Each row must contain digits 1-4 exactly once
- Each column must contain digits 1-4 exactly once
- Each 2x2 box must contain digits 1-4 exactly once

Important: Your solution must be a COMPLETE 16-character string with only the digits 1-4, representing your final solved grid.

Respond in this exact format:
<reasoning>
Your step-by-step solving process
</reasoning>
<answer>
[16-character solution string with no spaces or separators]
</answer>
"""

    ######################################################################
    # Dynamic generation requires end of block marker \\block for training
    else:
        SYSTEM_PROMPT = """You are a math expert. You will be given a question to solve. You should provide clear, logical reasoning step by step. Append the tag \\block directly to the end of the last sentence of each reasoning step without starting a new line. Respond exactly in the following format:
<reasoning>
Step 1, ... \\block
Step 2, ... \\block
...
Step n, ... \\block
</reasoning>
<answer>
FINAL ANSWER ONLY
</answer>
"""
        MATH_SYSTEM_PROMPT = GSM_SYSTEM_PROMPT = SYSTEM_PROMPT

        CTD_SYSTEM_PROMPT = """You are a mathematical expert solving Countdown questions.
Your Task: Given a list of numbers and a target integer, construct an arithmetic expression that evaluates exactly to the target.

STRICT CONSTRAINTS ON NUMBER USAGE:
1. **NO EXTERNAL NUMBERS**: You must use ONLY and ALL the numbers provided in the input list. Do not introduce any other integers (e.g., do not use 1 or 2 unless they are explicitly in the provided list).
2. **EXACTLY ONCE**: You must use EACH number from the provided list EXACTLY ONCE. You cannot skip any number, and you cannot reuse any number multiple times.
3. **NO SPACES OR SEPARATORS**: You may use the operations +, -, *, and / as needed. After reasoning, provide only your final expression inside <answer></answer> tags without including an equals sign or the target number. For example, if the numbers are [2, 3, 4] and the target is 5, a valid answer is: <answer>\n2*4-3\n</answer>"

You should provide clear, logical reasoning step by step. Append the tag \\block directly to the end of the last sentence of each reasoning step without starting a new line. Respond exactly in the following format:
<reasoning>
Step 1, ... \\block
...
Step n, ... \\block
</reasoning>
<answer>
...
</answer>"""

        SUDOKU_SYSTEM_PROMPT = """Please solve the following 4x4 Sudoku puzzle. The puzzle is provided as a 16-character string reading left-to-right, top-to-bottom, where '0' represents empty cells.
RULES:
- Fill empty cells with digits 1-4
- Each row must contain digits 1-4 exactly once
- Each column must contain digits 1-4 exactly once
- Each 2x2 box must contain digits 1-4 exactly once

STRICT OUTPUT FORMAT:
- The solution inside the <answer> tag must be a SINGLE CONTINUOUS STRING of EXACTLY 16 DIGITS, not 15 or 17.
- DO NOT use any spaces, newlines, commas, or grid visualization inside the <answer> tag.

You should provide clear, logical reasoning step by step. Append the tag \\block directly to the end of the last sentence of each reasoning step without starting a new line. Respond exactly in the following format:
<reasoning>
Step 1, ... \\block
...
Step n, ... \\block
</reasoning>
<answer>
[SINGLE CONTINUOUS SOLUTION STRING of EXACTLY 16 DIGITS]
</answer>
"""


# Initialize with default trainer type
set_trainer_type("b1_wll")


def set_random_seed(seed: int = 42):
    # Set the seed for Python's built-in random module
    random.seed(seed)
    # Set the seed for NumPy
    np.random.seed(seed)
    # Set the seed for PyTorch
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    # Ensure deterministic behavior in cuDNN (may impact performance)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def get_gsm8k_questions(split="train") -> Dataset:
    data = load_dataset("openai/gsm8k", "main")[split]
    return data.map(
        lambda x: {
            "prompt": [
                {"role": "user", "content": f"{GSM_SYSTEM_PROMPT}\n{x['question']}"},
            ],
            "answer": extract_hash_answer(x["answer"]),
        }
    )


def get_countdown_questions(split="train") -> Dataset:
    data = load_dataset("Jiayi-Pan/Countdown-Tasks-3to4", split=split)
    data = data.filter(lambda x: len(x["nums"]) == 3)

    return data.map(
        lambda x: {
            "prompt": [
                {
                    "role": "user",
                    "content": f"{CTD_SYSTEM_PROMPT}\nThe provided input list is {x['nums']} and the target number is {x['target']}.\nPlease solve it according to the strict constraints above.",
                },
            ],
            "target": x["target"],
            "numbers": x["nums"],
        }
    )


def get_sudoku_questions() -> Dataset:
    """Load the Sudoku dataset for training or evaluation."""
    cur_path = os.path.dirname(os.path.abspath(__file__))
    sudoku_file_path = "../dataset/4x4_sudoku_unique_puzzles.csv"
    sudoku_file_path = os.path.join(cur_path, sudoku_file_path)
    df = pd.read_csv(sudoku_file_path, dtype={"Puzzle": str, "Solution": str})
    data = Dataset.from_pandas(df)

    return data.map(
        lambda x: {
            "prompt": [
                {
                    "role": "user",
                    "content": f"{SUDOKU_SYSTEM_PROMPT}\nSolve the following Sudoku puzzle: {x['Puzzle']}\n",
                },
            ],
            "puzzle": x["Puzzle"],
            "solution": x["Solution"],
        }
    )


def get_math_questions(split="train") -> Dataset:
    data = load_dataset("ankner/math-500", split=split)
    data = data.map(
        lambda x: {
            "prompt": [
                {
                    "role": "user",
                    "content": f"{MATH_SYSTEM_PROMPT}\n{x['problem']}",
                },
            ],
            "answer": x["solution"],
        }
    )
    return data
