import torch.nn as nn
import torchvision.models as models


def build_encoder(args):
    if args.architecture == 'ResNet18':
        base = models.resnet18
    elif args.architecture == 'ResNet50':
        base = models.resnet50
    else:
        raise ValueError(f"Invalid architecture: {args.architecture}. Choose: ResNet18 or ResNet50.")
    
    if args.supervised:
        encoder = base(num_classes=args.num_classes, zero_init_residual=True)
    else:
        encoder = base(num_classes=args.feature_dim, zero_init_residual=True)
        encoder.fc = nn.Sequential(encoder.fc, nn.BatchNorm1d(args.feature_dim, affine=False))

    if "CIFAR" in args.dataset:
        encoder.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        encoder.maxpool = nn.Identity()
    elif "ImageNet" in args.dataset:
        pass

    return encoder


def build_projector(args):
    input_dim = args.feature_dim
    hidden_dim = args.projection_dim
    output_dim = args.feature_dim

    if args.projection_layer == 1:
        return nn.Sequential(
            nn.Linear(input_dim, output_dim)
        )
    
    elif args.projection_layer == 2:
        return nn.Sequential(
            nn.Linear(input_dim, hidden_dim, bias=False),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, output_dim)
        )

    elif args.projection_layer == 3:
        return nn.Sequential(
            nn.Linear(input_dim, hidden_dim, bias=False),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, hidden_dim, bias=False),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, output_dim)
        )
    
    else:
        raise ValueError(f"Invalid projection layer: {args.projection_layer}. Choose: 1, 2 or 3.")


def build_predictor(args):
    input_dim = args.feature_dim
    hidden_dim = args.prediction_dim
    output_dim = args.feature_dim

    if args.prediction_layer == 1:
        return nn.Sequential(
            nn.Linear(input_dim, output_dim)
        )
    
    elif args.prediction_layer == 2:
        return nn.Sequential(
            nn.Linear(input_dim, hidden_dim, bias=False),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, output_dim)
        )
    
    elif args.prediction_layer == 3:
        return nn.Sequential(
            nn.Linear(input_dim, hidden_dim, bias=False),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, hidden_dim, bias=False),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, output_dim)
        )
    
    else:
        raise ValueError(f"Invalid prediction layer: {args.prediction_layer}. Choose 1, 2 or 3.")