import torch
import torch.nn as nn


class VectorToContinuous(nn.Module):
    def __init__(self, vector_length, num_functions, amplitude=1.0, phase=0.0):
        """
        Initialize the converter with the vector length, number of sine functions, and constant amplitude and phase.
        :param vector_length: Length of the input binary vectors.
        :param num_functions: Number of sine functions to use.
        :param amplitude: Constant amplitude for all sine functions.
        :param phase: Constant phase for all sine functions.
        """
        super(VectorToContinuous, self).__init__()
        self.vector_length = vector_length
        self.num_functions = num_functions
        self.amplitude = amplitude
        self.phase = phase

        # Initialize a learnable weight matrix
        self.weights = nn.Parameter(torch.randn(num_functions, vector_length))

    def forward(self, batch):
        """
        Convert a batch of binary vectors to multiple continuous values using different sine functions.
        :param batch: A batch of binary vectors (2D tensor) to be converted.
        :return: A 2D tensor of continuous values.
        """
        # Ensure batch is a float tensor
        batch = batch.float()

        # Compute weighted sum for each vector in the batch using the weight matrix
        weighted_sums = torch.matmul(self.weights, batch.T).T + self.phase

        # Apply sine function and scale by amplitude
        return self.amplitude * torch.sin(weighted_sums)


if __name__ == "__main__":
    # Example usage
    vector_length = 5
    num_functions = 3
    converter = VectorToContinuous(vector_length, num_functions)

    # Batch of binary vectors
    batch_binary_vectors = torch.tensor([[0, 1, 0, 1, 1], [1, 0, 1, 0, 1], [1, 1, 1, 1, 1]])
    continuous_values_batch = converter(batch_binary_vectors)
    print(continuous_values_batch)
