import logging

import torch

from .backbone.meta_resnet import meta_resnet18
from .backbone.resnet import resnet18
from .backbone.resnet_cifar import resnet18_cifar
from .byol import BYOL
from .simclr import SimCLR
from .simsiam import SimSiam
from torch import nn
from torchmeta.modules import MetaModule


class MetaIdentity(nn.Identity, MetaModule):
    __doc__ = nn.Identity.__doc__

    def forward(self, input, params=None):
        return input


def get_backbone(num_class, pssl_optimizer, backbone, castrate=True):
    if pssl_optimizer != "perFedAvg":
        logging.info("backbone = %s" % resnet18_cifar)
        if backbone == "resnet18_cifar":
            backbone = resnet18_cifar(num_classes=num_class)
        else:
            backbone = resnet18(num_classes=num_class)
    else:
        logging.info("backbone = %s" % meta_resnet18)
        backbone = meta_resnet18(num_classes=num_class)

    if castrate:
        backbone.output_dim = backbone.fc.in_features
        if pssl_optimizer != 'perFedAvg':
            backbone.fc = nn.Identity()
        else:
            backbone.fc = MetaIdentity()
    return backbone


def get_ssl_model(num_class, ssl_method, pssl_optimizer, backbone):
    if ssl_method == 'simsiam':
        model = SimSiam(get_backbone(num_class, pssl_optimizer, backbone))
    elif ssl_method == 'byol':
        model = BYOL(get_backbone(num_class, pssl_optimizer, backbone))
    elif ssl_method == 'simclr':
        model = SimCLR(get_backbone(num_class, pssl_optimizer, backbone))
    elif ssl_method == 'swav':
        raise NotImplementedError
    else:
        raise NotImplementedError
    return model
