from cgitb import small
from ctypes import BigEndianStructure
import math
import pdb
import torch
from torch import nn

from torchvision.models.resnet import BasicBlock,conv1x1

ACTLIST = {
    "tanh": nn.Tanh,
    "sigmoid": nn.Sigmoid
}

NORMMAP = {
    "instance": nn.InstanceNorm2d,
    "batch": nn.BatchNorm2d
}

class BaseNetworkClass(nn.Module):
    def __init__(self, output_split_sizes):
        super().__init__()

        self.output_split_sizes = output_split_sizes

    def _get_correct_nn_output_format(self, nn_output, split_dim):
        if self.output_split_sizes is not None:
            return torch.split(nn_output, self.output_split_sizes, dim=split_dim)
        else:
            return nn_output

    def _apply_spectral_norm(self):
        for module in self.modules():
            if "weight" in module._parameters:
                nn.utils.spectral_norm(module)

    def forward(self, x):
        raise NotImplementedError("Define in child classes.")

def idxs_to_one_hot(idxs, conditioning_dimension):
    conditioning_vector = torch.zeros(idxs.shape[0], conditioning_dimension).to(idxs.device)
    conditioning_vector[torch.arange(idxs.shape[0]), idxs] = 1
    return conditioning_vector

class MLP(BaseNetworkClass):
    def __init__(
            self,
            input_dim,
            hidden_dims,
            output_dim,
            activation,
            output_split_sizes=None,
            spectral_norm=False,
            conditioning_dimension=0
    ):
        super().__init__(output_split_sizes)

        self.conditioning_dimension = conditioning_dimension

        layers = []
        prev_layer_dim = input_dim + conditioning_dimension
        for hidden_dim in hidden_dims:
            layers.append(nn.Linear(in_features=prev_layer_dim, out_features=hidden_dim))
            layers.append(activation())
            prev_layer_dim = hidden_dim

        layers.append(nn.Linear(in_features=hidden_dims[-1], out_features=output_dim))
        self.net = nn.Sequential(*layers)

        if spectral_norm: self._apply_spectral_norm()

    def forward(self, x, conditioning=None):
        if self.conditioning_dimension > 0:
            conditioning_vector = idxs_to_one_hot(conditioning, self.conditioning_dimension)
            x = torch.cat((x,conditioning_vector), 1)
            
        return self._get_correct_nn_output_format(self.net(x), split_dim=-1)

class SharedMLP(BaseNetworkClass):
    def __init__(
            self,
            input_dim,
            beginning_hidden_dims,
            middle_hidden_dims,
            end_hidden_dims,
            share_start,
            share_middle,
            share_end,
            output_dim,
            activation,
            output_split_sizes=None,
            spectral_norm=False
    ):
        super().__init__(output_split_sizes)

        def make_net(dims, end_layers=False):
            layers = []
            prev_layer_dim = dims[0]
            for dim_idx,hidden_dim in enumerate(dims[1:]):
                layers.append(nn.Linear(in_features=prev_layer_dim, out_features=hidden_dim))
                if dim_idx != len(dims[1:])-1 or not end_layers:
                    layers.append(activation())
                prev_layer_dim = hidden_dim
            return nn.Sequential(*layers)

        only_2_layer = len(middle_hidden_dims) == 0

        if only_2_layer:
            self.has_beginning = not share_start
            if not share_start:
                self.beginning_net = make_net([input_dim] + beginning_hidden_dims)
            
            self.has_middle = False
            
            self.has_end = not share_end
            if not share_end:
                self.end_net = make_net([beginning_hidden_dims[-1]] + end_hidden_dims + [output_dim], end_layers=True)

        else:
            self.has_beginning = not share_start
            if not share_start:
                self.beginning_net = make_net([input_dim] + beginning_hidden_dims)
            
            self.has_middle = not share_middle
            if not share_middle:
                self.middle_net = make_net([beginning_hidden_dims[-1]] + middle_hidden_dims)
            
            self.has_end = not share_end
            if not share_end:
                self.end_net = make_net([middle_hidden_dims[-1]] + end_hidden_dims + [output_dim], end_layers=True)

        if spectral_norm: self._apply_spectral_norm()
    
    def beginning(self, x):
        if self.has_beginning:
            return self.beginning_net(x)
        else:
            return x
    
    def middle(self, x):
        if self.has_middle:
            return self.middle_net(x)
        else:
            return x
    
    def end(self, x):
        if self.has_end:
            return self._get_correct_nn_output_format(self.end_net(x), split_dim=-1)
        else:
            return x
        
    def add_shared_module(self, module, location):
        assert location in ["start", "middle", "end"]
        
        if location == "start":
            self.beginning_net = module
            self.has_beginning = True 
        elif location == "middle":
            self.middle_net = module
            self.has_middle = True 
        elif location == "end":
            self.end_net = module
            self.has_end = True 
    
    def forward(self, x):
        return self.end(self.middle(self.beginning(x)))


class BaseCNNClass(BaseNetworkClass):
    def __init__(
            self,
            hidden_channels_list,
            stride,
            kernel_size,
            padding,
            output_split_sizes=None,
            do_param_init=True
    ):
        super().__init__(output_split_sizes)

        if do_param_init:
            self.stride = self._get_stride_or_kernel(stride, hidden_channels_list)
            self.kernel_size = self._get_stride_or_kernel(kernel_size, hidden_channels_list)
            self.padding = self._get_stride_or_kernel(padding, hidden_channels_list)

    def _get_stride_or_kernel(self, s_or_k, hidden_channels_list):
        if type(s_or_k) not in [list, tuple]:
            return [s_or_k for _ in hidden_channels_list]
        else:
            assert len(s_or_k) == len(hidden_channels_list), \
                "Mismatch between stride/kernels provided and number of hidden channels."
            return s_or_k
    
    def _get_new_image_height(self, height, kernel, stride, padding=0):
        # cf. https://pytorch.org/docs/1.9.1/generated/torch.nn.Conv2d.html
        # Assume dilation = 1, padding = 0
        return math.floor((height + 2*padding - kernel)/stride + 1)
        
    def _get_new_image_height_and_output_padding(self, height, kernel, stride, pad=0):
        # cf. https://pytorch.org/docs/1.9.1/generated/torch.nn.ConvTranspose2d.html
        # Assume dilation = 1, padding = 0
        output_padding = (height - 2*pad - kernel) % stride
        height = (height - 2*pad - kernel - output_padding) // stride + 1

        return height, output_padding
    
    def _get_stride_or_kernel(self, s_or_k, hidden_channels_list):
        if type(s_or_k) not in [list, tuple]:
            return [s_or_k for _ in hidden_channels_list]
        else:
            assert len(s_or_k) == len(hidden_channels_list), \
                "Mismatch between stride/kernels provided and number of hidden channels."
            return s_or_k

class SharedModule(nn.Module):
    def __init__(self, module, cat_dim, do_shared):
        super().__init__()
        self.module = module
        self.has_shared_module = False
        self.cat_dim = cat_dim
        self.do_shared = do_shared
    
    def attach_shared_module(self, shared_module):
        self.has_shared_module = True
        self.shared_module = shared_module
    
    def forward(self, xb):
        out = self.module(xb)

        if self.has_shared_module and self.do_shared:
            out_shared = self.shared_module(xb)
            return torch.cat((out, out_shared), self.cat_dim)
        
        else:
            return out

class CNN(BaseCNNClass):
    def __init__(
            self,
            input_channels,
            hidden_channels_list,
            output_dim,
            kernel_size,
            stride,
            padding,
            shared_module,
            image_height,
            activation,
            do_bn=False,
            output_split_sizes=None,
            spectral_norm=False,
            noise_dim=0,
            final_activation=None, # TODO: make config arg
            conv_bias=True, # TODO: make config arg,
            norm=None,
            norm_args={},
            final_linear=False,
            conditioning_dimension=0
    ):
        super().__init__(hidden_channels_list, stride, kernel_size, padding, output_split_sizes)

        self.conditioning_dimension = conditioning_dimension
        self.shared_module = self._get_stride_or_kernel(shared_module, hidden_channels_list)

        cnn_layers = []
        prev_channels = input_channels + (1 if conditioning_dimension > 0 else 0)
        for idx,(hidden_channels, k, s, p, sm) in enumerate(zip(hidden_channels_list, self.kernel_size, self.stride, self.padding, self.shared_module)):
        
            cnn_layers.append(
                SharedModule(nn.Conv2d(prev_channels, int(hidden_channels * sm) if sm > 0 else hidden_channels, k, s, padding=p, bias=conv_bias), 1, sm)
            )
          
            if norm is not None:
                cnn_layers.append(norm(hidden_channels, **norm_args))

            cnn_layers.append(activation())

            prev_channels = hidden_channels

            # NOTE: Assumes square image
            image_height = self._get_new_image_height(image_height, k, s)
       
        self.cnn_layers = nn.ModuleList(cnn_layers)
        
        self.final_linear = final_linear 

        if self.final_linear:
            self.fc_layer = SharedModule(
                nn.Linear(
                    in_features=prev_channels*image_height**2+noise_dim,
                    out_features=int(output_dim * self.shared_module[0]) if self.shared_module[0] > 0 else output_dim
                ), 
                1, self.shared_module[0])
        else:
            self.fc_layer = SharedModule(
                nn.Conv2d(int(prev_channels * self.shared_module[0]) if self.shared_module[0] > 0 else prev_channels, 1, self.kernel_size[-1], stride=1, padding=0),
                1, self.shared_module[0])

        self.final_activation = final_activation
        if self.final_activation is not None:
            self.final_activation = final_activation()

        if spectral_norm: self._apply_spectral_norm()

    def forward(self, x, conditioning=None, eps=None):
        if self.conditioning_dimension > 0:
            conditioning_channel = torch.ones_like(x)[:,0,:,:][:,None,:,:]
            conditioning_channel *= conditioning[:,None,None,None]
            x = torch.cat((x, conditioning_channel), 1)

        for layer in self.cnn_layers:
            x = layer(x)

        if self.final_linear:
            x = torch.flatten(x, start_dim=1)
       
        net_in = torch.cat((x, eps), dim=1) if eps is not None else x

        net_out = self.fc_layer(net_in)

        if self.final_activation is not None:
            net_out = self.final_activation(net_out)

        return self._get_correct_nn_output_format(net_out, split_dim=1)

class T_CNN(BaseCNNClass):
    def __init__(
            self,
            input_dim,
            hidden_channels_list,
            output_channels,
            kernel_size,
            stride,
            padding,
            shared_module,
            image_height,
            activation,
            do_bn=False,
            output_split_sizes=None,
            spectral_norm=False,
            single_sigma=False,
            final_activation=None, # TODO: make config arg
            conv_bias=True, # TODO: make config arg
            norm=None,
            norm_args={},
            initial_linear=False,
            conditioning_dimension=0
    ):
        super().__init__(hidden_channels_list, stride, kernel_size, padding, output_split_sizes)

        self.shared_module = self._get_stride_or_kernel(shared_module, hidden_channels_list)
        self.conditioning_dimension = conditioning_dimension

        self.single_sigma = single_sigma

        if self.single_sigma:
            # NOTE: In the MLP above, the single_sigma case can be handled by correctly
            #       specifying output_split_sizes. However, here the first output of the
            #       network will be of image shape, which more strongly contrasts with the
            #       desired sigma output of shape (batch_size, 1). We need the additional
            #       linear layer to project the image-like output to a scalar.
            self.sigma_output_layer = nn.Linear(output_split_sizes[-1]*image_height**2, 1)

        output_paddings = []
        for _, k, s, p in zip(hidden_channels_list, self.kernel_size[::-1], self.stride[::-1], self.padding[::-1]):
            # First need to infer the appropriate number of outputs for the first linear layer
            print(image_height)
            image_height, output_padding = self._get_new_image_height_and_output_padding(
                image_height, k, s, p
            )
            output_paddings.append(output_padding)
        output_paddings = output_paddings[::-1]

        self.initial_linear = initial_linear
        if self.initial_linear:
            self.fc_layer = SharedModule(
                nn.Linear(input_dim + self.conditioning_dimension, int(hidden_channels_list[0]*image_height**2 * self.shared_module[0]) \
                    if self.shared_module[0] > 0 else (hidden_channels_list[0]*image_height**2)),
                1, self.shared_module[0]
            )
            self.post_fc_shape = (hidden_channels_list[0], image_height, image_height)
        else:
            # TODO: verify dimensions
            self.fc_layer = SharedModule(
                nn.ConvTranspose2d(input_dim + (1 if self.conditioning_dimension > 0 else 0), int(hidden_channels_list[0] * self.shared_module[0]) \
                    if self.shared_module[0] > 0 else hidden_channels_list[0], self.kernel_size[0], 2, 1),
                1, self.shared_module[0]
            )
            self.post_fc_shape = (hidden_channels_list[0], image_height, image_height)

        t_cnn_layers = [activation()]
        prev_channels = hidden_channels_list[0]
        for idx,(hidden_channels, k, s, op, sm) in enumerate(zip(
            hidden_channels_list[1:], self.kernel_size[:-1], self.stride[:-1], output_paddings[:-1], self.shared_module[:-1]
        )):
            t_cnn_layers.append(
                SharedModule(
                    nn.ConvTranspose2d(prev_channels, int(hidden_channels * sm if sm > 0 else hidden_channels), k, s, output_padding=op, bias=conv_bias),
                    1, sm
                )
            )

            if norm is not None and not idx == len(hidden_channels_list) - 1:
                t_cnn_layers.append(norm(hidden_channels, **norm_args))
            
            t_cnn_layers.append(activation())

            prev_channels = hidden_channels
        
        t_cnn_layers.append(
            SharedModule(
                nn.ConvTranspose2d(
                    prev_channels, int(output_channels  * self.shared_module[-1]) if self.shared_module[-1] > 0 else output_channels, self.kernel_size[-1], self.stride[-1],
                    output_padding=output_paddings[-1]
                ),
                1, self.shared_module[-1]
            )
            
        )

        if final_activation is not None:
            t_cnn_layers.append(final_activation())

        self.t_cnn_layers = nn.ModuleList(t_cnn_layers)

        if spectral_norm: self._apply_spectral_norm()

    def forward(self, x, conditioning=None):

        
        if not self.initial_linear:
            x = x.reshape(x.shape[0], -1, 1, 1)
        
        if self.conditioning_dimension > 0:
            if len(x.shape) > 2:
                conditioning_channel = torch.ones_like(x)[:,0,:,:][:,None,:,:]
                conditioning_channel *= conditioning[:,None,None,None]
                x = torch.cat((x, conditioning_channel), 1)
            else:
                conditioning_vector = idxs_to_one_hot(conditioning, self.conditioning_dimension)
                x = torch.cat((x,conditioning_vector), 1)

       
        x = self.fc_layer(x)
        
        x = x.reshape(-1, *self.post_fc_shape)
    
        for layer in self.t_cnn_layers:
            x = layer(x)
            
        net_output = self._get_correct_nn_output_format(x, split_dim=1)
     
        if self.single_sigma:
            mu, log_sigma_unprocessed = net_output
            log_sigma = self.sigma_output_layer(log_sigma_unprocessed.flatten(start_dim=1))
            return mu, log_sigma.view(-1, 1, 1, 1)
        else:
            return net_output

class SharedCNN(BaseCNNClass):
    def __init__(
            self,
            input_channels,
            beginning_hidden_channels,
            middle_hidden_channels,
            end_hidden_channels,
            share_start,
            share_middle,
            share_end,
            beginning_kernel_size,
            middle_kernel_size,
            end_kernel_size,
            beginning_stride,
            middle_stride,
            end_stride,
            beginning_padding,
            middle_padding,
            end_padding,
            output_dim,
            image_height,
            activation,
            output_split_sizes=None,
            spectral_norm=False,
            noise_dim=0,
            final_activation=None, # TODO: make config arg
            conv_bias=True, # TODO: make config arg,
            norm=None,
            norm_args={},
            final_linear=False
    ):
        super().__init__(beginning_hidden_channels, beginning_stride, beginning_kernel_size, beginning_padding, output_split_sizes, do_param_init=False)

        self.final_activation = final_activation

        def make_net(channels, kernel_size, stride, padding, end_layers=False):

            print("Encoder:", channels, kernel_size, stride, padding)
        
            cnn_layers = []
            prev_channels = channels[0]
            channels = channels[1:]
        
            net_stride = self._get_stride_or_kernel(stride, channels)
            net_kernel_size = self._get_stride_or_kernel(kernel_size, channels)
            net_padding = self._get_stride_or_kernel(padding, channels)
            self.final_linear = final_linear

            for hidden_channels, k, s, p  in zip(channels, net_kernel_size, net_stride, net_padding):
                nonlocal image_height
                cnn_layers.append(nn.Conv2d(prev_channels, hidden_channels, k, s, padding=p, bias=conv_bias)) #TODO: dynamic padding

                if norm is not None:
                    cnn_layers.append(norm(hidden_channels, **norm_args))

                cnn_layers.append(activation())
                prev_channels = hidden_channels

                # NOTE: Assumes square image
                print("Encoder image height: ", image_height)
                image_height = self._get_new_image_height(image_height, k, s, p)
            
            if end_layers:
                if self.final_linear:
                    cnn_layers.append(
                        nn.Flatten(start_dim=1)
                    )
                    cnn_layers.append(nn.Linear(
                        in_features=prev_channels*image_height**2+noise_dim,
                        out_features=output_dim
                    ))
                else:
                    cnn_layers.append(nn.Conv2d(prev_channels, 1, net_kernel_size[-1], stride=1, padding=net_padding[-1]))

                if self.final_activation is not None:
                    cnn_layers.append(final_activation())

            cnn_layers = nn.Sequential(*cnn_layers)

            return cnn_layers
    
        
        only_2_layer = len(middle_hidden_channels) == 0
        
        self.has_beginning = not share_start
        if not share_start:
            self.beginning_net = make_net([input_channels] + beginning_hidden_channels, beginning_kernel_size, beginning_stride, beginning_padding)
        else:
            make_net([input_channels] + beginning_hidden_channels, beginning_kernel_size, beginning_stride, beginning_padding)
        
        self.has_middle = (not share_middle) and (not only_2_layer)
        if not only_2_layer:
            if (not share_middle):
                self.middle_net = make_net([beginning_hidden_channels[-1]] + middle_hidden_channels, middle_kernel_size, middle_stride, middle_padding)
            else:
                make_net([beginning_hidden_channels[-1]] + middle_hidden_channels, middle_kernel_size, middle_stride, middle_padding)

        self.has_end = not share_end
        if not share_end:
            self.end_net = make_net([middle_hidden_channels[-1] if not only_2_layer else beginning_hidden_channels[-1]] + end_hidden_channels, end_kernel_size, end_stride, end_padding, end_layers=True)
        else:
            make_net([middle_hidden_channels[-1] if not only_2_layer else beginning_hidden_channels[-1]] + end_hidden_channels, end_kernel_size, end_stride, end_padding, end_layers=True)

        if spectral_norm: self._apply_spectral_norm()
    
    def _get_new_image_height(self, height, kernel, stride, padding=0):
        # cf. https://pytorch.org/docs/1.9.1/generated/torch.nn.Conv2d.html
        # Assume dilation = 1, padding = 0
        return math.floor((height + 2*padding - kernel)/stride + 1)
    
    def beginning(self, x):
        try:
            if self.has_beginning:
                return self.beginning_net(x)
            else:
                return x
        except:
            pdb.set_trace()
    
    def middle(self, x):
        if self.has_middle:
            return self.middle_net(x)
        else:
            return x
    
    def end(self, x):
        if self.has_end:
            
            x = self.end_net(x)

            if type(x) == tuple: # TODO: make cleaner, hack when shared module does the split
                assert False, "TODO: handle this with final activation"
                return x 

            return self._get_correct_nn_output_format(x, split_dim=1)
        else:
            return x
        
    def add_shared_module(self, module, location):
        assert location in ["start", "middle", "end"]
        
        if location == "start":
            self.beginning_net = module
            self.has_beginning = True 
        elif location == "middle":
            self.middle_net = module
            self.has_middle = True 
        elif location == "end":
            self.end_net = module
            self.has_end = True 
    
    def forward(self, x):
        return self.end(self.middle(self.beginning(x)))

class Reshape(nn.Module):
    def __init__(self, new_shape):
        super().__init__()
        self.new_shape = new_shape 
    
    def forward(self, x):
        return x.reshape(self.new_shape)

class SharedT_CNN(BaseCNNClass):
    def __init__(
            self,
            input_dim,
            beginning_hidden_channels,
            middle_hidden_channels,
            end_hidden_channels,
            share_start,
            share_middle,
            share_end,
            beginning_kernel_size,
            middle_kernel_size,
            end_kernel_size,
            beginning_stride,
            middle_stride,
            end_stride,
            beginning_padding,
            middle_padding,
            end_padding,
            output_channels,
            image_height,
            activation,
            output_split_sizes=None,
            spectral_norm=False,
            single_sigma=False,
            final_activation=None, # TODO: make config arg
            conv_bias=True, # TODO: make config arg
            norm=None,
            norm_args={},
            initial_linear=False,
            force_zero_op=False # TODOt get rid of
    ):
        super().__init__(beginning_hidden_channels, beginning_stride, beginning_kernel_size, beginning_padding, output_split_sizes, do_param_init=False)

        self.single_sigma = single_sigma

        if self.single_sigma and not share_end:
            # NOTE: In the MLP above, the single_sigma case can be handled by correctly
            #       specifying output_split_sizes. However, here the first output of the
            #       network will be of image shape, which more strongly contrasts with the
            #       desired sigma output of shape (batch_size, 1). We need the additional
            #       linear layer to project the image-like output to a scalar.
            self.sigma_output_layer = nn.Linear(output_split_sizes[-1]*image_height**2, 1)

        # Get initial size
        total_stride = self._get_stride_or_kernel(beginning_stride + middle_stride + end_stride,\
                                        beginning_hidden_channels + middle_hidden_channels + end_hidden_channels)
        total_kernel_size = self._get_stride_or_kernel(beginning_kernel_size + middle_kernel_size + end_kernel_size, \
                                        beginning_hidden_channels + middle_hidden_channels + end_hidden_channels)
        total_padding = self._get_stride_or_kernel(beginning_padding + middle_padding + end_padding, \
                                        beginning_hidden_channels + middle_hidden_channels + end_hidden_channels)
        print("="*50)

        beginning_image_height = image_height
        total_output_paddings = []

        for _, k, s,p in zip( beginning_hidden_channels + middle_hidden_channels + end_hidden_channels, total_kernel_size[::-1], total_stride[::-1], total_padding[::-1]):
            print(image_height)
            # First need to infer the appropriate number of outputs for the first linear layer
            beginning_image_height, op = self._get_new_image_height_and_output_padding(
                beginning_image_height, k, s,p
            )
            total_output_paddings.append(total_output_paddings)
            
        
        def make_net(channels, kernel_size, stride, padding, beginning_layers=False, end_layers=False):
            nonlocal image_height

            print("Decoder", channels, kernel_size, stride, padding)

            prev_channels = channels[0]

            net_stride = self._get_stride_or_kernel(stride, channels)
            net_kernel_size = self._get_stride_or_kernel(kernel_size, channels)
            net_padding = self._get_stride_or_kernel(padding, channels)

            output_paddings = []
            for _, k, s, p in zip(channels, net_kernel_size[::-1], net_stride[::-1], net_padding[::-1]):
                print(image_height)
                # First need to infer the appropriate number of outputs for the first linear layer
                image_height, output_padding = self._get_new_image_height_and_output_padding(
                    image_height, k, s, p
                )
                output_paddings.append(output_padding)
           
            output_paddings = output_paddings[::-1]

            t_cnn_layers = []
            if beginning_layers:
                self.initial_linear = initial_linear
                if self.initial_linear:
                    fc_layer = nn.Linear(input_dim, channels[0]*beginning_image_height**2)
                    post_fc_shape = Reshape((-1, channels[0], beginning_image_height, beginning_image_height))
                    t_cnn_layers = [fc_layer, post_fc_shape]
                    t_cnn_layers.append(activation())
                else:
                    # TODO: verify dimensions
                    fc_layer = nn.ConvTranspose2d(input_dim, channels[0], kernel_size[0], 2, 1)
                    self.post_fc_shape = (channels[0], image_height, image_height)
                    t_cnn_layers = [fc_layer]
                    if norm is not None:
                        t_cnn_layers.append(norm(channels[0], **norm_args))
                    t_cnn_layers.append(activation())

            print("HERE2:", channels)
            for idx,(hidden_channels, k, s, p, op) in enumerate(zip(
                channels[1:], net_kernel_size[:-1], net_stride[:-1], net_padding[:-1], output_paddings[:-1]
            )):
                print("Here", force_zero_op)
                if force_zero_op:
                    op = 1 if end_layers else 0
                
                t_cnn_layers.append(
                    nn.ConvTranspose2d(prev_channels, hidden_channels, k, s, padding=p, output_padding=op, bias=conv_bias)
                )

                if norm is not None and not (idx == len(channels) - 1 and end_layers):
                    t_cnn_layers.append(norm(hidden_channels, **norm_args))

                t_cnn_layers.append(activation())

                prev_channels = hidden_channels

            if end_layers:
                t_cnn_layers.append(
                    nn.ConvTranspose2d(
                        prev_channels, output_channels, net_kernel_size[-1], net_stride[-1], padding=padding[-1],
                        output_padding=output_paddings[-1]
                    )
                )

                if final_activation is not None:
                    t_cnn_layers.append(final_activation())

            t_cnn_layers = nn.Sequential(*t_cnn_layers)

            return t_cnn_layers
        
        only_2_layer = len(middle_hidden_channels) == 0
        
        self.has_beginning = not share_start
        if not share_start:
            self.beginning_net = make_net(beginning_hidden_channels, beginning_kernel_size, beginning_stride, beginning_padding, beginning_layers=True)
        else:
            make_net(beginning_hidden_channels, beginning_kernel_size, beginning_stride, beginning_padding, beginning_layers=True)

        self.has_middle = not share_middle and not only_2_layer
        
        if not only_2_layer:
            if not share_middle:
                self.middle_net = make_net([beginning_hidden_channels[-1]] + middle_hidden_channels, \
                    [beginning_kernel_size[-1]] + middle_kernel_size, [beginning_stride[-1]] + middle_stride, [beginning_padding[-1]] + middle_padding)
            else:
                make_net([beginning_hidden_channels[-1]] + middle_hidden_channels, [beginning_kernel_size[-1]] + middle_kernel_size, \
                    [beginning_stride[-1]] + middle_stride, [beginning_padding[-1]] + middle_padding)

        self.has_end = not share_end
        if not only_2_layer:
            if not share_end:
                self.end_net = make_net([middle_hidden_channels[-1]] + end_hidden_channels, [middle_kernel_size[-1]] + end_kernel_size, \
                    [middle_stride[-1]] + end_stride, [middle_padding[-1]] + end_padding, end_layers=True)
            else:
                make_net([middle_hidden_channels[-1]] + end_hidden_channels, [middle_kernel_size[-1]] + end_kernel_size, \
                    [middle_stride[-1]] + end_stride, [middle_padding[-1]] + end_padding, end_layers=True)
        else:
            if not share_end:
                self.end_net = make_net([beginning_hidden_channels[-1]] + end_hidden_channels, [beginning_kernel_size[-1]] + end_kernel_size, \
                    [beginning_stride[-1]] + end_stride, [beginning_padding[-1]] + end_padding, end_layers=True)
            else:
                make_net([beginning_hidden_channels[-1]] + end_hidden_channels, [beginning_kernel_size[-1]] + end_kernel_size, \
                [beginning_stride[-1]] + end_stride,  [beginning_padding[-1]] + end_padding, end_layers=True)
   
        if spectral_norm: self._apply_spectral_norm()

    def beginning(self, x):
        if self.has_beginning:
            if not self.initial_linear:
                x = x.reshape(x.shape[0], -1, 1, 1)
            return self.beginning_net(x)
        else:
            return x
    
    def middle(self, x):
        if self.has_middle:
            # print("Middle", x.shape, "to", self.middle_net(x).shape)
            return self.middle_net(x)
        else:
            return x
    
    def end(self, x): 
        if self.has_end:
            # print("End", x.shape, "to", self.end_net(x).shape)
            x = self.end_net(x)

            if type(x) == tuple: return x

            net_output = self._get_correct_nn_output_format(x, split_dim=1)

            if self.single_sigma:
                mu, log_sigma_unprocessed = net_output
                log_sigma = self.sigma_output_layer(log_sigma_unprocessed.flatten(start_dim=1))
                return mu, log_sigma.view(-1, 1, 1, 1)
            else:
                return net_output
        else:
            return x
        
    def add_shared_module(self, module, location):
        assert location in ["start", "middle", "end"]
        
        if location == "start":
            self.beginning_net = module
            self.has_beginning = True 
        elif location == "middle":
            self.middle_net = module
            self.has_middle = True 
        elif location == "end":
            self.end_net = module
            self.has_end = True 
    
    def forward(self, x):
        return self.end(self.middle(self.beginning(x)))

class GaussianMixtureLSTM(BaseNetworkClass):
    def __init__(self, input_size,  hidden_size, num_layers, k_mixture):
        output_split_sizes = [k_mixture, k_mixture * input_size, k_mixture * input_size]
        super().__init__(output_split_sizes)
        self.rnn = nn.LSTM(input_size=input_size,
                           hidden_size=hidden_size,
                           num_layers=num_layers,
                           batch_first=True)
        self.linear = nn.Linear(in_features=hidden_size, out_features=k_mixture*(2*input_size+1))
        self.hidden_size = hidden_size
        self.k = k_mixture

    def forward(self, x, return_h_c=False, h_c=None, not_sampling=True):
        if not_sampling:
            x = torch.flatten(x, start_dim=2)
            x = torch.permute(x, (0, 2, 1))  # move "channels" to last axis
            x = x[:, :-1]  # last coordinate is never used as input
        if h_c is None:
            out, h_c = self.rnn(x)
        else:
            out, h_c = self.rnn(x, h_c)
        if not_sampling:
            out = torch.cat((torch.zeros((out.shape[0], 1, self.hidden_size)).to(out.device), out), dim=1)
        out = self.linear(out)
        weights, mus, sigmas = self.split_transform_and_reshape(out)  # NOTE: weights contains logits
        if return_h_c:
            return weights, mus, sigmas, h_c
        else:
            return weights, mus, sigmas

    def split_transform_and_reshape(self, out):
        weights, mus, sigmas = self._get_correct_nn_output_format(out, split_dim=-1)
        mus = torch.reshape(mus, (mus.shape[0], mus.shape[1], self.k, -1))
        sigmas = torch.reshape(torch.exp(sigmas), (sigmas.shape[0], sigmas.shape[1], self.k, -1))
        return weights, mus, sigmas

class ResidualStack(nn.Module):
    def __init__(
        self,
        block,
        layer_channels,
        blocks_per_layer,
        norm,
        input_channel_dim,
        layer_strides=None,
        upsample_layer=False
    ):
        super().__init__()
        self.norm = norm
        self.input_channel_dim = input_channel_dim
        
        if layer_strides == None:
            layer_strides = [2 for _ in range(len(layer_channels))]
            
        layers = []
        assert len(layer_channels) == len(blocks_per_layer)
        for channel_dim,num_blocks,stride in zip(layer_channels, blocks_per_layer, layer_strides):

            #TODO: configurable upsampling conv transpose
            if upsample_layer:
                layers.append(
                    nn.Upsample(
                        scale_factor=2,
                        mode="bilinear",
                        align_corners=False
                    )
                )

            layers.append(
                self._make_layer(
                    block=block,
                    channel_dim=channel_dim,
                    num_blocks=num_blocks,
                    stride=stride
                )
            )
        self.layers = nn.Sequential(*layers)
    
    def _make_layer(
        self,
        block,
        channel_dim,
        num_blocks,
        stride=2
    ):
        norm_layer = self.norm

        downsample = nn.Sequential(
            conv1x1(self.input_channel_dim, channel_dim * block.expansion, stride),
            norm_layer(channel_dim * block.expansion),
        )
        
        blocks = []
        blocks.append(
            block(self.input_channel_dim, channel_dim, stride=stride, downsample=downsample, norm_layer=self.norm)
        )
        self.input_channel_dim = channel_dim
        for _ in range(1, num_blocks):
            blocks.append(
                block(
                    channel_dim,
                    channel_dim,
                    norm_layer=norm_layer,
                )
            )

        return nn.Sequential(*blocks)
    
    def forward(self, x):
        return self.layers(x)

class ResidualEncoder(BaseNetworkClass):
    def __init__(
        self, 
        layer_channels, 
        blocks_per_layer, 
        output_dim, 
        input_channels, 
        output_split_sizes=None,
        input_channel_dim=64,
        norm=nn.BatchNorm2d, 
        block=BasicBlock,
        list_output=False
    ):
        # TODO: add more configuration
        super().__init__(output_split_sizes)
        self.norm = norm
        self.input_channel_dim = input_channel_dim
        self.list_output = list_output
    
        self.conv1 = nn.Conv2d(input_channels, self.input_channel_dim, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = norm(self.input_channel_dim)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.layers = ResidualStack(block, layer_channels, blocks_per_layer, norm=self.norm, input_channel_dim=self.input_channel_dim)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(layer_channels[-1] * block.expansion, output_dim)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layers(x)
      
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        net_output = self.fc(x)

        return self._get_correct_nn_output_format(net_output, split_dim=1)

class ResidualDecoder(BaseNetworkClass):
    def __init__(
        self,
        input_dim,
        layer_channels, 
        blocks_per_layer, 
        output_channels,
        image_height,
        output_split_sizes=None,
        single_sigma=False,
        norm=nn.BatchNorm2d,
        block=BasicBlock
    ):
        super().__init__(output_split_sizes)
        self.single_sigma=single_sigma
        self.norm = norm
        self.layer_channels = layer_channels + [output_channels]
        self.num_blocks = len(self.layer_channels)

        if single_sigma:
            self.sigma_output_layer = nn.Linear(output_channels*image_height**2, 1)

        self.layers = ResidualStack(
            block=block,
            layer_channels=self.layer_channels,
            blocks_per_layer=blocks_per_layer,
            layer_strides=[1 for _ in range(len(self.layer_channels))],
            upsample_layer=True,
            norm=self.norm,
            input_channel_dim=layer_channels[0]
        )

        self.fc_layer = nn.Linear(input_dim, layer_channels[0]*(image_height // (2**self.num_blocks))**2)
        self.post_fc_shape = (layer_channels[0], image_height, image_height)
    
    def forward(self, x):
        
        x = self.fc_layer(x).reshape(-1, *(self.post_fc_shape[0], self.post_fc_shape[1] // (2**self.num_blocks), self.post_fc_shape[2] // (2**self.num_blocks)))
        
        net_output = self.layers(x)

        net_output = self._get_correct_nn_output_format(net_output, split_dim=1)

        if self.single_sigma:
            mu, log_sigma_unprocessed = net_output
            log_sigma = self.sigma_output_layer(log_sigma_unprocessed.flatten(start_dim=1))
            return mu, log_sigma.view(-1, 1, 1, 1)
        else:
            return net_output