import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torch


class Net(nn.Module):
    def __init__(self, state, event):
        super(Net, self).__init__()
        self.event = event
        self.state = state

        # fc net settings
        bias = False

        self.num_filters = state["layers"]
        if not self.state["last_layer"]:
            f = state.all["dataset.num_channels"]
            self.num_filters = self.num_filters[:-1]

        # create layers & bn
        self.layers, self.norms, self.relus = nn.ModuleList([]), nn.ModuleList([]), nn.ModuleList([])
        for i,(f,flast) in enumerate(zip(self.num_filters,[state.all["dataset.num_channels"]]+self.num_filters)):

            # fc
            fc = nn.Linear(flast,f,bias=bias)
            self.layers.append(fc)

            # norm
            norm = event.normalization_layer(f)
            self.norms.append(norm)

            # relu
            self.relus.append(event.activation_layer())

        # fully connected
        self.fc_last = nn.Linear(f, state.all["dataset.num_classes"], bias=True)
        self.layers.append(self.fc_last)

    def forward(self, net):
        input = net.clone()

        # stack layers
        for l, layer, normalization, relu in zip(range(len(self.num_filters)), self.layers,self.norms, self.relus):
            net = layer(net)
            # self.event.optional.plot_scalar(Y=net.detach().cpu(),X=input.detach().cpu(),title="1D-path_layer"+str(l))
            net = normalization(net)
            if relu != nn.Identity:
                net = relu(net)

            # identitiy
            if self.state["short"]:
                net += input

        # fully connected
        net = self.fc_last(net)

        return net
