import torch
import torch.nn as nn
    
class SCHead(nn.Module):
    """
    Predict the self-consistency score.
    """
    def __init__(self, config_sc, **kwargs):
        super().__init__()

        self.config = config_sc
        
        self.latent_dim = config_sc.latent_dim
        self.ff_dim = config_sc.ff_dim
        self.c_out = config_sc.c_out
        self.n_layers = config_sc.n_layers
        self.dropout = config_sc.dropout
    
        layers = []
        for i in range(self.n_layers):
            in_dim = self.latent_dim if i == 0 else self.ff_dim
            out_dim = self.c_out if i == self.n_layers - 1 else self.ff_dim
            layers.append(torch.nn.Linear(in_dim, out_dim))
            if i != self.n_layers - 1:
                layers.append(torch.nn.GELU())
                layers.append(torch.nn.Dropout(self.dropout))
        self.head = torch.nn.Sequential(*layers)

    def forward(self, latent_z):
        """
        Args:
            latent_z:
                [*, latent_dim] latent_z
        Returns:
            [*, c_out] self-consistency logits
        """
        # [*, c_out]
        logits = self.head(latent_z)
        return logits
