# import math
# from copy import deepcopy
# import random
# import torch
# import torch.nn as nn
# from torch import Tensor
# from typing import Optional, Union, Tuple
# import torch.nn.functional as F
# from transformers.models.vit_mae.modeling_vit_mae import (
#     ViTMAEForPreTraining,
#     ViTMAEEncoder,
#     ViTMAEModel,
#     ViTMAEEmbeddings,
#     ViTMAEForPreTrainingOutput,
#     ViTMAEModelOutput,
#     ViTMAEDecoder,
#     ViTMAEDecoderOutput,
# )
# from transformers.models.nystromformer.modeling_nystromformer import NystromformerLayer
# from transformers.modeling_outputs import BaseModelOutput
# from random import randint
# from brainlm_mae.configuration_brainlm import BrainLMConfig
# from nitime.timeseries import TimeSeries
# from nitime.analysis import SpectralAnalyzer, FilterAnalyzer, NormalizationAnalyzer
# import numpy as np



# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# class PositionalEncoding_abs(nn.Module):
#     """
#     Positional Encoding module from PyTorch tutorial: [link]

#     Positional Encoding Formula:
#     - PE(pos, 2i) = sin(pos / ( 10000^{2i/d_model} ))  # Even dimensions = sin frequency
#     - PE(pos, 2i+1) = cos(pos / ( 10000^{2i/d_model} ))  # Odd dimensions = cosine frequency

#     10000 is a user-defined variable, chosen as 10000 by authors of original Transformer paper
#     - Scaling by 1/10000 makes 1 cycle very long => guarantees unique positional encodings
#     - If you plot sin(x / 10000), takes > 60k to complete 1 cycle of sin
#     """

#     def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 512):
#         super().__init__()
#         self.dropout = nn.Dropout(p=dropout)

#         # div_term creates this part of expression: ( 10000^{2i/d_model} )
#         div_term = torch.exp(
#             torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
#         )

#         position = torch.arange(max_len).unsqueeze(1)  # 0 through 5000
#         pe = torch.zeros(max_len, d_model)  # [max_seq_len, hidden_size]
#         pe[:, 0::2] = torch.sin(position * div_term)
#         pe[:, 1::2] = torch.cos(position * div_term)
#         self.register_buffer("pe", pe)

#     def forward(self, x: Tensor) -> Tensor:
#         """
#         Args:
#             x: Tensor, shape [seq_len, batch_size, embedding_dim]
#         """
#         # print(self.pe)
#         pos_encoding = self.pe[: x.size(1)]  # shape [seq_len, 1, embedding_dim]
#         pos_encoding = (
#             pos_encoding.unsqueeze(0).repeat(x.shape[0], 1, 1)
#         )
#         x = x + pos_encoding
#         return x
        
#         # x = x + self.pe[:, :x.size(1)]
#         # return self.dropout(x)
    

    
    
    
    
# # class PositionalEncoding_abs(nn.Module):
# #     def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 512):
# #         super(PositionalEncoding_abs, self).__init__()
# #         self.position_embedding = nn.Parameter(torch.empty(max_len, d_model))
# #         self.dropout = nn.Dropout(p=dropout)
# #         self.init_weights()
        

# #     def init_weights(self):
# #         #nn.init.uniform_(self.position_embedding, a=-0.1, b=0.1)
# #         nn.init.xavier_uniform_(self.position_embedding)

# #     def forward(self, x):
# #         # Assuming x is of shape (batch_size, seq_len, embedding_dim)
# #         seq_len = x.size(1)
# #         pos_encoding = self.position_embedding[:seq_len, :].unsqueeze(0).repeat(x.shape[0], 1, 1)
# #         return x + pos_encoding
    
# class BrainLMEmbeddings(ViTMAEEmbeddings):
#     """
#     Construct the CLS token, gene index and cell index embeddings.
#     """

#     def __init__(self, config):
#         super().__init__(config)
#         self.patch_embeddings = None
#         self.position_embeddings = None
#         self.num_brain_voxels = 1000
#         self.num_timepoints_per_voxel = 200
#         self.mask_ratio = 0.7
#         self.tokens = 1000
#         self.pos_embedding = PositionalEncoding_abs(d_model=config.hidden_size)
        
        
      
        
#         self.signal_embedding_projection = nn.Linear(
            
#             1000, config.hidden_size, bias=True     
#         )
        
#         self.signal_embedding_projection2 = nn.Linear(
            
#             config.hidden_size, config.hidden_size, bias=False     
#         )
        
#         self.region_weight_net = nn.Sequential(
#             nn.Linear(config.hidden_size, config.hidden_size, bias=True ),
#             nn.ReLU(),
#             nn.Linear(config.hidden_size, config.hidden_size, bias=True ),
#             nn.Softmax(dim=-1)  # 生成权重向量，使用softmax保证归一化
#         )

       

#     def initialize_weights(self):
#         torch.nn.init.normal_(self.cls_token, std=self.config.initializer_range)

#     def forward(self, signal_vectors,  xyz_vectors, noise):
#         """
#         :param signal_vectors: torch tensor of shape [batch, num_voxels, num_timepoints_per_voxel]

#         :param noise: torch tensor of noise for reproducibility, e.g. torch.rand(batch_size, seq_length, device=sequence.device)
#         :return:
#             embeddings: [batch, num_voxels * num_unmasked_patch_tokens + 1 CLS token, hidden_size]
#             mask: [batch, num_voxels, num_patch_tokens]
#             ids_restore: [batch, num_voxels, num_patch_tokens]
#         """
        
#         batch, num_timepoints_per_node,num_voxels = signal_vectors.shape  

#         pred_len = int(num_timepoints_per_node*0.25)            ## 6-20
#         first_end = num_timepoints_per_node-pred_len ##24-80
        
       
#         x = self.signal_embedding_projection(signal_vectors)
        
#         x = self.pos_embedding(x)  #### b,2604,1,256
        
#         brain_region_weights = self.region_weight_net(x)
#         x_weighted = x * brain_region_weights
        
       
        
#         embeddings, mask, ids_restore = self.random_masking(x_weighted, first_end,pred_len,noise=noise)
        
        

#         cls_tokens = self.cls_token.expand(embeddings.shape[0], -1, -1)
        
#         embeddings = torch.cat((cls_tokens,embeddings), dim=1)  
        
        
#         # return embeddings, mask, ids_restore,x_weighted
#         return embeddings, mask, ids_restore,first_end,pred_len
    
    
  
# #     def random_masking(self, sequence, noise=None):
# #             """
# #             Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsort random
# #             noise.

# #             Args:
# #                 sequence (`torch.LongTensor` of shape `(batch_size, sequence_length, dim)`)
# #                 noise (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*) which is
# #                     mainly used for testing purposes to control randomness and maintain the reproducibility
# #             """
# #             batch_size, seq_length, dim = sequence.shape
# #             len_keep = int(seq_length * (1 - self.config.mask_ratio))
# #             # len_keep = 35

# #             if noise is None:
# #                 noise = torch.rand(batch_size, seq_length, device=sequence.device)  # noise in [0, 1]
# #                 noise,_ = torch.sort(noise)

# #             # sort noise for each sample
# #             ids_shuffle = torch.argsort(noise, dim=1)  # ascend: small is keep, large is remove
# #             ids_restore = torch.argsort(ids_shuffle, dim=1)

# #             # keep the first subset
# #             ids_keep = ids_shuffle[:, :len_keep]
# #             sequence_unmasked = torch.gather(sequence, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, dim))

# #             # generate the binary mask: 0 is keep, 1 is remove
# #             mask = torch.ones([batch_size, seq_length], device=sequence.device)
# #             mask[:, :len_keep] = 0
# #             # unshuffle to get the binary mask
# #             mask = torch.gather(mask, dim=1, index=ids_restore)
            
# #             # perm = torch.randperm(len_keep)
# #             # # 根据随机顺序对第二维进行打乱
# #             # sequence_unmasked = sequence_unmasked[:, perm, :]

# #             return sequence_unmasked, mask, ids_restore   


#     def random_masking(self, sequence, first_end,pred_len,noise=None):
#             """
#             Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsort random
#             noise.

#             Args:
#                 sequence (`torch.LongTensor` of shape `(batch_size, sequence_length, dim)`)
#                 noise (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*) which is
#                     mainly used for testing purposes to control randomness and maintain the reproducibility
#             """
            
#             # mask_ratio = 0.2* torch.rand(1) + 0.3
#             # k = torch.rand(1)
            
            
#             all_len = first_end+pred_len
#             # batch_size, seq_length, dim = sequence.shape
            
# #             all_len =  randint(30,seq_length)  ##30-100
# #             pred_len = int(all_len/5)            ## 6-20
            
           
# #             first_end = all_len-pred_len ##24-80
            
#             # first_start = randint(0, max_len_seq-int(max_len_seq*0.7))
#             # first_end = randint(int(max_len_seq*0.7),max_len_seq)
            
#             sequence2 = sequence.clone()
#             batch_size, seq_length, dim = sequence2.shape
            
            
#             seq_len70 = first_end-0
#             seq_len30 = pred_len-0
            
#             seq_70 = sequence2[:,:seq_len70,:]
#             seq_30 = sequence2[:,seq_len70:,:]

#             if noise is None:
#                 noise = torch.rand(batch_size, seq_len70, device=sequence.device)  # noise in [0, 1]
#                 noise,_ = torch.sort(noise)


#             ids_shuffle_70 = torch.argsort(noise, dim=1)  # ascend: small is keep, large is remove
#             ids_restore_70 = torch.argsort(ids_shuffle_70, dim=1)
#             ids_restore_30 = torch.arange(seq_len70,seq_length,device=sequence.device)
#             ids_restore_30 = ids_restore_30.unsqueeze(0).repeat(batch_size, 1)
            
#             ids_restore = torch.cat((ids_restore_70,ids_restore_30),dim=1)
            
         
#             ids_keep_70 = ids_shuffle_70[:, :int(seq_len70*1)]
#             sequence_unmasked = torch.gather(seq_70, dim=1, index=ids_keep_70.unsqueeze(-1).repeat(1, 1, dim))
            
            
            

#             # generate the binary mask: 0 is keep, 1 is remove
# #             mask_70 = torch.ones([batch_size, seq_len70], device=sequence.device)
# #             mask_30 = torch.ones([batch_size, seq_len30], device=sequence.device)
# #             mask_70[:, :int(seq_len70*1)] = 0
            
# #             mask_70 = torch.gather(mask_70, dim=1, index=ids_restore_70)
            
# #             mask = torch.cat((mask_70,mask_30),dim=1)

#             mask = torch.ones([batch_size, all_len], device=sequence.device)
#             mask[:, :seq_len70] = 0
#             # unshuffle to get the binary mask
#             mask = torch.gather(mask, dim=1, index=ids_restore)

#             return sequence_unmasked, mask, ids_restore


# def create_attention_mask(seq_len, num_heads,num_masked,batch,training):
#     # 初始化为1（允许计算注意力）
#     mask = torch.zeros((seq_len, seq_len),dtype=torch.float32).to(device)
    
#     # mask[-num_masked:,:] = False
#     # for i in range(seq_len):
#     #     for j in range(seq_len):
#     #         if j>i:
#     #             mask[i,j] = torch.tensor(float('-inf'))
            

#     mask[:,-num_masked:] = torch.tensor(float('-inf'))
    
#     mask = mask.unsqueeze(0).repeat(num_heads, 1, 1)
    
#     return mask
 
    
# def create_attention_mask2(seq_len, num_heads,num_masked,batch,training):
#     # 初始化为1（允许计算注意力）
#     mask = torch.zeros((seq_len, seq_len),dtype=torch.float32).to(device)
    
#     # mask[-num_masked:,:] = False
#     for i in range(seq_len):
#         for j in range(seq_len):
#             if j>i:
#                 mask[i,j] = torch.tensor(float('-inf'))
            

   
    
#     mask = mask.unsqueeze(0).repeat(num_heads, 1, 1)
    
#     return mask

    
# class BrainLMEncoder(ViTMAEEncoder):###########################################################
#     def __init__(self, config):
#         super().__init__(config)
#         self.layer = nn.ModuleList(
#             [NystromformerLayer(config) for _ in range(config.num_hidden_layers)] ######################################## config.num_hidden_layers
#         )

#     def forward(
#         self,
#         hidden_states: torch.Tensor,
#         pred_len,
#         head_mask: Optional[torch.Tensor] = None,
#         output_attentions: bool = False,
#         output_hidden_states: bool = False,
#         return_dict: bool = True,
#     ) -> Union[tuple, BaseModelOutput]:
#         all_hidden_states = () if output_hidden_states else None
#         all_self_attentions = () if output_attentions else None
        
#         attention_mask = create_attention_mask2(hidden_states.shape[1], 8, pred_len,hidden_states.shape[0],self.training)################ seq_len,num_heads,no_att_len,batch_size

#         for i, layer_module in enumerate(self.layer):
#             if output_hidden_states:
#                 all_hidden_states = all_hidden_states + (hidden_states,)

#             # layer_head_mask = head_mask[i] if head_mask is not None else None

#             if self.gradient_checkpointing and self.training:

#                 def create_custom_forward(module):
#                     def custom_forward(*inputs):
#                         return module(*inputs, output_attentions)

#                     return custom_forward

#                 layer_outputs = torch.utils.checkpoint.checkpoint(
#                     create_custom_forward(layer_module),
#                     hidden_states,
#                     # layer_head_mask,  Nystromformer doesn't accept head_mask
#                 )
#             else:
#                 # layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
#                 # Nystromformer attention layer does not accept head_mask parameter
#                 layer_outputs = layer_module(
#                     hidden_states, attention_mask=None,output_attentions=output_attentions
#                 )

#             hidden_states = layer_outputs[0]

#             if output_attentions:
#                 all_self_attentions = all_self_attentions + (layer_outputs[1],)

#         if output_hidden_states:
#             all_hidden_states = all_hidden_states + (hidden_states,)

#         if not return_dict:
            
#             print("sssssssssssssssssss")
#             return tuple(
#                 v
#                 for v in [hidden_states, all_hidden_states, all_self_attentions]
#                 if v is not None
#             )
#         return BaseModelOutput(
#             last_hidden_state=hidden_states,
#             hidden_states=all_hidden_states,
#             attentions=all_self_attentions,
#         )


# class BrainLMModel(ViTMAEModel):
#     def __init__(self, config):
#         super().__init__(config)
#         self.embeddings = BrainLMEmbeddings(config)  
#         self.encoder = BrainLMEncoder(config)

#         self.post_init()

#     def forward(
#         self,
#         signal_vectors: torch.Tensor = None,
#         xyz_vectors: torch.Tensor = None,
#         head_mask: Optional[torch.FloatTensor] = None,
#         output_attentions: Optional[bool] = None,
#         output_hidden_states: Optional[bool] = None,
#         return_dict: Optional[bool] = None,
#         noise: Optional[bool] = None,
#     ) -> Union[Tuple, ViTMAEModelOutput]:

#         output_attentions = (
#             output_attentions
#             if output_attentions is not None
#             else self.config.output_attentions
#         )
#         output_hidden_states = (
#             output_hidden_states
#             if output_hidden_states is not None
#             else self.config.output_hidden_states
#         )
#         return_dict = (
#             return_dict if return_dict is not None else self.config.use_return_dict
#         )
#         head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)

        

         
        
#         embedding_output, mask, ids_restore,first_end,pred_len = self.embeddings(
#             signal_vectors,  xyz_vectors,noise
#         )

#         encoder_outputs = self.encoder(
#             embedding_output,
#             pred_len,
#             head_mask=head_mask,
#             output_attentions=output_attentions,
#             output_hidden_states=output_hidden_states,
#             return_dict=return_dict,
#         )
#         # print(encoder_outputs[0].shape)
#         # print(encoder_outputs[1][0])
#         sequence_output = encoder_outputs[0]
#         sequence_output = self.layernorm(sequence_output)

#         if not return_dict:
#             return (sequence_output, mask) + encoder_outputs[1:]

#         return ViTMAEModelOutput(
#             last_hidden_state=sequence_output,
#             mask=mask,
#             ids_restore=ids_restore,
#             hidden_states=encoder_outputs.hidden_states,
#             attentions=encoder_outputs.attentions,
#         ),first_end,pred_len




# class CustomTransformerLayer(nn.Module):
#     def __init__(self, embed_size, num_heads, forward_expansion, dropout_rate,training,seq_len):
#         super(CustomTransformerLayer, self).__init__()
#         self.embed_size = embed_size
#         self.num_heads = num_heads
#         self.dropout_rate = dropout_rate
        
       
        
#         # 多头注意力机制
#         if training:
#             self.attention = nn.MultiheadAttention(embed_size, num_heads, dropout=dropout_rate,batch_first=True)
#         else:
#             self.attention = nn.MultiheadAttention(embed_size, num_heads, dropout=dropout_rate,batch_first=True).eval()
        
#         # 前馈神经网络
#         self.feed_forward = nn.Sequential(
#             nn.Linear(embed_size, forward_expansion * embed_size),
#             nn.ReLU(),
#             nn.Linear(forward_expansion * embed_size, embed_size)
#         )
        
#         # Layer normalization
#         self.norm1 = nn.LayerNorm(embed_size)
#         self.norm2 = nn.LayerNorm(embed_size)
        
#         # Dropout
#         self.dropout = nn.Dropout(dropout_rate)

#     def forward(self, x, mask=None):
#         # x: (seq_len, batch_size, embed_size)
      
#         attention_output, attention_weights = self.attention(x, x, x, attn_mask=mask,average_attn_weights=False)
#         # mean_weight = torch.mean(attention_weights, dim=2, keepdim=True)
#         # mask = attention_weights >= mean_weight
#         # attention_weights = attention_weights * mask.float()
        
      
        
#         # 添加残差连接和LayerNorm
#         x = self.norm1(attention_output + x)
        
#         # 2. 前馈神经网络
#         forward_output = self.feed_forward(x)
        
#         # 添加残差连接和LayerNorm
#         out = self.norm2(forward_output + x)
        
#         return out, attention_weights
    


    

    
# class BrainLMDecoder(ViTMAEDecoder):
#     def __init__(self, config, num_patches):
#         super().__init__(config, num_patches)
#         self.decoder_pos_embed = None  # Not using positional embedding
#         self.num_brain_voxels = 1000
#         self.mask_ratio = config.mask_ratio
#         self.timepoint_patching_size = config.timepoint_patching_size
#         self.use_tanh_decoder = config.use_tanh_decoder
        
#         self.transformer_layer = CustomTransformerLayer(1024, config.decoder_num_attention_heads, 2, 0,self.training,51)  ### embed_size, num_heads, forward_expansion, dropout_rate,training


       
#         self.pos_embedding = PositionalEncoding_abs(d_model=config.hidden_size)

#         # Decoder Linear Attention Transformer Layers
#         decoder_config = deepcopy(config)
#         decoder_config.hidden_size = config.decoder_hidden_size
#         decoder_config.num_hidden_layers = config.decoder_num_hidden_layers
#         decoder_config.num_attention_heads = config.decoder_num_attention_heads
#         decoder_config.intermediate_size = config.decoder_intermediate_size
#         self.decoder_layers = nn.ModuleList(
#             [
#                 NystromformerLayer(decoder_config)
#                 for _ in range(config.decoder_num_hidden_layers)
#             ]
#         )
        
        
#         # self.decoder_layers = nn.ModuleList(
#         #     [
#         #         self.transformer_layer
#         #         for _ in range(config.decoder_num_hidden_layers)
#         #     ]
#         # )

       
#         # self.decoder_pred1 = nn.Linear(
#         #     in_features=256, 
#         #     out_features=512, 
#         #     bias=True,
#         # )
#         # self.decoder_pred_nonlinearity = nn.GELU()
#         # self.decoder_pred2 = nn.Linear(
#         #     in_features=1024, 
#         #     out_features=1000,     
#         #     bias=False,
#         # )
        
        
        
        
      

#         if self.use_tanh_decoder:
#             self.decoder_pred_nonlinearity2 = nn.Tanh()

#         self.initialize_weights(num_patches)

#     def initialize_weights(self, _):
#         # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
#         torch.nn.init.normal_(self.mask_token, std=self.config.initializer_range)

#     def forward(
#         self,
#         hidden_states,

#         ids_restore,
#         pred_len,
#         output_attentions=True,
#         output_hidden_states=True,
#         return_dict=True,
#     ):
#         # embed tokens
#         x = self.decoder_embed(hidden_states)

#         # Unflatten sequence
#         batch_size, flatten_seq_len, hidden_dim = x[:, 1:, :].shape
#         num_mask_tokens = ids_restore.shape[1] - flatten_seq_len

#         # append mask tokens to sequence
#         mask_tokens = self.mask_token.repeat(batch_size, num_mask_tokens, 1)
#         x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1)  # no cls token  
#         x_ = torch.gather(
#             x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, hidden_dim)
#         )  # unshuffle
#         # --> x_ is shape torch.Size([batch_size, num_voxels, num_tokens, hidden_size])

       
#         # x_ = torch.reshape(
#         #     x_, shape=(batch_size, self.num_brain_voxels,hidden_dim)
#         # )  # --> [batch_size, num_voxels, unmasked_timepoints_per_voxel, hidden_size]





#         # Add positional encoding for time signal
#         x_ = self.pos_embedding(x_)

#         # # Flatten again
#         # x_ = torch.flatten(x_, start_dim=1, end_dim=2)  # --> [batch, seq_len, dim]  

#         hidden_states = torch.cat([ x[:, :1, :],x_], dim=1)  # append cls token  
        
#         # print(hidden_states.shape)
        
       
#         attention_mask = create_attention_mask(hidden_states.shape[1], 8, pred_len,hidden_states.shape[0],self.training)################ seq_len,num_heads,no_att_len,batch_size
        
        
        
        
#         # hidden_states = self.pos_embedding(x)  # No positional embedding

#         # apply Transformer layers (blocks)
#         all_hidden_states = () if output_hidden_states else None
#         all_self_attentions = () if output_attentions else None
#         for i, layer_module in enumerate(self.decoder_layers):
#             if output_hidden_states:
#                 all_hidden_states = all_hidden_states + (hidden_states,)

#             if self.gradient_checkpointing and self.training:
                

#                 def create_custom_forward(module):
#                     def custom_forward(*inputs):
#                         return module(*inputs, output_attentions)

#                     return custom_forward

#                 layer_outputs = torch.utils.checkpoint.checkpoint(
#                     create_custom_forward(layer_module),
#                     hidden_states,
#                     # None,  Nystromformer layer does not accept argument head_mask
#                 )
#             else:
#                 # layer_outputs = layer_module(hidden_states, head_mask=None, output_attentions=output_attentions)
#                 # Nystromformer layer does not accept argument head_mask
#                 layer_outputs = layer_module(
#                     hidden_states, attention_mask=None,output_attentions=output_attentions
#                 )
#                 # if self.training:
#                 #     layer_outputs = layer_module(
#                 #         hidden_states,attention_mask
#                 #     )
#                 # else:
#                 #     # print("1111111111")
#                 #     layer_outputs = layer_module(
#                 #         hidden_states,attention_mask
#                 #     )

#             hidden_states = layer_outputs[0]

#             if output_attentions:
#                 all_self_attentions = all_self_attentions + (layer_outputs[1],)

#         if output_hidden_states:
#             all_hidden_states = all_hidden_states + (hidden_states,)

#         # hidden_states = self.decoder_norm(hidden_states)


#         # logits = self.decoder_pred2(hidden_states)

        
        
#         if self.use_tanh_decoder:
#             logits = self.decoder_pred_nonlinearity2(logits)


#         if not return_dict:
#             print("fffffffffff")
#             return tuple(
#                 v
#                 for v in [logits, all_hidden_states, all_self_attentions]
#                 if v is not None
#             )
#         return ViTMAEDecoderOutput(
#             logits=hidden_states,
#             hidden_states=all_hidden_states,
#             attentions=all_self_attentions,
#         )    


# class BrainLMForPretraining(ViTMAEForPreTraining):
#     """
#     Model definition is for pretraining on single-cell datasets. Will calculate loss on forward
#     pass through model.
#     """

#     def __init__(self, config):
#         super().__init__(config)
#         self.vit = BrainLMModel(config)
#         self.decoder = BrainLMDecoder(
#             config, num_patches=self.vit.embeddings.num_patches
#         )

        
        
        
#         self.decoder_pred = nn.Linear(
#             in_features=1024, 
#             out_features=1000,     
#             bias=True,
#         )

#         self.post_init()

#     def init_weights(self):
        
#         # Prune heads if needed
#         if self.config.pruned_heads:
#             self.prune_heads(self.config.pruned_heads)

#         # Initialize weights
#         self.apply(self._initialize_weights)

#         # Tie weights should be skipped when not initializing all weights
#         # since from_pretrained(...) calls tie weights anyways
#         self.tie_weights()

#     def _init_weights(self, module):  #
#         if isinstance(module, nn.Linear):
#             # we use xavier_uniform following official JAX ViT:
#             torch.nn.init.xavier_uniform_(module.weight)
#             # torch.nn.init.kaiming_uniform_(module.weight)
#             if isinstance(module, nn.Linear) and module.bias is not None:
#                 nn.init.constant_(module.bias, 0)
#             #     torch.nn.init.xavier_uniform_(module.bias)
#         elif isinstance(module, nn.LayerNorm):
#             nn.init.constant_(module.bias, 0)
#             nn.init.constant_(module.weight, 1.0)
#         elif isinstance(module, nn.Embedding):
#             torch.nn.init.kaiming_uniform_(module.weight)

#     def forward_loss(self, signal_values, pred_values,mask):
       
#         assert signal_values.shape == pred_values.shape
        
#         # mask = torch.ones(pred_values.shape,device=pred_values.device)

#         if self.config.loss_fn == "mse":
#             loss = (
#                 ((pred_values - signal_values) ** 2) * mask
#             ).sum() / mask.sum()  # MSE
            
#             # loss = abs((pred_values - signal_values) * mask).sum() / mask.sum()  # MAE
            
#         elif self.config.loss_fn == "mae":
#             loss = abs((pred_values - signal_values) * mask).sum() / mask.sum()  # MAE
#         else:
#             raise NotImplementedError("Unknown loss function specified.")

#         return loss


     
    

#     def forward(
#         self,
#         signal_vectors: torch.Tensor = None,
#         signal_vectors1: torch.Tensor = None,
#         xyz_vectors: torch.Tensor = None,
#         labels: torch.Tensor = None,  # not used
#         input_ids: torch.Tensor = None,  # not used, 
#         head_mask: Optional[torch.FloatTensor] = None,
#         output_attentions: Optional[bool] = None,  ########## train to drop
#         output_hidden_states: Optional[bool] = None,
#         return_dict: Optional[bool] = None,
#         noise: Optional[bool] = None,
#     ) -> Union[Tuple, ViTMAEForPreTrainingOutput]:

#         return_dict = (
#             return_dict if return_dict is not None else self.config.use_return_dict
#         )
     
#         # Encoder will perform BrainLM fmri embedding rather than VitMAE Image Embedding
#         outputs = self.vit(
#             signal_vectors=signal_vectors,
#             xyz_vectors=xyz_vectors,
#             head_mask=head_mask,
#             output_attentions=output_attentions,
#             output_hidden_states=output_hidden_states,
#             return_dict=return_dict,
#             noise=noise,
#         ) 
      
#         outputs1 = outputs[0]
#         first_end = outputs[1]
#         pred_len = outputs[2]
   
#         ids_restore = outputs1.ids_restore
#         mask = outputs1.mask
#         latent_all = outputs1.hidden_states
#         latent = outputs1.last_hidden_state  ###  b,2604*6+1,256
        
#         encoder_attentions = outputs1.attentions
        
        
#         # encoder_out = self.decoder_pred(latent)
    
        
#         decoder_outputs = self.decoder(latent,  ids_restore,pred_len)
#         logits = (
#             decoder_outputs.logits  ###b,424,100
#         )  # 
        
#         logits = self.decoder_pred(logits)[:,:-1,:]
#         mask2 = mask.unsqueeze(-1).repeat(1,1, 1000)
        
        
        
#         mask4 = mask2.clone()
#         mask4 = mask4-1
#         mask4 = torch.where(mask4 == -1, torch.tensor(1), mask4)
        
        
#         loss =  0.75*self.forward_loss(signal_vectors, logits,mask2)+0.25*self.forward_loss(signal_vectors, logits,mask4)
#         #0.1*self.forward_loss(seq[:,:first_end,:], logits[:,:first_end,:],mask2[:,:first_end,:])  +
        
#         # logits3 = torch.cat((logits,torch.zeros(signal_vectors.shape[0],100-(first_end+pred_len),1000,device=signal_vectors.device)),dim=1)
        
        
#         # print(signal_vectors.shape)


       
        

#         # mask4 = torch.cat((mask2,torch.zeros(signal_vectors.shape[0],100-(first_end+pred_len),1000,device=signal_vectors.device)),dim=1)
        

    
#         mask3 = mask2.transpose(2,1)


#         if not return_dict:
#             output = (logits1, mask) + outputs[2:]
#             return ((loss,) + output) if loss is not None else output

        
#         return ViTMAEForPreTrainingOutput(
#             loss=loss,
#             logits=(logits, latent),
#             mask=mask3,
#             hidden_states=outputs[0].hidden_states,
#             attentions=outputs[0].attentions,
#         )
        
#         # return ViTMAEForPreTrainingOutput(
#         #     loss=None,
#         #     logits=(logits,latent,decoder_outputs.attentions),
#         #     mask=mask3,
#         #     hidden_states=None,
#         #     attentions=decoder_outputs.attentions,
#         # )
#         # return ViTMAEForPreTrainingOutput(
#         #     loss=None,
#         #     logits=(logits,latent,encoder_attentions,weighted,logits3,decoder_outputs.hidden_states,latent_all),
#         #     mask=mask3,
#         #     hidden_states=None,
#         #     attentions=decoder_outputs.attentions,
#         # )
    
    
    
    
    
    
    

















import math
from copy import deepcopy
import random
import torch
import torch.nn as nn
from torch import Tensor
from typing import Optional, Union, Tuple
import torch.nn.functional as F
from transformers.models.vit_mae.modeling_vit_mae import (
    ViTMAEForPreTraining,
    ViTMAEEncoder,
    ViTMAEModel,
    ViTMAEEmbeddings,
    ViTMAEForPreTrainingOutput,
    ViTMAEModelOutput,
    ViTMAEDecoder,
    ViTMAEDecoderOutput,
)
from transformers.models.nystromformer.modeling_nystromformer import NystromformerLayer
from transformers.modeling_outputs import BaseModelOutput
from random import randint
from brainlm_mae.configuration_brainlm import BrainLMConfig
from nitime.timeseries import TimeSeries
from nitime.analysis import SpectralAnalyzer, FilterAnalyzer, NormalizationAnalyzer
import numpy as np



device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class PositionalEncoding_abs(nn.Module):
    """
    Positional Encoding module from PyTorch tutorial: [link]

    Positional Encoding Formula:
    - PE(pos, 2i) = sin(pos / ( 10000^{2i/d_model} ))  # Even dimensions = sin frequency
    - PE(pos, 2i+1) = cos(pos / ( 10000^{2i/d_model} ))  # Odd dimensions = cosine frequency

    10000 is a user-defined variable, chosen as 10000 by authors of original Transformer paper
    - Scaling by 1/10000 makes 1 cycle very long => guarantees unique positional encodings
    - If you plot sin(x / 10000), takes > 60k to complete 1 cycle of sin
    """

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 512):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        # div_term creates this part of expression: ( 10000^{2i/d_model} )
        div_term = torch.exp(
            torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
        )

        position = torch.arange(max_len).unsqueeze(1)  # 0 through 5000
        pe = torch.zeros(max_len, d_model)  # [max_seq_len, hidden_size]
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer("pe", pe)

    def forward(self, x: Tensor) -> Tensor:
        """
        Args:
            x: Tensor, shape [seq_len, batch_size, embedding_dim]
        """
        # print(self.pe)
        pos_encoding = self.pe[: x.size(1)]  # shape [seq_len, 1, embedding_dim]
        pos_encoding = (
            pos_encoding.unsqueeze(0).repeat(x.shape[0], 1, 1)
        )
        x = x + pos_encoding
        return x
        
        # x = x + self.pe[:, :x.size(1)]
        # return self.dropout(x)
    

    
    
    
    
# class PositionalEncoding_abs(nn.Module):
#     def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 512):
#         super(PositionalEncoding_abs, self).__init__()
#         self.position_embedding = nn.Parameter(torch.empty(max_len, d_model))
#         self.dropout = nn.Dropout(p=dropout)
#         self.init_weights()
        

#     def init_weights(self):
#         #nn.init.uniform_(self.position_embedding, a=-0.1, b=0.1)
#         nn.init.xavier_uniform_(self.position_embedding)

#     def forward(self, x):
#         # Assuming x is of shape (batch_size, seq_len, embedding_dim)
#         seq_len = x.size(1)
#         pos_encoding = self.position_embedding[:seq_len, :].unsqueeze(0).repeat(x.shape[0], 1, 1)
#         return x + pos_encoding
    
class BrainLMEmbeddings(ViTMAEEmbeddings):
    """
    Construct the CLS token, gene index and cell index embeddings.
    """

    def __init__(self, config):
        super().__init__(config)
        self.patch_embeddings = None
        self.position_embeddings = None
        self.num_brain_voxels = 1000
        self.num_timepoints_per_voxel = 200
        self.mask_ratio = 0.7
        self.tokens = 1000
        self.pos_embedding = PositionalEncoding_abs(d_model=config.hidden_size)
        
        
      
        
        self.signal_embedding_projection = nn.Linear(
            
            1000, config.hidden_size, bias=True     
        )
        
        self.signal_embedding_projection2 = nn.Linear(
            
            config.hidden_size, config.hidden_size, bias=False     
        )
        
        self.region_weight_net = nn.Sequential(
            nn.Linear(config.hidden_size, config.hidden_size, bias=True ),
            nn.ReLU(),
            nn.Linear(config.hidden_size, config.hidden_size, bias=True ),
            nn.Softmax(dim=-1)  # 生成权重向量，使用softmax保证归一化
        )

       

    def initialize_weights(self):
        torch.nn.init.normal_(self.cls_token, std=self.config.initializer_range)

    def forward(self, signal_vectors,  xyz_vectors, noise):
        """
        :param signal_vectors: torch tensor of shape [batch, num_voxels, num_timepoints_per_voxel]

        :param noise: torch tensor of noise for reproducibility, e.g. torch.rand(batch_size, seq_length, device=sequence.device)
        :return:
            embeddings: [batch, num_voxels * num_unmasked_patch_tokens + 1 CLS token, hidden_size]
            mask: [batch, num_voxels, num_patch_tokens]
            ids_restore: [batch, num_voxels, num_patch_tokens]
        """
        
        batch, num_timepoints_per_node,num_voxels = signal_vectors.shape  

        # pred_len = int(num_timepoints_per_node/5)            ## 6-20
        pred_len = int(num_timepoints_per_node*0.3) 
        first_end = num_timepoints_per_node-pred_len ##24-80
        
        
        
       
        x = self.signal_embedding_projection(signal_vectors)
        
        x = self.pos_embedding(x)  #### b,2604,1,256
        
        brain_region_weights = self.region_weight_net(x)
        x_weighted = x * brain_region_weights
        
        
       
        
        embeddings, mask, ids_restore = self.random_masking(x_weighted, first_end,pred_len,noise=noise)
        
        

        cls_tokens = self.cls_token.expand(embeddings.shape[0], -1, -1)
        
        embeddings = torch.cat((cls_tokens, embeddings), dim=1)  
        
        
        # return embeddings, mask, ids_restore,x_weighted
        return embeddings, mask, ids_restore,first_end,pred_len
    
    
  
#     def random_masking(self, sequence, noise=None):
#             """
#             Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsort random
#             noise.

#             Args:
#                 sequence (`torch.LongTensor` of shape `(batch_size, sequence_length, dim)`)
#                 noise (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*) which is
#                     mainly used for testing purposes to control randomness and maintain the reproducibility
#             """
#             batch_size, seq_length, dim = sequence.shape
#             len_keep = int(seq_length * (1 - self.config.mask_ratio))
#             # len_keep = 35

#             if noise is None:
#                 noise = torch.rand(batch_size, seq_length, device=sequence.device)  # noise in [0, 1]
#                 noise,_ = torch.sort(noise)

#             # sort noise for each sample
#             ids_shuffle = torch.argsort(noise, dim=1)  # ascend: small is keep, large is remove
#             ids_restore = torch.argsort(ids_shuffle, dim=1)

#             # keep the first subset
#             ids_keep = ids_shuffle[:, :len_keep]
#             sequence_unmasked = torch.gather(sequence, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, dim))

#             # generate the binary mask: 0 is keep, 1 is remove
#             mask = torch.ones([batch_size, seq_length], device=sequence.device)
#             mask[:, :len_keep] = 0
#             # unshuffle to get the binary mask
#             mask = torch.gather(mask, dim=1, index=ids_restore)
            
#             # perm = torch.randperm(len_keep)
#             # # 根据随机顺序对第二维进行打乱
#             # sequence_unmasked = sequence_unmasked[:, perm, :]

#             return sequence_unmasked, mask, ids_restore   


    def random_masking(self, sequence, first_end,pred_len,noise=None):
            """
            Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsort random
            noise.

            Args:
                sequence (`torch.LongTensor` of shape `(batch_size, sequence_length, dim)`)
                noise (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*) which is
                    mainly used for testing purposes to control randomness and maintain the reproducibility
            """
            
            # mask_ratio = 0.2* torch.rand(1) + 0.3
            # k = torch.rand(1)
            
            
            all_len = first_end+pred_len
            # batch_size, seq_length, dim = sequence.shape
            
#             all_len =  randint(30,seq_length)  ##30-100
#             pred_len = int(all_len/5)            ## 6-20
            
           
#             first_end = all_len-pred_len ##24-80
            
            # first_start = randint(0, max_len_seq-int(max_len_seq*0.7))
            # first_end = randint(int(max_len_seq*0.7),max_len_seq)
            
            sequence2 = sequence.clone()
            batch_size, seq_length, dim = sequence2.shape
            
            
            seq_len70 = first_end-0
            seq_len30 = pred_len-0
            
            seq_70 = sequence2[:,:seq_len70,:]
            seq_30 = sequence2[:,seq_len70:,:]

            if noise is None:
                noise = torch.rand(batch_size, seq_len70, device=sequence.device)  # noise in [0, 1]
                noise,_ = torch.sort(noise)


            ids_shuffle_70 = torch.argsort(noise, dim=1)  # ascend: small is keep, large is remove
            ids_restore_70 = torch.argsort(ids_shuffle_70, dim=1)
            ids_restore_30 = torch.arange(seq_len70,seq_length,device=sequence.device)
            ids_restore_30 = ids_restore_30.unsqueeze(0).repeat(batch_size, 1)
            
            ids_restore = torch.cat((ids_restore_70,ids_restore_30),dim=1)
            
         
            ids_keep_70 = ids_shuffle_70[:, :int(seq_len70*1)]
            sequence_unmasked = torch.gather(seq_70, dim=1, index=ids_keep_70.unsqueeze(-1).repeat(1, 1, dim))
            
            
            

            # generate the binary mask: 0 is keep, 1 is remove
#             mask_70 = torch.ones([batch_size, seq_len70], device=sequence.device)
#             mask_30 = torch.ones([batch_size, seq_len30], device=sequence.device)
#             mask_70[:, :int(seq_len70*1)] = 0
            
#             mask_70 = torch.gather(mask_70, dim=1, index=ids_restore_70)
            
#             mask = torch.cat((mask_70,mask_30),dim=1)

            mask = torch.ones([batch_size, all_len], device=sequence.device)
            mask[:, :seq_len70] = 0
            # unshuffle to get the binary mask
            mask = torch.gather(mask, dim=1, index=ids_restore)

            return sequence_unmasked, mask, ids_restore



    
    
class BrainLMEncoder(ViTMAEEncoder):####################################################################################################################################
    def __init__(self, config):
        super().__init__(config)
        self.layer = nn.ModuleList(
            [NystromformerLayer(config) for _ in range(config.num_hidden_layers)] ######################################## config.num_hidden_layers
        )

    def forward(
        self,
        hidden_states: torch.Tensor,
        head_mask: Optional[torch.Tensor] = None,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
    ) -> Union[tuple, BaseModelOutput]:
        all_hidden_states = () if output_hidden_states else None
        all_self_attentions = () if output_attentions else None

        for i, layer_module in enumerate(self.layer):
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

            # layer_head_mask = head_mask[i] if head_mask is not None else None

            if self.gradient_checkpointing and self.training:

                def create_custom_forward(module):
                    def custom_forward(*inputs):
                        return module(*inputs, output_attentions)

                    return custom_forward

                layer_outputs = torch.utils.checkpoint.checkpoint(
                    create_custom_forward(layer_module),
                    hidden_states,
                    # layer_head_mask,  Nystromformer doesn't accept head_mask
                )
            else:
                # layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
                # Nystromformer attention layer does not accept head_mask parameter
                layer_outputs = layer_module(
                    hidden_states, output_attentions=output_attentions
                )

            hidden_states = layer_outputs[0]

            if output_attentions:
                all_self_attentions = all_self_attentions + (layer_outputs[1],)

        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        if not return_dict:
            
            print("sssssssssssssssssss")
            return tuple(
                v
                for v in [hidden_states, all_hidden_states, all_self_attentions]
                if v is not None
            )
        return BaseModelOutput(
            last_hidden_state=hidden_states,
            hidden_states=all_hidden_states,
            attentions=all_self_attentions,
        )


class BrainLMModel(ViTMAEModel):
    def __init__(self, config):
        super().__init__(config)
        self.embeddings = BrainLMEmbeddings(config)  
        self.encoder = BrainLMEncoder(config)

        self.post_init()

    def forward(
        self,
        signal_vectors: torch.Tensor = None,
        xyz_vectors: torch.Tensor = None,
        head_mask: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        noise: Optional[bool] = None,
    ) -> Union[Tuple, ViTMAEModelOutput]:

        output_attentions = (
            output_attentions
            if output_attentions is not None
            else self.config.output_attentions
        )
        output_hidden_states = (
            output_hidden_states
            if output_hidden_states is not None
            else self.config.output_hidden_states
        )
        return_dict = (
            return_dict if return_dict is not None else self.config.use_return_dict
        )
        head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)

        

         
        
        embedding_output, mask, ids_restore,first_end,pred_len = self.embeddings(
            signal_vectors,  xyz_vectors,noise
        )

        encoder_outputs = self.encoder(
            embedding_output,
            head_mask=head_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        # print(encoder_outputs[0].shape)
        # print(encoder_outputs[1][0])
        sequence_output = encoder_outputs[0]
        sequence_output = self.layernorm(sequence_output)

        if not return_dict:
            return (sequence_output, mask) + encoder_outputs[1:]

        return ViTMAEModelOutput(
            last_hidden_state=sequence_output,
            mask=mask,
            ids_restore=ids_restore,
            hidden_states=encoder_outputs.hidden_states,
            attentions=encoder_outputs.attentions,
        ),first_end,pred_len




class CustomTransformerLayer(nn.Module):
    def __init__(self, embed_size, num_heads, forward_expansion, dropout_rate,training,seq_len):
        super(CustomTransformerLayer, self).__init__()
        self.embed_size = embed_size
        self.num_heads = num_heads
        self.dropout_rate = dropout_rate
        
       
        
        # 多头注意力机制
        if training:
            self.attention = nn.MultiheadAttention(embed_size, num_heads, dropout=dropout_rate,batch_first=True)
        else:
            self.attention = nn.MultiheadAttention(embed_size, num_heads, dropout=dropout_rate,batch_first=True).eval()
        
        # 前馈神经网络
        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, forward_expansion * embed_size),
            nn.ReLU(),
            nn.Linear(forward_expansion * embed_size, embed_size)
        )
        
        # Layer normalization
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)
        
        # Dropout
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x, mask=None):
        # x: (seq_len, batch_size, embed_size)
      
        attention_output, attention_weights = self.attention(x, x, x, attn_mask=mask,average_attn_weights=False)
        # mean_weight = torch.mean(attention_weights, dim=2, keepdim=True)
        # mask = attention_weights >= mean_weight
        # attention_weights = attention_weights * mask.float()
        
      
        
        # 添加残差连接和LayerNorm
        x = self.norm1(attention_output + x)
        
        # 2. 前馈神经网络
        forward_output = self.feed_forward(x)
        
        # 添加残差连接和LayerNorm
        out = self.norm2(forward_output + x)
        
        return out, attention_weights
    
def create_attention_mask(seq_len, num_heads,num_masked,batch,training):
    # 初始化为1（允许计算注意力）
    mask = torch.zeros((seq_len, seq_len),dtype=torch.bool).to(device)
    
    # mask[-num_masked:,:] = False
    # for i in range(seq_len):
    #     for j in range(i+1):
    #         mask[i,j] = False
    mask[:,-num_masked:] = True
    
    mask = mask.unsqueeze(0).repeat(batch*num_heads, 1, 1)
    
    return mask

    

    
class BrainLMDecoder(ViTMAEDecoder):
    def __init__(self, config, num_patches):
        super().__init__(config, num_patches)
        self.decoder_pos_embed = None  # Not using positional embedding
        self.num_brain_voxels = 1000
        self.mask_ratio = config.mask_ratio
        self.timepoint_patching_size = config.timepoint_patching_size
        self.use_tanh_decoder = config.use_tanh_decoder
        
        self.transformer_layer = CustomTransformerLayer(512, config.decoder_num_attention_heads, 2, 0,self.training,51)  ### embed_size, num_heads, forward_expansion, dropout_rate,training


       
        self.pos_embedding = PositionalEncoding_abs(d_model=config.hidden_size)

        # Decoder Linear Attention Transformer Layers
        decoder_config = deepcopy(config)
        decoder_config.hidden_size = config.decoder_hidden_size
        decoder_config.num_hidden_layers = config.decoder_num_hidden_layers
        decoder_config.num_attention_heads = config.decoder_num_attention_heads
        decoder_config.intermediate_size = config.decoder_intermediate_size
        self.decoder_layers = nn.ModuleList(
            [
                NystromformerLayer(decoder_config)
                for _ in range(config.decoder_num_hidden_layers)
            ]
        )
        
        
        # self.decoder_layers = nn.ModuleList(
        #     [
        #         self.transformer_layer
        #         for _ in range(config.decoder_num_hidden_layers)
        #     ]
        # )

       
        # self.decoder_pred1 = nn.Linear(
        #     in_features=256, 
        #     out_features=512, 
        #     bias=True,
        # )
        # self.decoder_pred_nonlinearity = nn.GELU()
        # self.decoder_pred2 = nn.Linear(
        #     in_features=1024, 
        #     out_features=1000,     
        #     bias=False,
        # )
        
        
        
        
      

        if self.use_tanh_decoder:
            self.decoder_pred_nonlinearity2 = nn.Tanh()

        self.initialize_weights(num_patches)

    def initialize_weights(self, _):
        # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
        torch.nn.init.normal_(self.mask_token, std=self.config.initializer_range)

    def forward(
        self,
        hidden_states,

        ids_restore,
        pred_len,
        output_attentions=True,
        output_hidden_states=True,
        return_dict=True,
    ):
        # embed tokens
        x = self.decoder_embed(hidden_states)

        # Unflatten sequence
        batch_size, flatten_seq_len, hidden_dim = x[:, 1:, :].shape
        num_mask_tokens = ids_restore.shape[1] - flatten_seq_len

        # append mask tokens to sequence
        mask_tokens = self.mask_token.repeat(batch_size, num_mask_tokens, 1)
        x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1)  # no cls token  
        x_ = torch.gather(
            x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, hidden_dim)
        )  # unshuffle
        # --> x_ is shape torch.Size([batch_size, num_voxels, num_tokens, hidden_size])

       
        # x_ = torch.reshape(
        #     x_, shape=(batch_size, self.num_brain_voxels,hidden_dim)
        # )  # --> [batch_size, num_voxels, unmasked_timepoints_per_voxel, hidden_size]





        # Add positional encoding for time signal
        x_ = self.pos_embedding(x_)

        # # Flatten again
        # x_ = torch.flatten(x_, start_dim=1, end_dim=2)  # --> [batch, seq_len, dim]  

        hidden_states = torch.cat([x[:, :1, :], x_], dim=1)  # append cls token  
        
        # print(hidden_states.shape)
        
       
        attention_mask = create_attention_mask(hidden_states.shape[1], 8, pred_len,hidden_states.shape[0],self.training)################ seq_len,num_heads,no_att_len,batch_size
        
        
        
        
        # hidden_states = self.pos_embedding(x)  # No positional embedding

        # apply Transformer layers (blocks)
        all_hidden_states = () if output_hidden_states else None
        all_self_attentions = () if output_attentions else None
        for i, layer_module in enumerate(self.decoder_layers):
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

            if self.gradient_checkpointing and self.training:
                

                def create_custom_forward(module):
                    def custom_forward(*inputs):
                        return module(*inputs, output_attentions)

                    return custom_forward

                layer_outputs = torch.utils.checkpoint.checkpoint(
                    create_custom_forward(layer_module),
                    hidden_states,
                    # None,  Nystromformer layer does not accept argument head_mask
                )
            else:
                # layer_outputs = layer_module(hidden_states, head_mask=None, output_attentions=output_attentions)
                # Nystromformer layer does not accept argument head_mask
                layer_outputs = layer_module(
                    hidden_states, output_attentions=output_attentions
                )
                # if self.training:
                #     layer_outputs = layer_module(
                #         hidden_states,
                #     )
                # else:
                #     # print("1111111111")
                #     layer_outputs = layer_module(
                #         hidden_states
                #     )

            hidden_states = layer_outputs[0]

            if output_attentions:
                all_self_attentions = all_self_attentions + (layer_outputs[1],)

        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        # hidden_states = self.decoder_norm(hidden_states)


        # logits = self.decoder_pred2(hidden_states)

        
        
        if self.use_tanh_decoder:
            logits = self.decoder_pred_nonlinearity2(logits)


        if not return_dict:
            print("fffffffffff")
            return tuple(
                v
                for v in [logits, all_hidden_states, all_self_attentions]
                if v is not None
            )
        return ViTMAEDecoderOutput(
            logits=hidden_states,
            hidden_states=all_hidden_states,
            attentions=all_self_attentions,
        )    


class BrainLMForPretraining(ViTMAEForPreTraining):
    """
    Model definition is for pretraining on single-cell datasets. Will calculate loss on forward
    pass through model.
    """

    def __init__(self, config):
        super().__init__(config)
        self.vit = BrainLMModel(config)
        self.decoder = BrainLMDecoder(
            config, num_patches=self.vit.embeddings.num_patches
        )

        
        
        
        self.decoder_pred = nn.Linear(
            in_features=1024, 
            out_features=1000,     
            bias=True,
        )

        self.post_init()

    def init_weights(self):
        
        # Prune heads if needed
        if self.config.pruned_heads:
            self.prune_heads(self.config.pruned_heads)

        # Initialize weights
        self.apply(self._initialize_weights)

        # Tie weights should be skipped when not initializing all weights
        # since from_pretrained(...) calls tie weights anyways
        self.tie_weights()

    def _init_weights(self, module):  #
        if isinstance(module, nn.Linear):
            # we use xavier_uniform following official JAX ViT:
            torch.nn.init.xavier_uniform_(module.weight)
            # torch.nn.init.kaiming_uniform_(module.weight)
            if isinstance(module, nn.Linear) and module.bias is not None:
                nn.init.constant_(module.bias, 0)
            #     torch.nn.init.xavier_uniform_(module.bias)
        elif isinstance(module, nn.LayerNorm):
            nn.init.constant_(module.bias, 0)
            nn.init.constant_(module.weight, 1.0)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.kaiming_uniform_(module.weight)

    def forward_loss(self, signal_values, pred_values,mask):
       
        assert signal_values.shape == pred_values.shape
        
        # mask = torch.ones(pred_values.shape,device=pred_values.device)

        if self.config.loss_fn == "mse":
            loss = (
                ((pred_values - signal_values) ** 2) * mask
            ).sum() / mask.sum()  # MSE
            
            # loss = abs((pred_values - signal_values) * mask).sum() / mask.sum()  # MAE
            
        elif self.config.loss_fn == "mae":
            loss = abs((pred_values - signal_values) * mask).sum() / mask.sum()  # MAE
        else:
            raise NotImplementedError("Unknown loss function specified.")

        return loss


     
    

    def forward(
        self,
        signal_vectors: torch.Tensor = None,
        signal_vectors1: torch.Tensor = None,
        xyz_vectors: torch.Tensor = None,
        labels: torch.Tensor = None,  # not used
        input_ids: torch.Tensor = None,  # not used, 
        head_mask: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = None,  ########################################### train to drop
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        noise: Optional[bool] = None,
    ) -> Union[Tuple, ViTMAEForPreTrainingOutput]:

        return_dict = (
            return_dict if return_dict is not None else self.config.use_return_dict
        )
     
        # Encoder will perform BrainLM fmri embedding rather than VitMAE Image Embedding
        outputs = self.vit(
            signal_vectors=signal_vectors,
            xyz_vectors=xyz_vectors,
            head_mask=head_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            noise=noise,
        ) 
      
        outputs1 = outputs[0]
        first_end = outputs[1]
        pred_len = outputs[2]
   
        ids_restore = outputs1.ids_restore
        mask = outputs1.mask
        latent_all = outputs1.hidden_states
        latent = outputs1.last_hidden_state  ###  b,2604*6+1,256
        
        encoder_attentions = outputs1.attentions
        
        
        # encoder_out = self.decoder_pred(latent)
    
        
        decoder_outputs = self.decoder(latent,  ids_restore,pred_len)
        logits = (
            decoder_outputs.logits  ###b,424,100
        )  # 
        
        logits = self.decoder_pred(logits)[:,1:,:]
        mask2 = mask.unsqueeze(-1).repeat(1,1, 1000)
        
        
        
        mask4 = mask2.clone()
        mask4 = mask4-1
        mask4 = torch.where(mask4 == -1, torch.tensor(1), mask4)
        
        
        loss =  0.75*self.forward_loss(signal_vectors, logits,mask2)+0.25*self.forward_loss(signal_vectors, logits,mask4)
        #0.1*self.forward_loss(seq[:,:first_end,:], logits[:,:first_end,:],mask2[:,:first_end,:])  +
        
        # logits3 = torch.cat((logits,torch.zeros(signal_vectors.shape[0],100-(first_end+pred_len),1000,device=signal_vectors.device)),dim=1)
        
        
        # print(signal_vectors.shape)


       
        

        # mask4 = torch.cat((mask2,torch.zeros(signal_vectors.shape[0],100-(first_end+pred_len),1000,device=signal_vectors.device)),dim=1)
        

    
        mask3 = mask2.transpose(2,1)


        if not return_dict:
            output = (logits1, mask) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        
        return ViTMAEForPreTrainingOutput(
            loss=loss,
            logits=(logits, latent),
            mask=mask3,
            hidden_states=outputs[0].hidden_states,
            attentions=outputs[0].attentions,
        )
        
        # return ViTMAEForPreTrainingOutput(
        #     loss=None,
        #     logits=(logits,latent,decoder_outputs.attentions),
        #     mask=mask3,
        #     hidden_states=None,
        #     attentions=decoder_outputs.attentions,
        # )
        # return ViTMAEForPreTrainingOutput(
        #     loss=None,
        #     logits=(logits,latent,encoder_attentions,weighted,logits3,decoder_outputs.hidden_states,latent_all),
        #     mask=mask3,
        #     hidden_states=None,
        #     attentions=decoder_outputs.attentions,
        # )

        
        return ViTMAEForPreTrainingOutput(
            loss=loss,
            logits=(logits, latent),
            mask=mask3,
            hidden_states=outputs[0].hidden_states,
            attentions=outputs[0].attentions,
        )
        
        # return ViTMAEForPreTrainingOutput(
        #     loss=None,
        #     logits=(logits,latent,decoder_outputs.attentions),
        #     mask=mask3,
        #     hidden_states=None,
        #     attentions=decoder_outputs.attentions,
        # )
        # return ViTMAEForPreTrainingOutput(
        #     loss=None,
        #     logits=(logits,latent,encoder_attentions,weighted,logits3,decoder_outputs.hidden_states,latent_all),
        #     mask=mask3,
        #     hidden_states=None,
        #     attentions=decoder_outputs.attentions,
        # )
    
    
    
    
    
    
    
