import random
import re

import torch

# Initialize tokenizer and sample input


def get_negative(tokenizer, input_ids, prefix_len):
    # Convert tokens to text for easy manipulation
    prefix = input_ids[:prefix_len]
    cont = input_ids[prefix_len:]
    tokens = tokenizer.convert_ids_to_tokens(cont)

    # Function to generate carry-based hard negatives for a given number
    def generate_carry_negative(number):
        number = int(number)
        if number < 10:  # single-digit number
            return [number + 1, number - 1] if number > 0 else [number + 1]
        elif 10 <= number < 100:  # two-digit number
            if number % 10 == 9:
                return [number + 1]  # e.g., 39 -> 40
            elif number % 10 == 0:
                return [number - 1]  # e.g., 40 -> 39
            else:
                return [number + 1, number - 1]
        else:  # three or more digits
            if number % 100 == 99:
                return [number + 1]  # e.g., 399 -> 400
            elif number % 100 == 0:
                return [number - 1]  # e.g., 400 -> 399
            else:
                return [number + 1, number - 1]

    # Identify numbers and alter them to create hard negatives
    i = 0
    negatives = tokens.copy()
    while i < len(tokens):
        if tokens[i].isdigit():  # Start of a number
            # Gather consecutive digit tokens to form a full number
            digit_tokens = []
            while i < len(tokens) and tokens[i].isdigit():
                digit_tokens.append(tokens[i])
                i += 1

            if len(digit_tokens) < 2:
                # Skip single digit numbers as they are easy negatives
                continue

            # Join digits to form the original number and generate carry-based negatives
            original_number_str = "".join(digit_tokens)
            carry_negatives = generate_carry_negative(original_number_str)
            negative = random.choice(carry_negatives)  # Choose a random hard negative
            negative = list(str(negative))  # Convert the number to a list of digits

            # Replace the digit sequence with the new modified number split into individual digits
            negatives[i - len(digit_tokens) : i] = negative
        else:
            i += 1

    # Tokenize modified texts to get their vector representations
    negative_cont = tokenizer.convert_tokens_to_ids(negatives)
    negative_cont = torch.tensor(negative_cont).long()
    assert len(negative_cont) == len(cont), "Negative length mismatch"
    return torch.cat(
        [prefix, negative_cont], dim=0
    )  # Concatenate the prefix and modified text


if __name__ == "__main__":
    from transformers import AutoTokenizer
    from utils import encode

    tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct")
    prefix = "Add 0.123 to 0.566."
    cont = "Answer: 0.689"
    prefix_ids = encode(tokenizer, prefix)
    cont_ids = encode(tokenizer, cont)
    input_ids = torch.cat([prefix_ids, cont_ids], dim=0)
    prefix_len = len(prefix_ids)

    negatives = get_negative(tokenizer, input_ids, prefix_len)
    import ipdb; ipdb.set_trace()  # noqa # fmt: skip
