from .resnet import ResNet18, ResNet34, ResNet50, ResNet101, ResNet152
from .sup_con_original import SupConResNet, SupCEResNet, LinearClassifier

# Resnet
def resnet18(num_classes = 10, norm_layer_type = 'bn' ,conv_layer_type = 'conv',linear_layer_type = 'linear', activation_layer_type = 'relu', etf_fc = False):
    return ResNet18(num_classes=num_classes, norm_layer_type = norm_layer_type, conv_layer_type = conv_layer_type, linear_layer_type = linear_layer_type,
                    activation_layer_type = activation_layer_type, etf_fc = etf_fc)

# Cross-entropy
def ce_resnet18(name='resnet18', head='mlp', remove_last_relu=False):
    return SupCEResNet(name=name, head=head, remove_last_relu=remove_last_relu)

def ce_resnet50(name='resnet50', head='mlp', remove_last_relu=False):
    return SupCEResNet(name=name, head=head, remove_last_relu=remove_last_relu)