"""
Our Spatial Broadcast Decoder implementation is based on this: https://github.com/dfdazac/vaesbd/blob/master/model.py
"""
from spaghettini import quick_register
import math

from torch import nn
import torch
import torch.nn.functional as F
from torch.nn import init


@quick_register
class SpatialBroadcastDecoder(nn.Module):
    def __init__(self, in_feats, im_size, hid_channels, num_conv_layers, out_channels, kernel_size, act):
        super().__init__()
        assert num_conv_layers >= 3, f"There must be at least 3 convolutional layers in the network. "
        self.in_feats = in_feats
        self.im_size = im_size
        self.hid_channels = hid_channels
        self.num_conv_layers = num_conv_layers
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.act = act

        # Form the "coordinate tiles". Since we're not using any padding, we have to start off with a larger image
        # so that we end up with the rigth size in the end.
        slack = num_conv_layers * (kernel_size - 1)
        self.initial_grid_size = (im_size[0] + slack, im_size[1] + slack)
        x = torch.linspace(-1, 1, self.initial_grid_size[0])
        y = torch.linspace(-1, 1, self.initial_grid_size[1])
        x_grid, y_grid = torch.meshgrid(x, y)
        grid = torch.cat((x_grid[None, None, ...], y_grid[None, None, ...]), dim=1)
        self.register_buffer('grid', grid)

        # Form the intermediate convolution operations.
        self.first_conv = nn.Conv2d(in_channels=self.in_feats + 2, out_channels=self.hid_channels,
                                    kernel_size=self.kernel_size, padding=0)

        def get_hidden_conv():
            return nn.Conv2d(in_channels=self.hid_channels, out_channels=self.hid_channels,
                             kernel_size=self.kernel_size, padding=0)

        dec_convs = [get_hidden_conv() for _ in range(self.num_conv_layers - 2)]
        self.dec_convs = nn.ModuleList(dec_convs)

        # Form the final convolution operation.
        self.last_conv = ZeroBiasConv(in_channels=self.hid_channels, out_channels=self.out_channels,
                                      kernel_size=self.kernel_size)

    def forward(self, feats):
        assert len(feats.shape) == 2
        bs = feats.shape[0]

        # Broadcast the features across the width and height dimensions.
        feats = feats.view(feats.shape + (1, 1))
        feats = feats.expand(-1, -1, self.initial_grid_size[0], self.initial_grid_size[1])

        # Concatenate the grid and broadcasted features.
        feats_with_grid = torch.cat((self.grid.expand(bs, -1, -1, -1), feats), dim=1)

        # Apply the convolutions and return.
        zs = self.first_conv(feats_with_grid)
        for module in self.dec_convs:
            zs = self.act(module(zs))

        return self.last_conv(zs)


class ZeroBiasConv(nn.Conv2d):
    def reset_parameters(self) -> None:
        init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
            init.zeros_(self.bias)


if __name__ == "__main__":
    """
    Run from root. 
    python -m src.dl.models.spatial_broadcast_decoder
    """
    test_num = 1

    if test_num == 0:
        xs = torch.rand(size=(250, 128))
        sbd = SpatialBroadcastDecoder(im_size=(10, 10), in_feats=128, hid_channels=16, num_conv_layers=3,
                                      out_channels=1, kernel_size=3, act=nn.functional.leaky_relu)
        ys = sbd(xs)
        print(ys.shape)

    if test_num == 1:
        # ____ Train SBNs on MNIST autoencoding. ____
        from src.dl.models.fully_connected import FCNetFixedWidth
        from torch.nn import ReLU
        import torch.nn as nn
        import torch.nn.functional as F
        import torch.optim as optim
        from torchvision import datasets, transforms
        from torch.utils.data import DataLoader
        import matplotlib.pyplot as plt

        # Get networks
        encoder = FCNetFixedWidth(num_inputs=784, num_hidden_dim=256, num_outputs=256, num_hidden_layers=1,
                                  activation_init=ReLU, use_layer_norm=False, add_final_activation=False)

        decoder = SpatialBroadcastDecoder(in_feats=256, im_size=(28, 28), hid_channels=16, num_conv_layers=3,
                                          out_channels=1, kernel_size=3, act=F.relu)

        # Get loader.
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])
        dataset1 = datasets.MNIST('../data', train=True, download=True, transform=transform)
        train_loader = DataLoader(dataset1, batch_size=32, shuffle=True)

        # Get optimizer.
        optimizer = optim.Adam(*[list(encoder.parameters()) + list(decoder.parameters())], lr=0.00003)

        # Train.
        encoder.train()
        decoder.train()
        max_batches = 10000
        batch_counter = 0
        while batch_counter < max_batches:
            for batch_idx, (data, target) in enumerate(train_loader):
                batch_counter += 1
                optimizer.zero_grad()
                bs = data.shape[0]
                feats = encoder(data.view(bs, -1))
                reconst = decoder(feats)
                loss = F.mse_loss(data.view(bs, -1), reconst.view(bs, -1))
                loss.backward()
                optimizer.step()
                print(f"Batch idx: {batch_counter}, loss: {float(loss)}")
                if batch_idx % 2500 == 0:
                    fig, axs = plt.subplots(2, 1, figsize=(10, 10), dpi=100)
                    axs[0].imshow(data[0, 0], cmap="Greys_r")
                    axs[1].imshow(reconst[0, 0].detach().numpy(), cmap="Greys_r")
                    plt.show()
