import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from data_utils import getMatrix
class FeatureEnhance(nn.Module):
    def __init__(self, in_channels: int, out_channels: int):
        super(FeatureEnhance, self).__init__()
        self.global_emb = nn.Parameter(torch.Tensor(1, in_channels))
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.global_emb)
    def augment(self, x: Tensor):
        return x + self.global_emb


class SimFeatureEnhance(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, str_enc_dim=16, thre=0.9):
        super(SimFeatureEnhance, self).__init__()
        self.w = nn.Parameter(torch.Tensor(out_channels, in_channels))
        self.linear = torch.nn.Linear(in_channels+str_enc_dim, in_channels)
        self.thre = thre
        self.num_w = out_channels
        self.reset_parameters()
    def reset_parameters(self):
        nn.init.kaiming_normal_(self.w)
    def attention_(self, x):
        score = torch.mm(x, self.w.t())
        return score
    def augment(self, x: Tensor, rw_embeddings):
        if not any(dim == 0 for dim in rw_embeddings.shape):
            x_hat = torch.cat([x, rw_embeddings], dim=1)
            x_0 = self.linear(x_hat)
        else:
            x_0 = x
        score = self.attention_(x_0)
        weight = F.softmax(score, dim=-1)
        if self.num_w != 1:
            thresh = getMatrix(weight, self.thre).data
            weight = torch.threshold(weight, thresh, 0)
        p = weight.mm(self.w)
        return torch.add(x, p), p
