""" Implementation of neural network modules. """

import math
import torch
import torch.nn as nn
import torch.nn.functional as F


class SDNCell(nn.Module):
    """ TODO: Use nn.GRUCell once it works with AMP. """

    def __init__(self, state_size):
        super().__init__()
        # for previous states
        self.register_parameter('weight_ih', nn.Parameter(torch.randn(3 * state_size, 3 * state_size)))
        self.register_parameter('bias_ih', nn.Parameter(torch.randn(3 * state_size)))
        # for state prior
        self.register_parameter('weight_hh', nn.Parameter(torch.randn(3 * state_size, state_size)))
        self.register_parameter('bias_hh', nn.Parameter(torch.randn(3 * state_size)))
        # Initialization
        std = 1.0 / math.sqrt(state_size)
        for w in self.parameters():
            w.data.uniform_(-std, std)

    def forward(self, prev_states, state_prior):
        i_vals = torch.addmm(self.bias_ih, torch.cat(prev_states, dim=1), self.weight_ih.t())
        h_vals = torch.addmm(self.bias_hh, state_prior, self.weight_hh.t())
        r_i, z_i, n_i = i_vals.chunk(3, 1)
        r_h, z_h, n_h = h_vals.chunk(3, 1)
        r = torch.sigmoid(r_i + r_h)
        z = torch.sigmoid(z_i + z_h)
        n = torch.tanh(n_i + (r * n_h))
        return n * z + state_prior * (1-z)


class _SDNLayer(nn.Module):

    def __init__(self, state_size, dir=0):
        super().__init__()
        self.state_size = state_size
        self.cell = SDNCell(state_size)
        if dir == 0:
            self.forward = self.forward0
        elif dir == 1:
            self.forward = self.forward1
        elif dir == 2:
            self.forward = self.forward2
        else:
            self.forward = self.forward3

    def forward0(self, states):
        # zero state
        batch = states.shape[0]
        dim = states.shape[2]
        device = states.device
        states = states.contiguous(memory_format=torch.channels_last)
        # make a loop
        for d in range(1, dim):
            prev_states = torch.cat([
                torch.zeros((batch, self.state_size, 1), device=device),
                states[:, :, :, d - 1],
                torch.zeros((batch, self.state_size, 1), device=device)
            ], dim=2).transpose(1, 2)
            # compute states
            states[:, :, :, d] = self.cell(
                prev_states=[prev_states[:, :-2, :].reshape(-1, self.state_size),
                             prev_states[:, 1:-1, :].reshape(-1, self.state_size),
                             prev_states[:, 2:, :].reshape(-1, self.state_size)],
                state_prior=states[:, :, :, d].transpose(1, 2).reshape(-1, self.state_size).clone(memory_format=torch.preserve_format)
            ).reshape(batch, -1, self.state_size).transpose(1, 2)
        # return new states
        return states.contiguous(memory_format=torch.contiguous_format)

    def forward1(self, states):
        # zero state
        batch = states.shape[0]
        dim = states.shape[2]
        device = states.device
        states = states.contiguous(memory_format=torch.channels_last)
        # make a loop
        for d in range(dim - 2, -1, -1):
            prev_states = torch.cat([
                torch.zeros((batch, self.state_size, 1), device=device),
                states[:, :, :, d + 1],
                torch.zeros((batch, self.state_size, 1), device=device)
            ], dim=2).transpose(1, 2)
            # compute states
            states[:, :, :, d] = self.cell(
                prev_states=[prev_states[:, :-2, :].reshape(-1, self.state_size),
                             prev_states[:, 1:-1, :].reshape(-1, self.state_size),
                             prev_states[:, 2:, :].reshape(-1, self.state_size)],
                state_prior=states[:, :, :, d].transpose(1, 2).reshape(-1, self.state_size).clone(memory_format=torch.preserve_format)
            ).reshape(batch, -1, self.state_size).transpose(1, 2)
        # return new states
        return states.contiguous(memory_format=torch.contiguous_format)

    def forward2(self, states):
        # zero state
        batch = states.shape[0]
        dim = states.shape[2]
        device = states.device
        states = states.contiguous(memory_format=torch.channels_last)
        # make a loop
        for d in range(1, dim):
            prev_states = torch.cat([
                torch.zeros((batch, self.state_size, 1), device=device),
                states[:, :, d - 1, :],
                torch.zeros((batch, self.state_size, 1), device=device)
            ], dim=2).transpose(1, 2)
            # compute states
            states[:, :, d, :] = self.cell(
                prev_states=[prev_states[:, :-2, :].reshape(-1, self.state_size),
                             prev_states[:, 1:-1, :].reshape(-1, self.state_size),
                             prev_states[:, 2:, :].reshape(-1, self.state_size)],
                state_prior=states[:, :, d, :].transpose(1, 2).reshape(-1, self.state_size).clone(memory_format=torch.preserve_format)
            ).reshape(batch, -1, self.state_size).transpose(1, 2)
        # return new states
        return states.contiguous(memory_format=torch.contiguous_format)

    def forward3(self, states):
        # zero state
        batch = states.shape[0]
        dim = states.shape[2]
        device = states.device
        states = states.contiguous(memory_format=torch.channels_last)
        # make a loop
        for d in range(dim - 2, -1, -1):
            prev_states = torch.cat([
                torch.zeros((batch, self.state_size, 1), device=device),
                states[:, :, d + 1, :],
                torch.zeros((batch, self.state_size, 1), device=device)
            ], dim=2).transpose(1, 2)
            # compute states
            states[:, :, d, :] = self.cell(
                prev_states=[prev_states[:, :-2, :].reshape(-1, self.state_size),
                             prev_states[:, 1:-1, :].reshape(-1, self.state_size),
                             prev_states[:, 2:, :].reshape(-1, self.state_size)],
                state_prior=states[:, :, d, :].transpose(1, 2).reshape(-1, self.state_size).clone(memory_format=torch.preserve_format)
            ).reshape(batch, -1, self.state_size).transpose(1, 2)
        # return new states
        return states.contiguous(memory_format=torch.contiguous_format)


class SDN(nn.Module):
    def __init__(self, in_ch, out_ch, state_size, dirs, kernel_size, stride, padding, upsample):
        super().__init__()
        # project-in network
        cnn_module = nn.ConvTranspose2d if upsample else nn.Conv2d
        self.pre_cnn = cnn_module(in_ch, state_size, kernel_size, stride, padding)
        # update network
        sdn_update_blocks = []
        for dir in dirs:
            sdn_update_blocks.append(_SDNLayer(state_size, dir=dir))
        self.sdn_update_network = nn.Sequential(*sdn_update_blocks)
        # project-out network
        self.post_cnn = nn.Conv2d(state_size, out_ch, 1)

    def forward(self, x):
        # (I) project-in step
        x = self.pre_cnn(x)
        x = nn.functional.tanh(x)
        # (II) update step
        x = self.sdn_update_network(x)
        # (III) project-out step
        x = self.post_cnn(x)
        return x


class ResSDN(nn.Module):

    def __init__(self, in_ch, out_ch, state_size, dirs, kernel_size, stride, padding, upsample):
        super().__init__()
        self.sdn = SDN(in_ch, 2 * out_ch, state_size, dirs, kernel_size, stride, padding, upsample)
        cnn_module = nn.ConvTranspose2d if upsample else nn.Conv2d
        self.cnn = cnn_module(in_ch, out_ch, kernel_size, stride, padding)

    def forward(self, input):
        cnn_out = self.cnn(input)
        sdn_out, gate = self.sdn(input).chunk(2, 1)
        gate = torch.sigmoid(gate)
        return gate * cnn_out + (1-gate) * sdn_out


class LadderLayer(nn.Module):
    def __init__(self, post_model, prior_model, z_size, h_size, free_bits, downsample, sdn_state,
                 sdn_dirs_a, sdn_dirs_b, use_sdn, sampling_temperature):
        super().__init__()

        # initializations
        self.post_model = post_model
        self.prior_model = prior_model
        self.logqzx_params_up = None
        self.free_bits = free_bits
        self.downsample = downsample
        self.use_sdn = use_sdn
        self.sampling_temperature = sampling_temperature
        self.act = nn.ELU(True)

        # infer CNN parameters based on whether we do downsampling or not
        kernel_size, stride, padding = (4, 2, 1) if downsample else (3, 1, 1)

        # create modules for bottom-up pass
        self.up_a_layout = [h_size, z_size * post_model.params_per_dim()]
        self.up_a = nn.Conv2d(h_size, sum(self.up_a_layout), kernel_size, stride, padding)
        self.up_b = nn.Conv2d(h_size, 2*h_size, 3, 1, 1)

        # create modules for top-down pass
        self.down_a_layout = [h_size, z_size * post_model.params_per_dim(), z_size * prior_model.params_per_dim()]
        if use_sdn:
            self.down_a = ResSDN(in_ch=h_size, out_ch=sum(self.down_a_layout), state_size=sdn_state, dirs=sdn_dirs_a,
                                 kernel_size=3, stride=1, padding=1, upsample=False)
            self.down_b = ResSDN(in_ch=h_size + z_size, out_ch=2 * h_size, state_size=sdn_state, dirs=sdn_dirs_b,
                                 kernel_size=kernel_size, stride=stride, padding=padding, upsample=downsample)
        else:
            #"""
            self.down_a = nn.Conv2d(h_size, sum(self.down_a_layout), 3, 1, 1)
            cnn_module = nn.ConvTranspose2d if downsample else nn.Conv2d
            self.down_b = cnn_module(h_size + z_size, 2 * h_size, kernel_size, stride, padding)
            """
            self.down_a = nn.Conv2d(h_size, sum(self.down_a_layout), 1)
            if downsample:
                self.down_b = nn.ConvTranspose2d(h_size + z_size, 2 * h_size, 2, 2, 0)
            else:
                self.down_b = nn.Conv2d(h_size + z_size, 2 * h_size, 1)
            """

    def up(self, input):

        x = self.act(input)
        x = self.up_a(x)

        h, self.logqzx_params_up = x.split(self.up_a_layout, 1)

        h = self.act(h)
        h = self.up_b(h)

        h, gate = h.chunk(2, 1)
        gate = torch.sigmoid(gate-1)

        # possibly downsample input
        if self.downsample:
            input = F.upsample(input, scale_factor=0.5)

        return (1-gate) * input + gate * h

    def down(self, input, sample=False, temperature=1.0, fixed_z=None):

        x = self.act(input)
        x = self.down_a(x)

        h_det, logqzx_params_down, logpz_params = x.split(self.down_a_layout, 1)

        if sample:
            z = self.prior_model.sample_once(logpz_params, temperature)
            if fixed_z is not None:
                z = z * temperature + fixed_z * (1-temperature)
            kl = kl_obj = torch.zeros(input.size(0), device=input.device)
        elif fixed_z is not None:
            z = fixed_z
            kl = kl_obj = torch.zeros(input.size(0), device=input.device)
        else:
            # merge posterior parameters
            q_params = self.logqzx_params_up + logqzx_params_down
            # sample and compute E[log p(z|x)]
            z, logqzx = self.post_model.reparameterize(q_params)
            # compute E[log p(z)]
            logpz = self.prior_model.conditional_log_prob(logpz_params, z)
            # compute KL[p(z|x)||p(z)]
            kl = kl_obj = logqzx - logpz
            # free bits are computed per layer
            kl_extra = (max(kl_obj.mean(), self.free_bits) - kl_obj.mean()) / kl_obj.size(0)
            kl_obj = kl_obj + kl_extra

        h = torch.cat((z, h_det), 1)
        h = self.act(h)
        h = self.down_b(h)

        h, gate = h.chunk(2, 1)
        gate = torch.sigmoid(gate-1)

        # possibly upsample input
        if self.downsample:
            input = F.upsample(input, scale_factor=2.0)

        return (1-gate) * input + gate * h, kl, kl_obj, z


class BaselineBlock1(nn.Module):

    def __init__(self, in_ch, out_ch, state_size, dirs, kernel_size, stride, padding, upsample):
        super().__init__()
        self.res_cnn = nn.ConvTranspose2d(in_ch, out_ch, kernel_size, stride, padding)
        self.cnn_stack = nn.Sequential(
            nn.ConvTranspose2d(in_ch, state_size, kernel_size, stride, padding),
            nn.ELU(),
            nn.Conv2d(state_size, 2 * out_ch, 5, 1, 2)
        )

    def forward(self, input):
        cnn_res_out = self.res_cnn(input)
        cnn_stack_out, gate = self.cnn_stack(input).chunk(2, 1)
        gate = torch.sigmoid(gate)
        return gate * cnn_res_out + (1-gate) * cnn_stack_out


class BaselineBlock2(nn.Module):

    def __init__(self, in_ch, out_ch, state_size, dirs, kernel_size, stride, padding, upsample):
        super().__init__()
        self.cnn_stack = nn.Sequential(
            nn.ConvTranspose2d(in_ch, state_size, kernel_size, stride, padding),
            nn.ELU(),
            nn.Conv2d(state_size, out_ch, 5, 1, 2)
        )

    def forward(self, input):
        return self.cnn_stack(input)

class BaselineBlock3(nn.Module):

    def __init__(self, in_ch, out_ch, state_size, dirs, kernel_size, stride, padding, upsample):
        super().__init__()
        self.res_cnn = nn.ConvTranspose2d(in_ch, out_ch, kernel_size, stride, padding)
        self.cnn_stack = nn.Sequential(
            nn.ConvTranspose2d(in_ch, state_size, kernel_size, stride, padding),
            nn.ELU(),
            nn.Conv2d(state_size, state_size, 3, 1, 1),
            nn.ELU(),
            nn.Conv2d(state_size, 2 * out_ch, 3, 1, 1)
        )

    def forward(self, input):
        cnn_res_out = self.res_cnn(input)
        cnn_stack_out, gate = self.cnn_stack(input).chunk(2, 1)
        gate = torch.sigmoid(gate)
        return gate * cnn_res_out + (1 - gate) * cnn_stack_out


class BaselineBlock4(nn.Module):

    def __init__(self, in_ch, out_ch, state_size, dirs, kernel_size, stride, padding, upsample):
        super().__init__()
        if not upsample:
            self.cnn = nn.Conv2d(in_ch, out_ch, 7, 1, 3)
        else:
            self.cnn = nn.ConvTranspose2d(in_ch, out_ch, 8, 2, 3)

    def forward(self, input):
        return self.cnn(input)
