# Time : 2023/11/13 12:52
# Author : 小霸奔
# FileName: pretrain_net.p
from utils.config import ModelConfig
from utils.util_block import MultiHeadAttentionBlock
import torch
import torch.nn as nn
from torch import Tensor
import torch.nn.functional as F

class FeatureExtractor(nn.Module):
    def __init__(self, args):
        super(FeatureExtractor, self).__init__()
        self.ModelParam = ModelConfig(args["dataset"])
        self.FEBlock_EEG = nn.Sequential(
                    nn.Conv1d(self.ModelParam.EegNum, 64, kernel_size=50, stride=6, bias=False),
                    nn.BatchNorm1d(64),
                    nn.GELU(),
                    nn.MaxPool1d(kernel_size=8, stride=8),
                    nn.Dropout(0.1),

                    nn.Conv1d(64, 128, kernel_size=8),
                    nn.BatchNorm1d(128),
                    nn.GELU(),

                    nn.Conv1d(128, 256, kernel_size=8),
                    nn.BatchNorm1d(256),
                    nn.GELU(),

                    nn.Conv1d(256, 512, kernel_size=8),
                    nn.BatchNorm1d(512),
                    nn.GELU(),
                    nn.MaxPool1d(kernel_size=4, stride=4),
                )

        self.FEBlock_EOG = nn.Sequential(
            nn.Conv1d(self.ModelParam.EogNum, 64, kernel_size=50, stride=6, bias=False),
            nn.BatchNorm1d(64),
            nn.GELU(),
            nn.MaxPool1d(kernel_size=8, stride=8),
            nn.Dropout(0.1),

            nn.Conv1d(64, 128, kernel_size=8),
            nn.BatchNorm1d(128),
            nn.GELU(),

            nn.Conv1d(128, 256, kernel_size=8),
            nn.BatchNorm1d(256),
            nn.GELU(),

            nn.Conv1d(256, 512, kernel_size=8),
            nn.BatchNorm1d(512),
            nn.GELU(),
            nn.MaxPool1d(kernel_size=4, stride=4),
        )
        self.avg = nn.AdaptiveAvgPool1d(1)
        self.fusion = nn.Linear(1024, 512)

    def forward(self, eeg, eog):
        batch = eeg.shape[0] // self.ModelParam.SeqLength
        eeg = self.FEBlock_EEG(eeg)

        eog = self.FEBlock_EOG(eog)

        eeg = self.avg(eeg).view(batch * self.ModelParam.SeqLength, 1, self.ModelParam.EncoderParam.d_model)
        eog = self.avg(eog).view(batch * self.ModelParam.SeqLength, 1, self.ModelParam.EncoderParam.d_model)

        x = self.fusion(torch.concat((eeg, eog), dim=2))

        x = x.view(batch, self.ModelParam.SeqLength, -1)

        return x


class TransformerEncoder(torch.nn.Module):
    def __init__(self, args):
        super(TransformerEncoder, self).__init__()
        self.ModelParam = ModelConfig(args["dataset"])
        self.encoder = MultiHeadAttentionBlock(self.ModelParam.EncoderParam.d_model,
                                               self.ModelParam.EncoderParam.layer_num,
                                               self.ModelParam.EncoderParam.drop,
                                               self.ModelParam.EncoderParam.n_head)

    def forward(self, x):
        return self.encoder(x)


class SleepMLP(nn.Module):
    def __init__(self, args):
        super(SleepMLP, self).__init__()
        self.ModelParam = ModelConfig(args["dataset"])
        self.dropout_rate = self.ModelParam.SleepMlpParam.drop
        self.sleep_stage_mlp = nn.Sequential(
            nn.Linear(self.ModelParam.SleepMlpParam.first_linear[0],
                      self.ModelParam.SleepMlpParam.first_linear[1]),
            nn.Dropout(self.dropout_rate),
            nn.GELU(),
            nn.Linear(self.ModelParam.SleepMlpParam.second_linear[0],
                      self.ModelParam.SleepMlpParam.second_linear[1]),
            nn.Dropout(self.dropout_rate),
            nn.GELU(),
        )
        self.sleep_stage_classifier = nn.Linear(self.ModelParam.SleepMlpParam.out_linear[0],
                                                self.ModelParam.SleepMlpParam.out_linear[1], bias=False)

    def forward(self, x):
        x = self.sleep_stage_mlp(x)
        x = self.sleep_stage_classifier(x)
        x = x.permute(0, 2, 1)
        return x

class FeatureExtractor_Face(nn.Module):
    def __init__(self, args):
        super(FeatureExtractor_Face, self).__init__()
        self.ModelParam = ModelConfig(args["dataset"])
        self.eeg_fe1 = nn.Sequential(
            nn.Conv1d(self.ModelParam.FaceCn, 64, kernel_size=50, stride=6, bias=False, padding=24),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=8, stride=8, padding=4),
            nn.Dropout(0.1),

            nn.Conv1d(64, 128, kernel_size=6, padding=3),
            nn.BatchNorm1d(128),
            nn.ReLU(),

            nn.Conv1d(128, 256, kernel_size=6, padding=3),
            nn.BatchNorm1d(256),
            nn.ReLU(),

            nn.Conv1d(256, 512, kernel_size=6, padding=3),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=4, stride=4),
        )

        self.eeg_fe2 = nn.Sequential(
            nn.Conv1d(self.ModelParam.FaceCn, 64, kernel_size=400, stride=50, bias=False, padding=200),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=4, stride=4, padding=2),
            nn.Dropout(0.1),

            nn.Conv1d(64, 128, kernel_size=8, padding=3),
            nn.BatchNorm1d(128),
            nn.ReLU(),

            nn.Conv1d(128, 256, kernel_size=8, padding=3),
            nn.BatchNorm1d(256),
            nn.ReLU(),

            nn.Conv1d(256, 512, kernel_size=8, padding=3),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=2, stride=2, padding=1),
        )
        self.avg = nn.AdaptiveAvgPool1d(output_size=1)
        self.fusion = nn.Linear(1024, 512)

    def forward(self, x):
        batch = x.shape[0]
        for i in range(10):
            tp = x[:, :, i * 750:i * 750 + 750]
            tp = tp.view(batch, 1, self.ModelParam.FaceCn, 750)
            if i == 0:
                sequence = tp
                sequence = sequence
            else:
                sequence = torch.concat((sequence, tp), dim=1)
        # print(sequence.shape)
        sequence = sequence.view(-1, self.ModelParam.FaceCn, 750)
        x1 = self.eeg_fe1(sequence)
        x1 = self.avg(x1).view(batch, 10, -1)
        # print(x1.shape)
        x2 = self.eeg_fe2(sequence)
        x2 = self.avg(x2).view(batch, 10, -1)
        # print(x2.shape)
        x = self.fusion(torch.concat((x1, x2), dim=2))
        # print(x.shape)
        return x


class TransformerEncoder_Face(torch.nn.Module):
    def __init__(self, args):
        super(TransformerEncoder_Face, self).__init__()
        self.ModelParam = ModelConfig(args["dataset"])
        self.args = args
        self.encoder = MultiHeadAttentionBlock(self.ModelParam.EncoderParam.d_model,
                                               self.ModelParam.EncoderParam.layer_num,
                                               self.ModelParam.EncoderParam.drop,
                                               self.ModelParam.EncoderParam.n_head)

    def forward(self, x):
        return self.encoder(x)  # batch, 28, 512


class SleepMLP_Face(nn.Module):
    def __init__(self, args):
        super(SleepMLP_Face, self).__init__()
        self.ModelParam = ModelConfig(args["dataset"])
        self.dropout_rate = self.ModelParam.FaceMlpParam.drop
        self.sleep_stage_mlp = nn.Sequential(
            nn.Linear(5120,
                      self.ModelParam.FaceMlpParam.first_linear[1]),
            nn.Dropout(self.dropout_rate),
            nn.GELU(),
            nn.Linear(self.ModelParam.FaceMlpParam.second_linear[0],
                      self.ModelParam.FaceMlpParam.second_linear[1]),
            nn.Dropout(self.dropout_rate),
            nn.GELU(),
        )
        self.sleep_stage_classifier = nn.Linear(self.ModelParam.FaceMlpParam.out_linear[0],
                                                self.ModelParam.FaceMlpParam.out_linear[1], bias=False)

    def forward(self, x):
        batch = x.shape[0]
        x = x.view(batch, -1)
        x = self.sleep_stage_mlp(x)
        x = self.sleep_stage_classifier(x)
        x = x.view(-1, 9)
        return x


class FeatureExtractor_BCI2000(nn.Module):
    def __init__(self, args):
        super(FeatureExtractor_BCI2000, self).__init__()
        self.ModelParam = ModelConfig(args["dataset"])
        self.eeg_fe = nn.Sequential(
            nn.Conv1d(self.ModelParam.BCICn, 64, kernel_size=50, stride=6, bias=False, padding=24),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=8, stride=8, padding=4),
            nn.Dropout(0.1),

            nn.Conv1d(64, 128, kernel_size=6, padding=3),
            nn.BatchNorm1d(128),
            nn.ReLU(),

            nn.Conv1d(128, 256, kernel_size=6, padding=3),
            nn.BatchNorm1d(256),
            nn.ReLU(),

            nn.Conv1d(256, 512, kernel_size=6, padding=3),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=4, stride=4),
        )

        self.avg = nn.AdaptiveAvgPool1d(output_size=1)


    def forward(self, x):
        batch = x.shape[0]
        for i in range(10):
            tp = x[:, :, i * 64:i * 64 + 64]
            tp = tp.view(batch, 1, self.ModelParam.BCICn, 64)
            if i == 0:
                sequence = tp
                sequence = sequence
            else:
                sequence = torch.concat((sequence, tp), dim=1)
        # print(sequence.shape)
        sequence = sequence.view(-1, self.ModelParam.BCICn, 64)
        x1 = self.eeg_fe(sequence)
        x1 = self.avg(x1).view(batch, 10, -1)

        return x1


class TransformerEncoder_BCI2000(torch.nn.Module):
    def __init__(self, args):
        super(TransformerEncoder_BCI2000, self).__init__()
        self.ModelParam = ModelConfig(args["dataset"])
        self.args = args
        self.encoder = MultiHeadAttentionBlock(self.ModelParam.EncoderParam.d_model,
                                               self.ModelParam.EncoderParam.layer_num,
                                               self.ModelParam.EncoderParam.drop,
                                               self.ModelParam.EncoderParam.n_head)

    def forward(self, x):
        return self.encoder(x)  # batch, 28, 512


class SleepMLP_BCI2000(nn.Module):
    def __init__(self, args):
        super(SleepMLP_BCI2000, self).__init__()
        self.ModelParam = ModelConfig(args["dataset"])
        self.dropout_rate = self.ModelParam.FaceMlpParam.drop
        self.sleep_stage_mlp = nn.Sequential(
            nn.Linear(5120,
                      self.ModelParam.BCI2000MlpParam.first_linear[1]),
            nn.Dropout(self.dropout_rate),
            nn.GELU(),
            nn.Linear(self.ModelParam.BCI2000MlpParam.second_linear[0],
                      self.ModelParam.BCI2000MlpParam.second_linear[1]),
            nn.Dropout(self.dropout_rate),
            nn.GELU(),
        )
        self.sleep_stage_classifier = nn.Linear(self.ModelParam.BCI2000MlpParam.out_linear[0],
                                                self.ModelParam.BCI2000MlpParam.out_linear[1], bias=False)

    def forward(self, x):
        batch = x.shape[0]
        x = x.view(batch, -1)
        x = self.sleep_stage_mlp(x)
        x = self.sleep_stage_classifier(x)
        x = x.view(-1, 4)
        return x


class SleepMLP_BCI2000_2(nn.Module):
    def __init__(self, args):
        super(SleepMLP_BCI2000_2, self).__init__()
        self.ModelParam = ModelConfig(args["dataset"])
        self.dropout_rate = self.ModelParam.FaceMlpParam.drop
        self.sleep_stage_mlp = nn.Sequential(
            nn.Linear(5120,
                      self.ModelParam.BCI2000_2_MlpParam.first_linear[1]),
            nn.Dropout(self.dropout_rate),
            nn.GELU(),
            nn.Linear(self.ModelParam.BCI2000_2_MlpParam.second_linear[0],
                      self.ModelParam.BCI2000_2_MlpParam.second_linear[1]),
            nn.Dropout(self.dropout_rate),
            nn.GELU(),
        )
        self.sleep_stage_classifier = nn.Linear(self.ModelParam.BCI2000_2_MlpParam.out_linear[0],
                                                self.ModelParam.BCI2000_2_MlpParam.out_linear[1], bias=False)

    def forward(self, x):
        batch = x.shape[0]
        x = x.view(batch, -1)
        x = self.sleep_stage_mlp(x)
        x = self.sleep_stage_classifier(x)
        x = x.view(-1, 2)
        return x


class FeatureExtractor_MDD(nn.Module):
    def __init__(self, args):
        super(FeatureExtractor_MDD, self).__init__()
        self.ModelParam = ModelConfig(args["dataset"])
        self.eeg_fe = nn.Sequential(
            nn.Conv1d(self.ModelParam.MDDCn, 64, kernel_size=50, stride=6, bias=False, padding=24),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=8, stride=8, padding=4),
            nn.Dropout(0.1),

            nn.Conv1d(64, 128, kernel_size=6, padding=3),
            nn.BatchNorm1d(128),
            nn.ReLU(),

            nn.Conv1d(128, 256, kernel_size=6, padding=3),
            nn.BatchNorm1d(256),
            nn.ReLU(),

            nn.Conv1d(256, 512, kernel_size=6, padding=3),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=4, stride=4),
        )

        self.avg = nn.AdaptiveAvgPool1d(output_size=1)

    def forward(self, x):
        batch = x.shape[0]
        for i in range(10):
            tp = x[:, :, i * 100:i * 100 + 100]
            tp = tp.view(batch, 1, self.ModelParam.MDDCn, 100)
            if i == 0:
                sequence = tp
                sequence = sequence
            else:
                sequence = torch.concat((sequence, tp), dim=1)
        # print(sequence.shape)
        sequence = sequence.view(-1, self.ModelParam.MDDCn, 100)
        x1 = self.eeg_fe(sequence)
        x1 = self.avg(x1).view(batch, 10, -1)

        return x1


class TransformerEncoder_MDD(torch.nn.Module):
    def __init__(self, args):
        super(TransformerEncoder_MDD, self).__init__()
        self.ModelParam = ModelConfig(args["dataset"])
        self.args = args
        self.encoder = MultiHeadAttentionBlock(self.ModelParam.EncoderParam.d_model,
                                               self.ModelParam.EncoderParam.layer_num,
                                               self.ModelParam.EncoderParam.drop,
                                               self.ModelParam.EncoderParam.n_head)

    def forward(self, x):
        return self.encoder(x)  # batch, 28, 512


class SleepMLP_MDD(nn.Module):
    def __init__(self, args):
        super(SleepMLP_MDD, self).__init__()
        self.ModelParam = ModelConfig(args["dataset"])
        self.dropout_rate = self.ModelParam.FaceMlpParam.drop
        self.sleep_stage_mlp = nn.Sequential(
            nn.Linear(5120,
                      self.ModelParam.MDDMlpParam.first_linear[1]),
            nn.Dropout(self.dropout_rate),
            nn.GELU(),
            nn.Linear(self.ModelParam.MDDMlpParam.second_linear[0],
                      self.ModelParam.MDDMlpParam.second_linear[1]),
            nn.Dropout(self.dropout_rate),
            nn.GELU(),
        )
        self.sleep_stage_classifier = nn.Linear(self.ModelParam.MDDMlpParam.out_linear[0],
                                                self.ModelParam.MDDMlpParam.out_linear[1], bias=False)

    def forward(self, x):
        batch = x.shape[0]
        x = x.view(batch, -1)
        x = self.sleep_stage_mlp(x)
        x = self.sleep_stage_classifier(x)
        x = x.view(-1, 2)
        return x



class ConvModule(nn.Module):
    def __init__(self, n_filters, n_chs=32, eeg_len = 7500, dropout = 0.5):
        super().__init__()
        self.temporal_conv = nn.Conv2d(1, n_filters, kernel_size=(1, 25),bias = False)
        self.spatial_conv = nn.Conv2d(n_filters, n_filters, kernel_size=(n_chs, 1))
        self.bn = nn.BatchNorm2d(n_filters)
        self.avg_pooling = nn.AvgPool2d(kernel_size=(1, 75), stride=(1, 15))
        self.dp = nn.Dropout(dropout)
        self.elu = nn.ELU()
#         self.conv2d = nn.Conv2d(n_filters, n_filters, (1, 1), stride=(1, 1))
        self.n_chs = n_chs
        self.eeg_len = eeg_len

    def forward(self, X):
        X = X.reshape(-1, 1, self.n_chs, self.eeg_len)
        X = self.temporal_conv(X)
        X = self.spatial_conv(X)
        X = self.elu(self.bn(X))
        X = self.dp(self.avg_pooling(X))
#         X = self.conv2d(X)
        X = X.squeeze(-2)
        X = X.mT
        return X




class MultiHeadAttention(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.hidden_size = config['hidden_size']
        self.num_heads = config['num_heads']
        self.dropout_rate = config['dropout_rate']
        assert (
                self.hidden_size % self.num_heads == 0
        ), "Hidden size must be divisible by num_heads but got {} and {}".format(
            self.hidden_size, self.num_heads
        )
        self.head_dim = self.hidden_size // self.num_heads
        self.dropout = nn.Dropout(self.dropout_rate)

        self.wq = nn.Linear(self.hidden_size, self.hidden_size)
        self.wk = nn.Linear(self.hidden_size, self.hidden_size)
        self.wv = nn.Linear(self.hidden_size, self.hidden_size)
        self.wo = nn.Linear(self.hidden_size, self.hidden_size)

    def _split_heads(self, x: Tensor) -> Tensor:
        batch_size, seq_len, _ = x.shape
        x = x.view(batch_size, seq_len, self.num_heads, self.head_dim)
        return x.permute(0, 2, 1, 3)

    def _merge_heads(self, x: Tensor) -> Tensor:
        batch_size, _, seq_len, _ = x.shape
        return x.permute(0, 2, 1, 3).reshape(batch_size, seq_len, self.hidden_size)

    def forward(self, q, k, v, att_mask=None) -> Tensor:
        q = self._split_heads(self.wq(q))
        k = self._split_heads(self.wk(k))
        v = self._split_heads(self.wv(v))

        qk_logits = torch.matmul(q, k.mT)
        if att_mask is not None:
            att_mask = att_mask[:, :, :q.shape[-2], :k.shape[-2]]
            qk_logits += att_mask * -1e9
        att_score = F.softmax(qk_logits / (self.head_dim ** 0.5), dim=-1)
        att_score = self.dropout(att_score)
        embeds = torch.matmul(att_score, v)
        embeds = self.wo(self._merge_heads(embeds))

        return embeds


class AddNorm(nn.Module):
    def __init__(self, ln_shape, dropout_rate):
        super().__init__()
        self.layer_norm = nn.LayerNorm(ln_shape)
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, X, Y):
        X = self.layer_norm(X + self.dropout(Y))
        return X


class FeedForwardNetwork(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.gelu = nn.GELU()
        self.fc2 = nn.Linear(hidden_size, output_size)

    def forward(self, X):
        X = self.fc2(self.gelu(self.fc1(X)))
        return X


class EncoderBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.hidden_size = config['hidden_size']
        self.num_heads = config['num_heads']
        self.dropout_rate = config['dropout_rate']
        self.ffn_size = config['ffn_size']
        self.mha = MultiHeadAttention(config)

        self.an1 = AddNorm(self.hidden_size, self.dropout_rate)
        self.an2 = AddNorm(self.hidden_size, self.dropout_rate)
        self.ffn = FeedForwardNetwork(self.hidden_size, self.ffn_size, self.hidden_size)

    def forward(self, X):
        Y = self.an1(X, self.mha(X, X, X))
        Y = self.an2(Y, self.ffn(Y))
        return Y


class TransformerModule(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.num_layers = config['num_layers']
        self.blks = nn.Sequential()
        for i in range(self.num_layers):
            self.blks.add_module("block" + str(i),
                                 EncoderBlock(config))

    def forward(self, X):
        for layer in self.blks:
            X = layer(X)
        return X


class Conformer(nn.Module):
    def __init__(self, config):
        super().__init__()

        # hyper params
        self.cnn_filters = config['cnn_filters']
        self.n_chs = config['n_chs']
        self.eeg_len = config['eeg_len']
        self.num_heads = config['num_heads']
        self.hidden_size = self.cnn_filters * self.num_heads
        self.num_layers = config['num_layers']
        self.num_classes = config['num_classes']
        self.dropout_rate = config['dropout_rate']
        self.fc_size = config['fc_size']
        config['hidden_size'] = self.hidden_size
        self.cnn_module = ConvModule(self.cnn_filters, self.n_chs,self.eeg_len)
        self.transformer_module = TransformerModule(config)
        # self.feedforward = nn.Linear()
        self.flatten = nn.Flatten()
        cnn_len = (self.eeg_len-24-74)//15+1
        self.fc = nn.Sequential(nn.Linear(self.cnn_filters * cnn_len*self.num_heads, self.fc_size * 4),
                                nn.ELU(), nn.Dropout(self.dropout_rate),
                                nn.Linear(self.fc_size * 4, self.fc_size),
                                nn.ELU(), nn.Dropout(self.dropout_rate),
                                nn.Linear(self.fc_size, self.num_classes))

    def forward(self, X):
        X = self.cnn_module(X)
        X = X.repeat(1,1,self.num_heads)
        X = self.transformer_module(X)
        X = self.flatten(X)
        X = self.fc(X)
        X = F.softmax(X, dim=-1)
        return X


class FeatureExtractor_TUEV(nn.Module):
    def __init__(self, args):
        super(FeatureExtractor_TUEV, self).__init__()
        self.ModelParam = ModelConfig(args["dataset"])
        self.eeg_fe = nn.Sequential(
            nn.Conv1d(self.ModelParam.TUEVCn, 64, kernel_size=50, stride=6, bias=False, padding=24),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=8, stride=8, padding=4),
            nn.Dropout(0.1),

            nn.Conv1d(64, 128, kernel_size=6, padding=3),
            nn.BatchNorm1d(128),
            nn.ReLU(),

            nn.Conv1d(128, 256, kernel_size=6, padding=3),
            nn.BatchNorm1d(256),
            nn.ReLU(),

            nn.Conv1d(256, 512, kernel_size=6, padding=3),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=4, stride=4),
        )

        self.avg = nn.AdaptiveAvgPool1d(output_size=1)

    def forward(self, x):
        batch = x.shape[0]
        for i in range(10):
            tp = x[:, :, i * 100:i * 100 + 100]
            tp = tp.view(batch, 1, self.ModelParam.TUEVCn, 100)
            if i == 0:
                sequence = tp
                sequence = sequence
            else:
                sequence = torch.concat((sequence, tp), dim=1)
        # print(sequence.shape)
        sequence = sequence.view(-1, self.ModelParam.TUEVCn, 100)
        x1 = self.eeg_fe(sequence)
        x1 = self.avg(x1).view(batch, 10, -1)

        return x1


class TransformerEncoder_TUEV(torch.nn.Module):
    def __init__(self, args):
        super(TransformerEncoder_TUEV, self).__init__()
        self.ModelParam = ModelConfig(args["dataset"])
        self.args = args
        self.encoder = MultiHeadAttentionBlock(self.ModelParam.EncoderParam.d_model,
                                               self.ModelParam.EncoderParam.layer_num,
                                               self.ModelParam.EncoderParam.drop,
                                               self.ModelParam.EncoderParam.n_head)

    def forward(self, x):
        return self.encoder(x)  # batch, 28, 512


class SleepMLP_TUEV(nn.Module):
    def __init__(self, args):
        super(SleepMLP_TUEV, self).__init__()
        self.ModelParam = ModelConfig(args["dataset"])
        self.dropout_rate = self.ModelParam.FaceMlpParam.drop
        self.sleep_stage_mlp = nn.Sequential(
            nn.Linear(5120,
                      self.ModelParam.TUEVMlpParam.first_linear[1]),
            nn.Dropout(self.dropout_rate),
            nn.GELU(),
            nn.Linear(self.ModelParam.TUEVMlpParam.second_linear[0],
                      self.ModelParam.TUEVMlpParam.second_linear[1]),
            nn.Dropout(self.dropout_rate),
            nn.GELU(),
        )
        self.sleep_stage_classifier = nn.Linear(self.ModelParam.TUEVMlpParam.out_linear[0],
                                                self.ModelParam.TUEVMlpParam.out_linear[1], bias=False)

    def forward(self, x):
        batch = x.shape[0]
        x = x.view(batch, -1)
        x = self.sleep_stage_mlp(x)
        x = self.sleep_stage_classifier(x)
        x = x.view(-1, 6)
        return x

# if __name__ == '__main__':
#     args = {"dataset": "BCI2000"}
#     a = torch.randn(size=(28, 64, 640))
#     fe = FeatureExtractor_BCI2000(args)
#     en = TransformerEncoder_BCI2000(args)
#     mlp = SleepMLP_BCI2000(args)
#     x = fe(a)
#     print(x.shape)
#     x = en(x)
#     print(x.shape)
#     x = mlp(x)
#     print(x.shape)