# Time : 2023/11/13 14:18
# Author : 小霸奔
# FileName: util_block.p
import torch.nn as nn
import math
import torch
import copy
import torch.nn.functional as F
from utils.util import fix_randomness
import numpy as np


def clones(module, N):
    """用于生成相同网络层的克隆函数, 它的参数module表示要克隆的目标网络层, N代表需要克隆的数量"""
    # 在函数中, 我们通过for循环对module进行N次深度拷贝, 使其每个module成为独立的层,
    # 然后将其放在nn.ModuleList类型的列表中存放.
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])


class MultiHeadAttention(torch.nn.Module):
    def __init__(self, input_size, hidden_size, num_head):
        super(MultiHeadAttention, self).__init__()
        self.w_query = nn.Linear(input_size, input_size)
        self.w_key = nn.Linear(input_size, input_size)
        self.w_value = nn.Linear(input_size, input_size)
        self.num_head = num_head
        self.dense = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.drop = nn.Dropout(0.2)
        self.input_size = input_size
        self.att_dropout = nn.Dropout(0.25)

    def forward(self, input_tensor):
        batch_size = input_tensor.shape[0]
        query = self.w_query(input_tensor)
        key = self.w_key(input_tensor)
        value = self.w_value(input_tensor)
        query = query.view(batch_size, -1, self.num_head, self.input_size//self.num_head).permute(0, 2, 1 ,3)
        key = key.view(batch_size, -1, self.num_head, self.input_size // self.num_head).permute(0, 2, 1, 3)
        value = value.view(batch_size, -1, self.num_head, self.input_size // self.num_head).permute(0, 2, 1, 3)
        attention_score = torch.matmul(query, key.transpose(-1, -2))
        attention_score = attention_score / math.sqrt(input_tensor.shape[2])
        attention_prob = nn.Softmax(dim=1)(attention_score)
        attention_prob = self.att_dropout(attention_prob)
        context = torch.matmul(attention_prob, value)
        context = context.permute(0, 2, 1, 3).contiguous()
        context = context.view(batch_size, -1, self.num_head*(self.input_size//self.num_head))
        hidden_state = self.dense(context)
        return hidden_state


class FeedForward(torch.nn.Module):
    def __init__(self, d_model, d_ff, dropout):
        """初始化函数有三个输入参数分别是d_model, d_ff,和dropout=0.1，第一个是线性层的输入维度也是第二个线性层的输出维度，
           因为我们希望输入通过前馈全连接层后输入和输出的维度不变. 第二个参数d_ff就是第二个线性层的输入维度和第一个线性层的输出维度.
           最后一个是dropout置0比率."""
        super(FeedForward, self).__init__()
        self.relu = nn.ReLU()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.drop = nn.Dropout(dropout)

    def forward(self, x):
        x = self.drop(self.relu(self.linear1(x)))
        x = self.drop(self.relu(self.linear2(x)))
        return x


class LayerNorm(torch.nn.Module):
    def __init__(self, features, eps=1e-6):
        super(LayerNorm, self).__init__()
        self.a1 = nn.Parameter(torch.ones(features))
        self.b1 = nn.Parameter(torch.zeros(features))
        self.eps = eps

    # 根据features的形状初始化两个参数张量a2，和b2，第一个初始化为1张量，
    # 也就是里面的元素都是1，第二个初始化为0张量，也就是里面的元素都是0，这两个张量就是规范化层的参数，
    # 因为直接对上一层得到的结果做规范化公式计算，将改变结果的正常表征，因此就需要有参数作为调节因子，
    # 使其即能满足规范化要求，又能不改变针对目标的表征.最后使用nn.parameter封装，代表他们是模型的参数。

    def forward(self, x):
        """输入参数x代表来自上一层的输出"""
        # 在函数中，首先对输入变量x求其最后一个维度的均值，并保持输出维度与输入维度一致.
        # 接着再求最后一个维度的标准差，然后就是根据规范化公式，用x减去均值除以标准差获得规范化的结果，
        # 最后对结果乘以我们的缩放参数，即a2，*号代表同型点乘，即对应位置进行乘法操作，加上位移参数b2.返回即可.
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        return self.a1 * (x - mean) / (std + self.eps) + self.b1


class SublayerConnection(nn.Module):
    def __init__(self, size, dropout=0.1):
        """它输入参数有两个, size以及dropout， size一般是都是词嵌入维度的大小，
           dropout本身是对模型结构中的节点数进行随机抑制的比率，
           又因为节点被抑制等效就是该节点的输出都是0，因此也可以把dropout看作是对输出矩阵的随机置0的比率.
        """
        super(SublayerConnection, self).__init__()
        # 实例化了规范化对象self.norm
        self.norm = LayerNorm(size)
        # 又使用nn中预定义的droupout实例化一个self.dropout对象.
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x, sublayer):
        """前向逻辑函数中, 接收上一个层或者子层的输入作为第一个参数，
           将该子层连接中的子层函数作为第二个参数"""

        # 我们首先对输出进行规范化，然后将结果传给子层处理，之后再对子层进行dropout操作，
        # 随机停止一些网络中神经元的作用，来防止过拟合. 最后还有一个add操作，
        # 因为存在跳跃连接，所以是将输入x与dropout后的子层输出结果相加作为最终的子层连接输出.
        return x + self.dropout(sublayer(self.norm(x)))


class EncoderLayer(nn.Module):
    def __init__(self, size, self_attn, feed_forward, dropout):
        """它的初始化函数参数有四个，分别是size，其实就是我们词嵌入维度的大小，它也将作为我们编码器层的大小,
           第二个self_attn，之后我们将传入多头自注意力子层实例化对象, 并且是自注意力机制,
           第三个是feed_froward, 之后我们将传入前馈全连接层实例化对象, 最后一个是置0比率dropout."""
        super(EncoderLayer, self).__init__()
        self.self_attn = self_attn
        self.feed_forward = feed_forward
        self.pos_emb = PositionalEncoding(size)
        # 如图所示, 编码器层中有两个子层连接结构, 所以使用clones函数进行克隆
        self.sublayer = clones(SublayerConnection(size, dropout), 2)
        # 把size传入其中
        self.size = size

    def forward(self, x):
        """forward函数中有两个输入参数，x和mask，分别代表上一层的输出，和掩码张量mask."""
        # 里面就是按照结构图左侧的流程. 首先通过第一个子层连接结构，其中包含多头自注意力子层，
        # 然后通过第二个子层连接结构，其中包含前馈全连接子层. 最后返回结果.
        x = self.pos_emb(x.transpose(0, 1)).transpose(0, 1)
        x = self.sublayer[0](x, lambda x: self.self_attn(x))
        return self.sublayer[1](x, self.feed_forward)


class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        # pe的维度是（5000，512）
        pe = torch.zeros(max_len, d_model)

        # position是一个5000行1列的tensor
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)

        # div_term是一个256长度的一维tensor
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        pe = pe.unsqueeze(0).transpose(0, 1)
        # 最终的pe是一个torch.Size([5000, 1, 512])的维度
        self.register_buffer('pe', pe)

    def forward(self, x):
        '''
        x: [seq_len, batch_size, d_model]
        '''
        x = x + self.pe[:x.size(0), :]

        return self.dropout(x)


class MultiHeadAttentionBlock(nn.Module):
    def __init__(self, d_model, layer_num, drop_out, n_head):
        super(MultiHeadAttentionBlock, self).__init__()
        self.multi_attention = MultiHeadAttention(d_model, d_model, n_head)
        self.dropout_rate = drop_out
        self.feedforward = FeedForward(d_model, d_model*4, self.dropout_rate)
        self.encoder = EncoderLayer(d_model, self_attn=self.multi_attention, feed_forward=self.feedforward,
                                    dropout=self.dropout_rate)
        self.layer_num = layer_num

    def forward(self, x):
        for _ in range(self.layer_num):
            x = self.encoder(x)
        return x


def D(p, z, version='simplified'):  # negative cosine similarity
    if version == 'original':
        z = z.detach()  # stop gradient
        p = F.normalize(p, dim=2)  # l2-normalize
        z = F.normalize(z, dim=2)  # l2-normalize
        return -(p * z).sum(dim=2).mean()

    elif version == 'simplified':  # same thing, much faster. Scroll down, speed test in __main__
        return - F.cosine_similarity(p, z.detach(), dim=-1).mean()
    else:
        raise Exception


class prediction_MLP(nn.Module):
    def __init__(self, in_dim=128, hidden_dim=512, out_dim=128):  # bottleneck structure
        super().__init__()
        ''' page 3 baseline setting
        Prediction MLP. The prediction MLP (h) has BN applied 
        to its hidden fc layers. Its output fc does not have BN
        (ablation in Sec. 4.4) or ReLU. This MLP has 2 layers. 
        The dimension of h’s input and output (z and p) is d = 2048, 
        and h’s hidden layer’s dimension is 512, making h a 
        bottleneck structure (ablation in supplement). 
        '''
        self.layer1 = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True)
        )
        self.layer2 = nn.Linear(hidden_dim, out_dim)
        """
        Adding BN to the output of the prediction MLP h does not work
        well (Table 3d). We find that this is not about collapsing. 
        The training is unstable and the loss oscillates.
        """

    def forward(self, x):
        batch, seq_len = x.shape[0], x.shape[1]
        x = x.view(batch * seq_len, -1)
        x = self.layer1(x)
        x = x.view(batch, seq_len, -1)
        x = self.layer2(x)
        return x


class projection_MLP(nn.Module):
    def __init__(self, in_dim, hidden_dim=512, out_dim=128):
        super().__init__()
        ''' page 3 baseline setting
        Projection MLP. The projection MLP (in f) has BN ap-
        plied to each fully-connected (fc) layer, including its out- 
        put fc. Its output fc has no ReLU. The hidden fc is 2048-d. 
        This MLP has 3 layers.
        '''
        self.layer1 = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True)
        )
        self.layer2 = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True)
        )
        self.layer3 = nn.Sequential(
            nn.Linear(hidden_dim, out_dim),
            nn.BatchNorm1d(out_dim)
        )
        self.num_layers = 3

    def set_layers(self, num_layers):
        self.num_layers = num_layers

    def forward(self, x):
        batch, seq_len = x.shape[0], x.shape[1]
        x = x.view(batch*seq_len, -1)
        if self.num_layers == 3:
            x = self.layer1(x)
            x = self.layer2(x)
            x = self.layer3(x)
        elif self.num_layers == 2:
            x = self.layer1(x)
            x = self.layer3(x)
        else:
            raise Exception
        x = x.view(batch, seq_len, -1)
        return x


def get_item(rand, signal):
    if rand == 1:
        return flipping(signal)
    elif rand == 2:
        return scaling(signal)
    else:
        return negate(signal)


def augmentation(eeg, eog, args):
    rand = np.random.randint(args["rand"])
    fix_randomness(rand)
    rand1 = np.random.randint(1, 4)
    rand2 = np.random.randint(1, 4)
    while rand1 == rand2:
        rand2 = np.random.randint(1, 4)

    eeg_aug1 = get_item(rand1, eeg)
    eog_aug1 = get_item(rand1, eog)

    eeg_aug2 = get_item(rand2, eeg)
    eog_aug2 = get_item(rand2, eog)

    return eeg_aug1, eog_aug1, eeg_aug2, eog_aug2

def negate(signal):
    """
    negate the signal
    """
    # sigma = np.random.uniform(0.8, 1.2)
    negated_signal = signal * (-1)
    return negated_signal


def scaling(x):
    sigma = np.random.uniform(1.1)
    return x*sigma


def flipping(x):
    return torch.flip(x, dims=[1])


class NTXentLoss(torch.nn.Module):

    def __init__(self, device, batch_size, temperature, use_cosine_similarity):
        super(NTXentLoss, self).__init__()
        self.batch_size = batch_size
        self.temperature = temperature
        self.device = device
        self.softmax = torch.nn.Softmax(dim=-1)
        # self.mask_samples_from_same_repr = self._get_correlated_mask().type(torch.bool)
        self.similarity_function = self._get_similarity_function(use_cosine_similarity)
        self.criterion = torch.nn.CrossEntropyLoss(reduction="sum")

    def _get_similarity_function(self, use_cosine_similarity):
        if use_cosine_similarity:
            self._cosine_similarity = torch.nn.CosineSimilarity(dim=-1)
            return self._cosine_simililarity
        else:
            return self._dot_simililarity

    def get_correlated_mask(self):
        diag = np.eye(2 * self.batch_size)
        l1 = np.eye((2 * self.batch_size), 2 * self.batch_size, k=-self.batch_size)
        l2 = np.eye((2 * self.batch_size), 2 * self.batch_size, k=self.batch_size)
        mask = torch.from_numpy((diag + l1 + l2))
        mask = (1 - mask).type(torch.bool)
        return mask.to(self.device)

    @staticmethod
    def _dot_simililarity(x, y):
        v = torch.tensordot(x.unsqueeze(1), y.T.unsqueeze(0), dims=2)
        # x shape: (N, 1, C)
        # y shape: (1, C, 2N)
        # v shape: (N, 2N)
        return v

    def _cosine_simililarity(self, x, y):
        # x shape: (N, 1, C)
        # y shape: (1, 2N, C)
        # v shape: (N, 2N)
        v = self._cosine_similarity(x.unsqueeze(1), y.unsqueeze(0))
        return v

    def forward(self, zis, zjs):
        losses = 0
        for i in range(zis.shape[0]):
            zis_ = zis[i, :, :]
            zjs_ = zjs[i, :, :]

            representations = torch.cat([zjs_, zis_], dim=0)

            similarity_matrix = self.similarity_function(representations, representations)
            # filter out the scores from the positive samples
            self.batch_size = zis_.shape[0]
            l_pos = torch.diag(similarity_matrix, self.batch_size)
            r_pos = torch.diag(similarity_matrix, -self.batch_size)

            positives = torch.cat([l_pos, r_pos]).view(2 * self.batch_size, 1)
            # print(positives.shape)
            # negatives = similarity_matrix[self.mask_samples_from_same_repr].view(2 * self.batch_size, -1)
            negatives = similarity_matrix[self.get_correlated_mask().type(torch.bool)].view(2 * self.batch_size, -1)
            # print(negatives.shape)
            logits = torch.cat((positives, negatives), dim=1)
            logits /= self.temperature

            labels = torch.zeros(2 * self.batch_size).to(self.device).long()
            loss = self.criterion(logits, labels)
            losses += loss
        return losses / (2 * self.batch_size)


class ProjHead(nn.Module):
    def __init__(self, args):
        super(ProjHead, self).__init__()


        self.input_length = 512

        self.classifier = nn.Sequential(
            nn.Linear(self.input_length, self.input_length // 2),
            nn.BatchNorm1d(self.input_length // 2),
            nn.ReLU(True),
            nn.Linear(self.input_length // 2, self.input_length // 4),
        )

    def forward(self, x):
        batch = x.shape[0]
        x = x.view(-1, self.input_length)
        x = self.classifier(x)
        x = x.view(batch, 20, -1)
        return x