import torch
import torch.nn as nn


class Specific(nn.Module):
    def __init__(self, input_size, target_num):
        super(Specific, self).__init__()
        self.encoder_layer = nn.TransformerEncoderLayer(d_model=input_size, nhead=8)
        self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=128)
        self.flatten = nn.Flatten(1, -1)
        self.linear = nn.Linear(128 * 128, target_num)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.transformer_encoder(x)
        x = self.flatten(x)
        x = self.linear(x)
        x = self.relu(x)
        return x