import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torch


class Net(nn.Module):
    def __init__(self, state, event):
        super(Net, self).__init__()
        self.event = event
        self.state = state

        # conv net settings
        bias = False
        self.num_filters = state["conv_blocks"]

        # create convs & bn
        self.convs, self.norms, self.relus = nn.ModuleList([]), nn.ModuleList([]), nn.ModuleList([])
        for i,(f,flast) in enumerate(zip(self.num_filters,[state.all["dataset.num_channels"]]+self.num_filters)):

            if i == 0:
                # conv
                conv = nn.Conv2d(flast, f, 3, bias=bias, stride=state["first_stride"], padding=1)
                self.convs.append(conv)

            else:
                # conv
                conv_groups = nn.ModuleList([nn.Conv2d(flast,f,3,bias=bias,stride=state["stride"],padding=1) for i in range(state["groups"])])
                self.convs.append(conv_groups)

            # conv
            norm = event.normalization_layer(f)
            self.norms.append(norm)

            # relu
            self.relus.append(event.activation_layer())

        # fully connected
        self.last_bn = nn.BatchNorm1d(f)
        self.fc1 = nn.Linear(f, 100,bias=bias)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                event.init_conv(m)
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)



    def forward(self, net):

        #firstconv

        net = self.convs[0](net)
        net = self.norms[0](net)
        net = self.relus[0](net)

        # stack convs
        for l, conv_group, normalization, relu in zip(range(1,len(self.num_filters)), self.convs[1:],self.norms[1:], self.relus[1:]):
            if self.state["short"] and (l-1) % self.state["resblock_size"] == 0:
                shortcut = net
            net = torch.stack([relu(normalization(conv(net))) for conv in conv_group],-1).mean(-1)
            if self.state["short"] and l % self.state["resblock_size"] == 0:
               net = net+shortcut


        # fully connected
        net = net.mean([2, 3])
        net = self.last_bn(net)
        net = self.fc1(net)

        return net
