from transformers import PreTrainedTokenizerBase

from rllm.parser.chat_template.parser import ChatTemplateParser

def construct_prompt_for_tips(trajectory, last_reward):
    return ""

def remove_tips_from_messages(messages: list[dict[str, str]], prompt_ids: list[int]):
    """
    Removes tips from messages.
    """
    # Only remove from system message
    # Ensure the system message is the first message
    if messages[0]["role"] != "system" and "<tips>" not in messages[0]["content"]:
        return messages, prompt_ids
    
    messages[0]["content"] = messages[0]["content"].split("<tips>")[0]
    prompt_ids = None

    return messages, prompt_ids


def get_recent_assistant_user_messages(chat_completions_messages):
    """
    Extracts the most recent assistant message and environment messages (user/tool) from a chat completions list.

    Args:
        chat_completions_messages (List[Dict]): List of message dictionaries from chat completions.

    Returns:
        Tuple[Dict, List[Dict]]: A tuple containing:
            - The most recent assistant message (or None if not found)
            - A list of environment messages (user/tool) that occurred after the last assistant message,
              in chronological order.
    """
    # Loop backwards to get the last assistant message and environment messages
    env_messages = []
    assistant_message = None
    seen_assistant_message = False
    for message in reversed(chat_completions_messages):
        role = message.get("role", None)
        if role == "assistant":
            if assistant_message:
                break
            seen_assistant_message = True
            assistant_message = message
        elif role in ["user", "tool"] and not seen_assistant_message:
            env_messages.append(message)
    # Reverse the env_messages to maintain chronological order
    env_messages = list(reversed(env_messages))

    return assistant_message, env_messages


def convert_messages_to_tokens_and_masks(messages: list[dict[str, str]], tokenizer: PreTrainedTokenizerBase, parser: ChatTemplateParser, contains_first_msg=False, contains_generation_msg=False):
    """
    Converts multiple messages to tokens and masks.
    contains_first_msg flag and contains_generation_msg flag are used to indicate whether the conversation is for beginning or contains the generation.
    The first and last message is assumed to be the special message respectively

    Args:
        messages (List[Dict]): The messages to convert.
        tokenizer: The tokenizer to use.
        contains_first_msg (bool): Whether the first message is a special message.
        contains_generation_msg (bool): Whether the last message is a special message.

    Returns:
        Tuple[List[int], List[int]]: A tuple containing all tokens and all masks.
    """
    all_msg_tokens = []
    all_msg_masks = []

    def _convert_message_to_tokens_and_masks(msg, first_msg=False, generation_msg=False):
        msg_text = parser.parse([msg], add_generation_prompt=generation_msg, is_first_msg=first_msg)

        # Remove the assistant token since it is contained in previous message as generation prompt
        if msg["role"] == "assistant":
            assert msg_text.startswith(parser.assistant_token), f"Expected assistant token {parser.assistant_token} but got {msg_text}"
            msg_text = msg_text.replace(parser.assistant_token, "")

        msg_tokens = tokenizer.encode(msg_text, add_special_tokens=False)
        mask_value = 1 if msg["role"] == "assistant" else 0
        msg_mask = [mask_value] * len(msg_tokens)

        return msg_tokens, msg_mask

    for i, msg in enumerate(messages):
        msg_tokens, msg_mask = _convert_message_to_tokens_and_masks(msg, first_msg=(contains_first_msg and i == 0), generation_msg=(contains_generation_msg and i == len(messages) - 1))
        all_msg_tokens.extend(msg_tokens)
        all_msg_masks.extend(msg_mask)

    return all_msg_tokens, all_msg_masks


def remove_concepts_from_messages(prompt: str, prefix_token: str, suffix_token: str):
    """
    Removes concepts from messages.
    """
    prompt_prefix = prompt.split(prefix_token)[0]
    prompt_suffix = prompt.split(suffix_token)[1]
    return prompt_prefix + suffix_token + prompt_suffix


import difflib
from itertools import zip_longest

def highlight_changed_text(old_observation: str, new_observation: str) -> str:
    """
    1) Align old/new line-by-line using difflib (to reliably match inserted/deleted lines)
    2) For each corresponding line, compute the minimal diff using difflib so that only 'insert/replace' segments are marked in red
    3) Highlighting does not cross line boundaries, and deleted lines are not output (output is always based on new_observation)

    Returns: A string of new_observation with only changed sections marked in ANSI red (31)
    """
    RED = "<strong>"
    RESET = "</strong>"

    # Separate content and line ending (e.g., \n, \r\n) for line-ending preservation
    def split_ending(line: str):
        if line.endswith("\r\n"):
            return line[:-2], "\r\n"
        elif line.endswith("\n") or line.endswith("\r"):
            return line[:-1], line[-1]
        else:
            return line, ""

    # (old_line_content, new_line_content, new_line_ending) -> highlighted new line
    def highlight_line(old_line_content: str, new_line_content: str, new_line_end: str) -> str:
        # Minimum diff within line
        sm = difflib.SequenceMatcher(a=old_line_content, b=new_line_content, autojunk=False)
        out = []
        for tag, i1, i2, j1, j2 in sm.get_opcodes():
            if tag == "equal":
                out.append(new_line_content[j1:j2])
            elif tag in ("replace", "insert"):
                # Mark newly created/changed section in red
                out.append(RED)
                out.append(new_line_content[j1:j2])
                out.append(RESET)
            elif tag == "delete":
                # Characters only in old are not in new -> don't output
                pass
        out.append(new_line_end)  # preserve line ending
        return "".join(out)

    # First align by lines (minimal line diff)
    old_lines_raw = old_observation.splitlines(keepends=True)
    new_lines_raw = new_observation.splitlines(keepends=True)

    old_contents = []
    for ln in old_lines_raw:
        c, _ = split_ending(ln)
        old_contents.append(c)

    new_contents = []
    new_ends = []
    for ln in new_lines_raw:
        c, e = split_ending(ln)
        new_contents.append(c)
        new_ends.append(e)

    # Line sequence matching (properly recognize inserted/deleted lines)
    line_sm = difflib.SequenceMatcher(a=old_contents, b=new_contents, autojunk=False)

    out_lines = []
    for tag, i1, i2, j1, j2 in line_sm.get_opcodes():
        if tag == "equal":
            # Output exactly the same lines as they are
            for j in range(j1, j2):
                out_lines.append(new_contents[j] + new_ends[j])
        elif tag == "replace":
            # In section, line counts may differ, so do 1:1 pairing, use empty string if absent
            for old_line, new_line, new_end in zip_longest(
                old_contents[i1:i2], new_contents[j1:j2], new_ends[j1:j2], fillvalue=None
            ):
                if new_line is None:
                    # Line exists only in old and not in new -> no output (deleted line)
                    continue
                if old_line is None:
                    # Whole newly created line is marked as changed
                    out_lines.append("\033[31m" + new_line + "\033[0m" + (new_end or ""))
                else:
                    out_lines.append(highlight_line(old_line, new_line, new_end or ""))
        elif tag == "insert":
            # Line exists only in new -> mark the whole line as changed
            for j in range(j1, j2):
                out_lines.append("\033[31m" + new_contents[j] + "\033[0m" + new_ends[j])
        elif tag == "delete":
            # Line exists only in old -> no output
            continue

    return "".join(out_lines)

def task_type_matching(sub_task_type: str, difficulty_level: str) -> str:
    """
    Matches the sub task type and difficulty level to a task type.
    """
    SUB_TASK_TYPE_MAP = {
        "standard_sudoku": 10,
        "ctc_sudoku": 11,
        "default": 99
    }
    DIFFICULTY_LEVEL_MAP = {
        "easy": 1,
        "medium": 2,
        "hard": 3,
        "extremely_hard": 4,
        "default": 9
    }
    if sub_task_type not in SUB_TASK_TYPE_MAP:
        sub_task_type = "default"
    if difficulty_level not in DIFFICULTY_LEVEL_MAP:
        difficulty_level = "default"
        
    return SUB_TASK_TYPE_MAP[sub_task_type] * 10 + DIFFICULTY_LEVEL_MAP[difficulty_level]