from torch import nn
import torch.nn.functional as F

class LinearModel(nn.Module):
    def __init__(self,
                 args):
        super(LinearModel, self).__init__()
        self.linear = nn.Linear(sum(args.backbone_args['h_dims']), args.n_cls, bias=True)

    def forward(self, features):
        logits = self.linear(features)
        return logits

    def reset_params(self):
        self.linear.reset_parameters()

class SGC_MLP(nn.Module):
    def __init__(self,
                 args):
        super(SGC_MLP, self).__init__()
        self.mlp_layers = nn.ModuleList()
        h_dims = args.SGC_args['h_dims']
        if len(h_dims) > 0:
            self.mlp_layers.append(nn.Linear(args.d_data, h_dims[0]))
            for i in range(len(h_dims) - 1):
                self.mlp_layers.append(nn.Linear(h_dims[i], h_dims[i + 1]))
            self.mlp_layers.append(nn.Linear(h_dims[-1], args.n_cls))
        elif len(h_dims) == 0:
            self.mlp_layers.append(nn.Linear(args.d_data, args.n_cls))
        else:
            raise ValueError('no valid MLP dims are given')

    def forward(self, x):
        for layer in self.mlp_layers[:-1]:
            x = layer(x)
            x = F.relu(x)
        x = self.mlp_layers[-1](x)
        return x
    
    def reset_params(self):
        for layer in self.mlp_layers:
            layer.reset_parameters()
