import torch 
import torch.nn as nn   

####
# Forward for ensemble pruned
####

def forward_pruned_layers(self, input):
    """
    Forward pass for pruned layers in an ensemble model.
    """
    #B, T, D = input.shape
    num_models = self[0].self_attention.num_models   
    input = input.unsqueeze(1).repeat(1, num_models, 1, 1)
    for module in self:
        input = module(input)
    return input

def forward_pruned_encoder(self, input: torch.Tensor):
    """
    Forward pass for a trasformer block in an ensemble model.
    """
    torch._assert(input.dim() == 4, f"Expected (batch_size, num_models, seq_length, hidden_dim) got {input.shape}")
    B,M,T,D = input.shape 
    x = self.ln_1(input)
    x, _ = self.self_attention(x, x, x, need_weights=False)
    x = self.dropout(x)
    x = x + input

    y = self.ln_2(x)
    y = self.mlp(y)
    return x + y

def forward_mlp_block_pruned(self, input, num_models):
        """
        Old sequential forward pass for a MLP block in an ensemble model. 
        """
        seq_length = input.shape[1] // num_models
        for module in self:
            if isinstance(module, nn.Linear):
                outputs = []
                for i in range(num_models):
                    out = module(input[:, i*seq_length:(i+1)*seq_length, :])
                    outputs.append(out)
                input = torch.concat(outputs, dim=1)
            else:
                input = module(input)
        return input

####
# Forward for token pruning
####
def pruned_token_forward(self, input: torch.Tensor):
    torch._assert(input.dim() == 3, "Input tensor must be 3D")
    x = self.ln_1(input)
    x, _ = self.self_attention(x, x, x, need_weights=False)
    x = self.dropout(x)

    tokens_kept = self.self_attention.tokens_kept.unsqueeze(-1).expand(-1, -1, self.self_attention.embed_dim)
    input = torch.gather(input, dim=1, index=tokens_kept)
    x = x + input 

    y = self.ln_2(x)
    y = self.mlp(x)

    return x + y

def deit_pruned_token_forward(self, hidden_states: torch.Tensor, head_mask = None, output_attentions = False, return_dict = True):
    torch._assert(hidden_states.dim() == 3, "Input tensor must be 3D")
    x, attn_weights = self.attention(self.layernorm_before(hidden_states), 
                                     hidden_states,
                                     hidden_states, 
                                     need_weights=output_attentions)
    #x = self.attention.attention.dropout(x)
    tokens_kept = self.attention.tokens_kept.unsqueeze(-1).expand(-1, -1, x.shape[-1])
    hidden_states = torch.gather(hidden_states, dim=1, index=tokens_kept)
    hidden_states = x + hidden_states

    # in DeiT, layernorm is also applied after self-attention
    layer_output = self.layernorm_after(hidden_states)
    layer_output = self.intermediate(layer_output)

    # second residual connection is done here
    layer_output = self.output(layer_output, hidden_states)
    
    outputs = (layer_output,attn_weights)
    
    return outputs