from Network.network import Network
from Network.network_utils import get_inplace_acti

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np

# TODO: Copied, but not properly implemented yet
class Conv2DNetwork(Network): # basic 1d conv network 
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.kernel = kwargs["kernel"] # square kernel of a single value
        self.input_dims = kwargs["input_dims"]
        self.channels = self.input_dims[-1] if len(self.input_dims) > 2 else 1
        self.stride = kwargs["stride"]
        self.padding = kwargs["padding"]
        self.output_dim = kwargs["output_dim"]
        self.reduce = kwargs["reduce"]
        include_last = kwargs['include_last']
        self.is_crelu = kwargs['activation'] == "crelu"
        crelu_mul = 1
        if self.is_crelu: 
            self.hs = [int(hs / 2) for hs in self.hs]
            crelu_mul = 2
            if kwargs['activation_final'] == "crelu":
                self.activation_final = get_inplace_acti("leakyrelu")

        x,y = self.input_dims[:2]# technically switched
        for i in range(len(self.hs)):
            x = int(( x + 2 * self.padding - (self.kernel - 1) - 1) / self.stride + 1)
            y = int(( y + 2 * self.padding - (self.kernel - 1) - 1) / self.stride + 1)
        last_num = self.hs[-1] if not include_last else self.output_dim
        self.reduce_size = x * y * last_num 

        if len(self.hs) == 0:
            layers = [nn.Conv2d(self.channels, self.output_dim, self.kernel, self.stride, self.padding)]
        else:
            if len(self.hs) == 1:
                layers = [nn.Conv2d(self.channels, self.hs[0], self.kernel, self.stride, self.padding)]
            else:
                layers = ([nn.Conv2d(self.channels, self.hs[0], self.kernel, self.stride, self.padding), nn.ReLU(inplace=True)] + 
                      sum([[nn.Conv2d(self.hs[i-1] * crelu_mul, self.hs[i], self.kernel, self.stride, self.padding), nn.ReLU(inplace=True)] for i in range(1, len(self.hs) - 1)], list())
                      + [nn.Conv2d(self.hs[-2] * crelu_mul, self.hs[-1], self.kernel, self.stride, self.padding), nn.ReLU(inplace=True)])
            if include_last: # if we include last, we need a relu after second to last. If we do not include last, we assume that there is a layer afterwards so we need a relu after the second to last
                layers += [nn.Conv2d(self.hs[-1] * crelu_mul, self.output_dim, self.kernel, self.stride, self.padding)]
        self.conv = nn.Sequential(*layers)
        self.model = nn.ModuleList([self.conv])
        self.final = None
        if self.reduce:
            self.final = nn.Linear(self.reduce_size, self.output_dim)
            self.model = nn.ModuleList([self.conv] + [self.final])
        self.train()
        self.reset_network_parameters()

    def forward(self, x):
        x = self.conv(x)
        if self.reduce:
            x = x.reshape(-1, self.reduce_size)
            x = self.final(x)
        return x