import torch
from torch import nn

class CopyEncoder(nn.Module):
    def __init__(
        self,
        seq_length: int,
        **kwargs
    ):
        super(CopyEncoder, self).__init__()
        self.seq_length = seq_length
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x.repeat(self.seq_length, *[1 for _ in range(len(x.shape))])
    