import torch


def get_cont_distances(tensor1, tensor2):
    tensor1 = tensor1.long()
    tensor2 = tensor2.long()
    outs = []
    for i in range(len(tensor1)):
        out = get_cont_distance(tensor1[i], tensor2[i])
        outs.append(out)
    return torch.stack(outs)


def get_cont_distance(tensor1, tensor2):

    # Helper function to extract numbers and their positions
    def extract_numbers_and_positions(tensor):
        numbers = []
        positions = []
        current_number = ""
        current_positions = []

        for i, token in enumerate(tensor):
            if token != -1:  # It's a digit
                current_number += str(token.item())
                current_positions.append(i)
            else:  # Non-digit or end of sequence
                if current_number:
                    numbers.append(int(current_number))
                    positions.append(current_positions[:])  # Copy the current positions
                    current_number = ""
                    current_positions = []

        # Append any remaining number if we ended on a digit sequence
        if current_number:
            numbers.append(int(current_number))
            positions.append(current_positions)

        return numbers, positions

    # Extract numbers and their positions from both tensors
    numbers1, positions1 = extract_numbers_and_positions(tensor1)
    numbers2, positions2 = extract_numbers_and_positions(tensor2)

    # Initialize distance tensor
    distance = torch.zeros_like(tensor1)

    # Calculate the distance for each pair of numbers and distribute over positions
    for num1, num2, pos_list in zip(numbers1, numbers2, positions1):
        if pos_list in positions2:  # Ensure alignment in positions
            # Calculate the difference and split into digits
            diff = abs(num1 - num2)
            diff_digits = list(
                str(diff).zfill(len(pos_list))
            )  # Pad the difference to match position length

            # Assign each digit of the difference to the corresponding token position
            for idx, digit in zip(pos_list, diff_digits):
                distance[idx] = int(digit)

    return distance


if __name__ == "__main__":
    tensor1 = torch.tensor([-1, -1, 2, 0, -1, -1, 3, 1, -1, -1], dtype=torch.int)
    tensor2 = torch.tensor([-1, -1, 1, 9, -1, -1, 4, 3, -1, -1], dtype=torch.int)
    print(get_cont_distance(tensor1, tensor2))
