import torch.nn as nn


class NGDModel(nn.Module):
    def __init__(self, architecture, activation):
        super(NGDModel, self).__init__()
        if activation == "tanh":
            activation_function = nn.Tanh()
        else:
            raise NotImplementedError(f"{activation} is not supported")
        
        self.layers = nn.ModuleList()
        for i in range(len(architecture)-1):
            layer = nn.Linear(architecture[i], architecture[i+1])
            self.layers.append(layer)
            if i < len(architecture)-2:
                self.layers.append(activation_function)
        
    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x


class SGDModel(nn.Module):
    def __init__(self, architecture, activation):
        super(SGDModel, self).__init__()
        if activation == "tanh":
            activation_function = nn.Tanh()
        else:
            raise NotImplementedError(f"{activation} is not supported")
        
        self.layers = nn.ModuleList()
        for i in range(len(architecture)-1):
            layer = nn.Linear(architecture[i], architecture[i+1], bias=False)
            self.layers.append(layer)
            if i < len(architecture)-2:
                self.layers.append(activation_function)
        
    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x
    
    
class AdamModel(nn.Module):
    def __init__(self, architecture, activation):
        super(AdamModel, self).__init__()
        if activation == "tanh":
            activation_function = nn.Tanh()
        else:
            raise NotImplementedError(f"{activation} is not supported")
        
        self.layers = nn.ModuleList()
        for i in range(len(architecture)-1):
            layer = nn.Linear(architecture[i], architecture[i+1], bias=False)
            self.layers.append(layer)
            if i < len(architecture)-2:
                self.layers.append(activation_function)
        
    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x
    
    
class LBFGSModel(nn.Module):
    def __init__(self, architecture, activation):
        super(LBFGSModel, self).__init__()
        if activation == "tanh":
            activation_function = nn.Tanh()
        else:
            raise NotImplementedError(f"{activation} is not supported")
        
        self.layers = nn.ModuleList()
        for i in range(len(architecture)-1):
            layer = nn.Linear(architecture[i], architecture[i+1], bias=False)
            self.layers.append(layer)
            if i < len(architecture)-2:
                self.layers.append(activation_function)
        
    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x