import torch
import torch.nn as nn
import torch.nn.functional as F

def freeze_model(m):
    m.requires_grad_(False)

def hot_model(m):
    m.requires_grad_(True)

class MLP(nn.Module):
    """ Very simple multi-layer perceptron (also called FFN)"""
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
        super().__init__()
        self.num_layers = num_layers
        h = [hidden_dim] * (num_layers - 1)
        self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))

    def forward(self, x):
        for i, layer in enumerate(self.layers):
            x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
        return x

class LightCrossAttnCfg:
    hidden_size: int = 768
    num_attention_heads: int = 8
    encoder_width: int = 768 # 1024
    attention_probs_dropout_prob: float = 0.1
    position_embedding_type: str = "absolute"
    max_position_embeddings: int = 512
    layer_norm_eps = 1e-12
    hidden_dropout_prob = 0.1
    num_hidden_layers=12
    intermediate_size=3072
    hidden_act="gelu"

