import torch

# Configuration
modulus = 5
# Size = Padding(1) + Digits(5) + Ops/Brackets(6) = 12
vocab_size = modulus + 7
# ops: [pad]()+-*=


def decode(sample, modulus=5):
    """
    Decodes the output sample (Infix = Postfix).
    """
    input_seq, target_seq = sample  # target_seq is now a list, not a scalar

    # Create reverse mapping (Index -> Char)
    vocab = {
        0: "<PAD>",
        modulus + 1: "+",
        modulus + 2: "-",
        modulus + 3: "*",
        modulus + 4: "(",
        modulus + 5: ")",
        modulus + 6: "=",
    }
    for i in range(modulus):
        vocab[i + 1] = str(i)

    # Helper to decode list of indices
    def ids_to_str(ids):
        chars = []
        for idx in ids:
            if isinstance(idx, torch.Tensor):
                idx = idx.item()
            chars.append(vocab.get(idx, "?"))
        return "".join(chars)

    # Decode Infix part
    infix_str = ids_to_str(input_seq)

    # Decode Postfix part
    postfix_str = ids_to_str(target_seq)

    # Return readable format
    return f"{infix_str}={postfix_str}"


def generate_sample(min_length, max_length, seed=None, mult=True):
    """
    Generates an Infix expression and its corresponding Postfix (RPN) form.
    Restricted to ODD lengths only to avoid unary minus (negative numbers).
    """
    if seed is not None:
        torch.manual_seed(seed)

    if min_length > max_length:
        raise ValueError("min_length must be less than or equal to max_length")

    # Correction: Force lengths to be ODD.
    # Even lengths require unary minus (e.g., "-3"), which we are removing.
    if min_length % 2 == 0:
        min_length += 1
    if min_length > max_length:
        # If bumping min made it larger than max, bump max too or error.
        # Here we imply a small valid range exists.
        max_length = min_length

    def gen_terminal():
        """Generates a random terminal value."""
        value = torch.randint(0, modulus, (1,)).item()
        return str(value), str(value)  # (Infix, Postfix)

    def generate_expression(length):
        """
        Recursively generates (Infix, Postfix) pair.
        """
        # --- BASE CASES (Odd Only) ---
        if length == 1:
            return gen_terminal()

        elif length == 3:
            # Infix: "(3)", Postfix: "3"
            # Note: Parentheses in Infix are structural.
            # In Postfix/RPN, brackets are usually implicit/removed.
            term_str, post_str = gen_terminal()
            return f"({term_str})", post_str

        # Removed Length 2 "-A" and Length 4 "(-A)" to ban negative numbers

        # --- RECURSIVE STEP (Lengths 5, 7, 9...) ---
        # A binary op adds 3 chars: "(" + left + op + right + ")" => L + R + 3
        # Since 'length' is Odd, and '3' is Odd, (L + R) must be Even.
        # Since L and R must be valid expressions (Odd), L+R = Odd+Odd = Even.
        # So we must sample Odd splits.

        # Range for left_length:
        # Min: 1
        # Max: length - 3 - 1 (Right must be at least 1)

        high = length - 3
        if high < 1:
            raise ValueError(f"Cannot generate binary op for length {length}")

        # Sample from 1, 3, 5 ... high
        # We sample an integer 'k' and map it to 2k+1
        # 2k+1 <= high  =>  2k <= high-1  =>  k <= (high-1)//2

        max_k = (high - 1) // 2
        k = torch.randint(0, max_k + 1, (1,)).item()
        left_length = 2 * k + 1

        right_length = length - (left_length + 3)

        left_infix, left_postfix = generate_expression(left_length)
        right_infix, right_postfix = generate_expression(right_length)

        # Sample operator
        # 0:+, 1:-, 2:*
        max_op = 3 if mult else 2
        op_idx = torch.randint(0, max_op, (1,)).item()

        ops = ["+", "-", "*"]
        op_char = ops[op_idx]

        # Infix: ( Left op Right )
        # Postfix: Left Right op
        new_infix = "(" + left_infix + op_char + right_infix + ")"
        new_postfix = left_postfix + right_postfix + op_char

        return new_infix, new_postfix

    # Generate random ODD length
    # Adjust range to step by 2 or retry until odd.
    # Since we fixed min_length/max_length logic above, we might still pick even if range is like 10-12?
    # Safer to map a linear sample to odd numbers.

    # Calculate number of odd integers in range [min, max]
    # This logic can be tricky, let's just rejection sample or strict map for simplicity.
    # Simple map:
    valid_lengths = [l for l in range(min_length, max_length + 1) if l % 2 != 0]
    if not valid_lengths:
        # Fallback if user gave bad even range
        length = min_length if min_length % 2 != 0 else min_length + 1
    else:
        idx = torch.randint(0, len(valid_lengths), (1,)).item()
        length = valid_lengths[idx]

    infix_str, postfix_str = generate_expression(length)

    # --- Vocabulary Mapping ---
    vocab = {
        "+": modulus + 1,
        "-": modulus + 2,
        "*": modulus + 3,
        "(": modulus + 4,
        ")": modulus + 5,
        "=": modulus + 6,
    }
    for i in range(modulus):
        vocab[str(i)] = i + 1

    # Encode
    input_sequence = [vocab[char] for char in infix_str]
    # input_sequence.append(vocab["="])

    target_sequence = [vocab[char] for char in postfix_str]

    return input_sequence, target_sequence


def preprocess_data(sample):
    """
    Constructs a Training Tensor with [Input + SEP + Dummies] structure.

    Input:  [Infix...] + [=] + [PAD] * len(Postfix)
    Target: [PAD...]   + [PAD] + [Postfix...]
    Mask:   [0...]     + [0]   + [1...]
    """
    infix_seq, postfix_seq = sample

    # Define special tokens based on our vocab
    # 0 is PAD, modulus+7 is '=' (used as SEP)
    PAD_IDX = 0
    SEP_IDX = modulus + 7

    # 1. Calculate needed dummies
    # We need enough empty slots for the model to write the Postfix expression
    num_dummies = len(postfix_seq)

    # 2. Construct Input: [Infix] + [=] + [DUMMY] * len(Postfix)
    # The model sees the equation and a string of 0s to overwrite
    input_list = infix_seq + [SEP_IDX] + [PAD_IDX] * num_dummies

    # 3. Construct Target: [IGNORE] * len(Prompt) + [Postfix]
    # We ignore the loss for the Infix and Separator parts
    prompt_len = len(infix_seq) + 1  # +1 for SEP
    target_list = [PAD_IDX] * prompt_len + postfix_seq

    # 4. Construct Mask
    # 0 (False) for prompt, 1 (True) for result
    mask = torch.zeros(len(input_list), dtype=torch.bool)
    mask[prompt_len:] = True

    return (
        torch.tensor(input_list, dtype=torch.long),
        torch.tensor(target_list, dtype=torch.long),
        mask,
    )
