import numpy as np
import torch

def pad_sequence(sequence, pad_token, target_length, side="left"):
    """
    Pads a sequence with the given pad_token and returns the padded sequence and attention mask.
    
    Args:
        sequence (list[int]): The input tokenized sequence.
        pad_token (int): The token ID used for padding.
        target_length (int): The desired length of the sequence.
    
    Returns:
        tuple: (padded_sequence, attention_mask), both as PyTorch tensors.
    """
    seq_len = len(sequence)
    assert side in ["left", "right"], "padding side must be either 'left' or 'right'"
    
    if seq_len >= target_length:
        # Truncate if the sequence is longer than target_length
        padded_sequence = sequence[-target_length:] # take last target_length tokens, trim beginning
        attention_mask = [1] * target_length # all tokens are valid
    else:
        # Pad if the sequence is shorter
        if side == "left":
            padded_sequence = [pad_token] * (target_length - seq_len) + sequence
            attention_mask = [0] * (target_length - seq_len) + [1] * seq_len
        elif side == "right":
            padded_sequence = sequence + [pad_token] * (target_length - seq_len)
            attention_mask = [1] * seq_len + [0] * (target_length - seq_len)
    
    return {"input_ids": torch.tensor(padded_sequence, dtype=torch.long), "attn_mask":torch.tensor(attention_mask, dtype=torch.long)}

# Example usage
if __name__ == "__main__": 
    sequence = [101, 1024, 2023, 2003, 1037, 2742]  # Example tokenized input
    pad_token = 50246  # Example padding token
    target_length = 10

    output = pad_sequence(sequence, pad_token, target_length, side="left")
    print("Using left padding:")
    print("Padded Sequence:", output["input_ids"])
    print("Attention Mask:", output["attn_mask"])

    output = pad_sequence(sequence, pad_token, target_length, side="right")
    print("Using right padding:")
    print("Padded Sequence:", output["input_ids"])
    print("Attention Mask:", output["attn_mask"])