import numpy as np
import torch
import torch.nn as nn

class MaskedTransformerDecoderLayer(nn.Module):
    """自定义 Transformer Decoder Layer，只有 Masked Multi-Head Attention"""

    def __init__(self, d_model, nhead, dropout=0.1):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(embed_dim=d_model, num_heads=nhead, dropout=dropout, batch_first=True)
        self.norm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, tgt, tgt_mask=None):
        """
        tgt: (batch_size, seq_len, d_model) -> 目标序列
        tgt_mask: (seq_len, seq_len) -> 目标序列的因果遮罩，防止看到未来信息
        """
        # **Masked Multi-Head Self-Attention**
        attn_output, _ = self.self_attn(tgt, tgt, tgt, attn_mask=tgt_mask)

        # **残差连接 + LayerNorm**
        tgt = self.norm(tgt + self.dropout(attn_output))
        #tgt = self.norm(self.dropout(attn_output))

        return tgt


class MaskedTransformerDecoder(nn.Module):
    """使用自定义的 MaskedDecoderLayer 作为 Transformer Decoder"""

    def __init__(self, d_model, nhead, num_layers, dropout=0.1):
        super().__init__()
        self.layers = nn.ModuleList([MaskedTransformerDecoderLayer(d_model, nhead, dropout) for _ in range(num_layers)])

    def forward(self, tgt, tgt_mask=None):
        for layer in self.layers:
            tgt = layer(tgt, tgt_mask)
        return tgt

class TransformerRNN(nn.Module):
    def __init__(self, config, device, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.device = device
        self.seed = config['seed']
        self.rng = config['rng']

        n_input = config['num_input']
        n_rnn = config['num_rnn']  # Transformer 维度
        self.initNoise = torch.zeros(config['batch_size'], n_rnn).to(self.device)
        self.zeroStims = torch.zeros(config['batch_size'], n_input).to(self.device)
        self.onlyFixStim = torch.concat([
            torch.zeros(config['batch_size'], np.prod(config['image_shape'])),
            np.float32(config['fixationInput']) * torch.ones(config['batch_size'],
                                                             n_input - np.prod(config['image_shape']))], dim=1).to(
            self.device)

        # model
        num_layers = 1  # Transformer 层数
        nhead = 4  # Transformer 头数
        self.input_proj = nn.Linear(n_input, n_rnn // nhead)
        self.transformer_decoder = MaskedTransformerDecoder(d_model=n_rnn // nhead, nhead=nhead, num_layers=num_layers, dropout=0.1)
        self.fc_out = nn.Linear(n_rnn // nhead, config['num_rnn_out'])

        # **初始化状态 memory**
        self.config = config

    def forward(self, config, imageStims):
        # **构造输入序列**
        stims = torch.zeros(config["batch_size"], config["duration"], config['num_input']).to(self.device)
        for time in range(config['duration']):
            # Set time-specific inputs
            if time >= config['stimPeriod'][0] and time < config['stimPeriod'][1]:
                stims[:, time, :] = imageStims
            elif time < config['fixationPeriod'][1]:
                stims[:, time, :] = self.onlyFixStim
            else:
                stims[:, time, :] = self.zeroStims

        # **输入投影**
        x = self.input_proj(stims)  # (batch_size, tdim, d_model)

        # **因果遮罩（防止未来信息泄露）**
        tgt_mask = nn.Transformer.generate_square_subsequent_mask(config['duration']).to(self.device)

        # **Transformer Decoder**
        output_states = self.transformer_decoder(x, tgt_mask=tgt_mask)

        # **最终输出**
        y_hat = self.fc_out(output_states)  # 取最后时间步作为最终输出
        y_hat = y_hat.squeeze(0)
        output_states = output_states.squeeze(0)
        return y_hat, output_states, None, None


