import torch
import torch.nn as nn
from .utils import spectral_projector

__all__ = ['Base']


class Base(nn.Module):
    """
    Build a SimCLR model.
    """
    def __init__(self, base_ecnoder, need_projector=True, dim=512, hidden_dim=1024):
        """
        dim: feature dimension (defulat:512)
        hidden_dim: hidden dimension of the projector (default=4096)
        """
        super(Base, self).__init__()
         
        # create the encoder
        # num_classes is the output fc dimension
        self.encoder = base_ecnoder(num_classes=dim)
        # build a projector
        prev_dim = self.encoder.fc.weight.shape[1]
        self.encoder.fc = spectral_projector(prev_dim, dim, hidden_dim) if need_projector else nn.Identity()

    
    def forward(self, x1, x2):
        """
        Input:
            x1: first views of images
            x2: second views of images
        Output:
            q, k: predictors and targets of the network
        """
        # compute features for each view
        q = self.encoder(x1)
        k = self.encoder(x2)

        return q, k
