import torch
import torch.nn as nn
import torch.nn.functional as F
import math

# 设置Transformer模型参数
num_blocks = 8
context_length = 64
d_model = 768
num_heads = 4
dropout = 0.1

class FeedForwardNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_model * 4),
            nn.ReLU(),
            nn.Linear(d_model * 4, d_model),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.ffn(x)

class Attention(nn.Module):
    def __init__(self):
        super().__init__()
        self.Wq = nn.Linear(d_model, d_model // num_heads, bias=False)
        self.Wk = nn.Linear(d_model, d_model // num_heads, bias=False)
        self.Wv = nn.Linear(d_model, d_model // num_heads, bias=False)
        self.register_buffer('mask', torch.tril(torch.ones(context_length, context_length)))
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, T, C = x.shape
        q = self.Wq(x)
        k = self.Wk(x)
        v = self.Wv(x)

        weights = (q @ k.transpose(-2, -1)) / math.sqrt(d_model // num_heads)
        weights = weights.masked_fill(self.mask[:T, :T] == 0, float('-inf'))
        weights = F.softmax(weights, dim=-1)
        weights = self.dropout(weights)

        output = weights @ v

        return output

class MultiHeadAttention(nn.Module):
    def __init__(self):
        super().__init__()
        self.heads = nn.ModuleList([Attention() for _ in range(num_heads)])
        self.projection_layer = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        head_outputs = [head(x) for head in self.heads]
        head_outputs = torch.cat(head_outputs, dim=-1)
        out = self.dropout(self.projection_layer(head_outputs))
        return out

class TransformerBlock(nn.Module):
    def __init__(self):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.ln2 = nn.LayerNorm(d_model)
        self.mha = MultiHeadAttention()
        self.ffn = FeedForwardNetwork()

    def forward(self, x):
        x = x + self.mha(self.ln1(x))
        x = x + self.ffn(self.ln2(x))
        return x

class TransformerModel(nn.Module):
    def __init__(self, input_dim, num_classes):
        super().__init__()
        self.embedding = nn.Linear(input_dim, d_model)  # 使用线性层进行嵌入
        self.transformer_blocks = nn.Sequential(*(
            [TransformerBlock() for _ in range(num_blocks)] +
            [nn.LayerNorm(d_model)]
        ))
        self.model_out_linear_layer = nn.Linear(d_model, num_classes)

    def forward(self, x, targets=None):
        x = self.embedding(x)  # 直接将输入特征映射到d_model维度
        x = x.unsqueeze(1)  # 添加时间维度，(B, T, C) => (B, 1, C)
        x = self.transformer_blocks(x)
        logits = self.model_out_linear_layer(x.squeeze(1))

        if targets is not None:
            loss = F.cross_entropy(input=logits, target=targets)
        else:
            loss = None
        return logits, loss
