import torch
from torch import nn
import torch.nn.functional as F
from model.models.pc import LocFC, LocConv

nonlinearity = {
    'relu': nn.ReLU(),
    'sigmoid': nn.Sigmoid(),
    'tanh': nn.Tanh(),
    'leaky_relu': nn.LeakyReLU(),
    'softplus': nn.Softplus(),
    'linear': nn.Identity()
}

class Classifier(nn.Module):
    def __init__(self, args, L=0):
        super().__init__()
        self.L = args.n_layers - 1 if L == 0 else L-1
        get_fc = lambda args, l, o2c: LocConv(args, l) if args.cnn else LocFC(args, l, o2c=o2c)
        o2c = False if L == 0 else True
        fcs = [get_fc(args, l, o2c).to('cuda') for l in range(self.L)]
        self.fcs = nn.ModuleList(fcs)
        self.act = nonlinearity[args.act]
        self.flip = args.flip
        # flatten in case of args.cnn

    def _forward(self, x):
        z = x; zs = [z]
        # if flip, should also activate the first layer
        if self.flip: z = self.act(z)
        for l in range(self.L):
            z, z_pre = self.fcs[l](z)
            zs.append(z_pre.detach())
        return z, zs

    def forward(self, x):
        return self._forward(x)[0]

    def forward_layer(self, x):
        return self._forward(x)


if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--ds', type=int, nargs='+', default=[784, 128, 10])
    args = parser.parse_args()
    args.ecs = [1, 32, 64, 32]
    args.eks = [3, 3, 3]
    args.ess = [2, 2, 2]
    args.eps = [1, 1, 1]
    args.cnn = True
    model = Classifier(args)
    print(model)
    x = torch.randn(32, 1, 28, 28).to('cuda')
    print(model(x))
    print(model(x).shape)
