import torch
import torch.nn as nn

from SourceCode.ModelModule.SparseSoftmax import Sparsemax


class AttentionMatrix(nn.Module):
    def __init__(self, source_refined_dim, row_dim):
        super().__init__()
        self.attention_matrix = torch.nn.Parameter(torch.rand(source_refined_dim, row_dim, requires_grad=True))
        self.normalize()
        self.sparse_softmax = Sparsemax(row_dim)

    def forward(self, refined_vec):
        product_tensor = refined_vec.matmul(self.attention_matrix)
        return self.sparse_softmax(product_tensor)

    def normalize(self):
        with torch.no_grad():
            matrix_pow_2 = torch.square(self.attention_matrix)
            matrix_base = torch.sqrt(matrix_pow_2.sum(dim=0, keepdim=True))
            matrix_base = matrix_base.repeat(self.attention_matrix.shape[0], 1)
            self.attention_matrix.data = self.attention_matrix.div(matrix_base)


class ScaleAttentionMatrix(nn.Module):
    def __init__(self, source_refined_dim, row_dim):
        super().__init__()
        self.attention_matrix = torch.nn.Parameter(torch.rand(source_refined_dim, row_dim, requires_grad=True))
        self.normalize()
        self.scale_value = torch.nn.Parameter(torch.ones(1, requires_grad=True))
        self.sparse_softmax = Sparsemax(row_dim)

    def forward(self, refined_vec):
        refined_vec = refined_vec * self.scale_value
        product_tensor = refined_vec.matmul(self.attention_matrix)
        return self.sparse_softmax(product_tensor)

    def normalize(self):
        with torch.no_grad():
            matrix_pow_2 = torch.square(self.attention_matrix)
            matrix_base = torch.sqrt(matrix_pow_2.sum(dim=0, keepdim=True))
            matrix_base = matrix_base.repeat(self.attention_matrix.shape[0], 1)
            self.attention_matrix.data = self.attention_matrix.div(matrix_base)
