import torch
import torch.nn as nn
import torchvision
from torch.nn import Module

def generate_net(name, device=torch.device('cpu'), init='rand', num_classes=10, pretrained=True, input_size = None, hidden_size=None):
    if name == 'small':
        if input_size is None:
            input_size = 784
        network = SmallNet(device, input_features=input_size, hidden_size=hidden_size)
    elif name == 'resnet50':
        network = torch.hub.load('pytorch/vision:v0.8.0', 'resnet50', pretrained=True)
        network.fc = nn.Linear(network.fc.in_features, num_classes)
    elif name == 'resnet18':
        network = torch.hub.load('pytorch/vision:v0.8.0', 'resnet18', pretrained=True)
        network.fc = nn.Linear(network.fc.in_features, num_classes)
    elif name == 'resnet10':
        if pretrained:
            print("Unable to load a pretrained resnet10 model. Loading the untrained model instead.")
        network = torchvision.models.ResNet(torchvision.models.resnet.BasicBlock, [1, 1, 1, 1],
                                        num_classes=num_classes)
    else:
        raise ValueError(f"Unknown network: {name}")
    return network.to(device)



class SmallNet(Module):

    def __init__(self, device=torch.device('cpu'), n_classes=10, input_features=784, hidden_size=None,
                 act=torch.nn.Sigmoid()):
        super(SmallNet, self).__init__()

        self.flatten = nn.Flatten()
        modules = []
        if hidden_size is not None:
            for size in hidden_size:
                modules.append(torch.nn.Linear(input_features, size))
                modules.append(act)
                input_features = size
        modules.append(torch.nn.Linear(input_features, n_classes))
        self.net = nn.Sequential(*modules)
        self.net.to(device)

    def forward(self, x):
        x = self.flatten(x)
        out = self.net(x)
        return out

def get_learnable_parameters_as_vector(net):
    return nn.utils.parameters_to_vector(net.parameters())

def set_learnable_parameters_from_vector(net, w):
    nn.utils.vector_to_parameters(w, net.parameters())




