import torch
import torch.nn as nn
import torch.nn.functional as F

# Defined as in Equation (2) of https://arxiv.org/pdf/2212.07677.pdf
# config will have the following information:
# - config.device
# - config.num_heads
# - config.head_dim
# - config.point_dim
# - config.query_dim (query and key share same dimension)

class LinearSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.query_layer = nn.Linear(config.point_dim, config.query_dim * config.num_heads, device=config.device)
        self.key_layer = nn.Linear(config.point_dim, config.query_dim * config.num_heads, device=config.device)
        self.value_layer = nn.Linear(config.point_dim, config.head_dim * config.num_heads, device=config.device)
        self.projection = nn.Parameter(torch.randn(config.num_heads, config.head_dim, config.point_dim, device=config.device))

    # Assume x has shape (batch_size, sequence_length, point_dim)
    # Also assume num_heads * head_dim = point_dim
    def forward(self, x):
        batch_size = x.size(dim=0)
        sequence_length = x.size(dim=1)
        point_dim = x.size(dim=2)

        # 1. Pass through query layer and reshape
        x_query = self.query_layer(x) # (batch_size, sequence_length, config.query_dim * config.num_heads)
        x_query = x_query.view(batch_size, sequence_length, self.config.num_heads, self.config.query_dim)

        # 2. Pass through key layer and reshape
        x_key = self.key_layer(x) # (batch_size, sequence_length, config.query_dim * config.num_heads)
        x_key = x_key.view(batch_size, sequence_length, self.config.num_heads, self.config.query_dim)

        # 3. Pass through value layer and reshape
        x_value = self.value_layer(x) # (batch_size, sequence_length, config.head_dim * config.num_heads)
        x_value = x_value.view(batch_size, sequence_length, self.config.num_heads, self.config.head_dim)

        # 4. Perform matrix multiplication to get weighted averages
        x_query = torch.transpose(x_query, 1, 2) # (batch_size, num_heads, sequence_length, query_dim)
        x_key = torch.transpose(x_key, 1, 2) # (batch_size, num_heads, sequence_length, query_dim)
        x_value = torch.transpose(x_value, 1, 2) # (batch_size, num_heads, sequence_length, head_dim)
        weighted_avgs = torch.matmul(x_query, torch.transpose(x_key, 2, 3)) # (batch_size, num_heads, sequence_length, sequence_length)
        weighted_avgs = torch.matmul(weighted_avgs, x_value) # (batch_size, num_heads, sequence_length, head_dim)

        # 5. Combine different heads - apply projection matrix
        # then sum across the different heads
        attention_result = torch.matmul(weighted_avgs, self.projection) # (batch_size, num_heads, sequence_length, point_dim)
        attention_result = torch.sum(attention_result, dim=1) # (batch_size, sequence_length, point_dim)
        return attention_result