import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torch


class Net(nn.Module):
    def __init__(self, state, event):
        super(Net, self).__init__()
        self.event = event
        self.state = state

        # conv net settings
        bias = False
        self.num_filters = state["conv_blocks"]

        # create convs & bn
        self.convs, self.norms, self.relus = nn.ModuleList([]), nn.ModuleList([]), nn.ModuleList([])
        for i,(f,flast) in enumerate(zip(self.num_filters,[state.all["dataset.num_channels"]]+self.num_filters)):

            if i == 0:
                # conv
                conv = nn.Conv2d(flast, f, 3, bias=bias, stride=state["first_stride"], padding=1)
                self.convs.append(conv)

            else:
                # conv
                conv = nn.Conv2d(flast,f,3,bias=bias,stride=state["stride"],padding=1)
                self.convs.append(conv)

            # conv
            norm = event.normalization_layer(f)
            self.norms.append(norm)

            # relu
            self.relus.append(event.activation_layer())

        # fully connected
        self.last_bn = nn.BatchNorm1d(f)
        self.fc1 = nn.Linear(f, 100,bias=bias)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                event.init_conv(m)
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)



    def forward(self, net):

        # stack convs
        for l, conv, normalization, relu in zip(range(len(self.num_filters)), self.convs,self.norms, self.relus):
            if self.state["short"] and l > 0 and (l-1) % self.state["resblock_size"] == 0:
                shortcut = net
            net = conv(net)
            net = normalization(net)
            if relu != nn.Identity:
                self.event.optional.before_relu(net)
                net = relu(net)
                self.event.optional.after_relu(net)
            if self.state["short"] and l > 0 and l % self.state["resblock_size"] == 0:
                net = net+shortcut


        # fully connected
        net = net.mean([2, 3])
        net = self.last_bn(net)
        net = self.fc1(net)

        return net
