from .networks import resnet32
import torch

def build_cifar():
    model = resnet32(num_classes=100, num_exps=3, use_norm=True)
    return model


class SHIKEWrapper(torch.nn.Module):
    def __init__(self, div_out=False):
        super().__init__()
        self.forward_func = None
        self.div_out = div_out
        self.featurizer, self.classifier = None, None
        self.model = None

    def build_cifar(self):
        self.model = resnet32(num_classes=100, num_exps=3, use_norm=True)

    def forward(self, input, **kwargs):
        outs = self.model(input, crt=True)
        if self.div_out:
            return outs
        else:
            return sum(outs) / len(outs)

