from spaghettini import quick_register

import numpy as np

import torch
import torch.nn as nn
from torch.nn import init
from torch.nn.utils import spectral_norm, remove_spectral_norm

from src.utils.misc import recursively_remove_spectral_norm

SMALL_SCALE = 0.0001


class SmallInitConv2D(nn.Conv2d):
    def reset_parameters(self) -> None:
        init.trunc_normal_(self.weight, mean=0., std=SMALL_SCALE)
        if self.bias is not None:
            init.zeros_(self.bias)


class DiracInitConv2D(nn.Conv2d):
    def reset_parameters(self) -> None:
        init.dirac_(self.weight)
        if self.bias is not None:
            init.zeros_(self.bias)


class SmallInitConvTranspose2d(nn.ConvTranspose2d):
    def reset_parameters(self) -> None:
        init.trunc_normal_(self.weight, mean=0., std=SMALL_SCALE)
        if self.bias is not None:
            init.zeros_(self.bias)


def double_conv(in_channels, out_channels, hid_channels=None, use_spectral_norm=False):
    hid_channels = out_channels if hid_channels is None else hid_channels
    get_first_conv = lambda: SmallInitConv2D(in_channels, hid_channels, (3, 3), padding=(1, 1))
    get_second_conv = lambda: SmallInitConv2D(hid_channels, out_channels, (3, 3), padding=(1, 1))

    first_conv = get_first_conv() if not use_spectral_norm else spectral_norm(module=get_first_conv())
    second_conv = get_second_conv() if not use_spectral_norm else spectral_norm(module=get_second_conv())
    return nn.Sequential(
        first_conv,
        nn.InstanceNorm2d(hid_channels),
        nn.LeakyReLU(inplace=True),
        second_conv,
        nn.InstanceNorm2d(out_channels),
        nn.LeakyReLU(inplace=True)
    )


@quick_register
class UNet(nn.Module):
    """
    Taken from https://github.com/stelzner/monet/blob/master/model.py
    """

    def __init__(self, num_blocks, in_channels, out_channels, channel_base=64, add_input_img=False,
                 use_spectral_norm=False):
        super().__init__()
        self.num_blocks = num_blocks
        self.down_convs = nn.ModuleList()
        cur_in_channels = in_channels
        for i in range(num_blocks):
            self.down_convs.append(double_conv(cur_in_channels, channel_base * 2 ** i,
                                               use_spectral_norm=use_spectral_norm))
            cur_in_channels = channel_base * 2 ** i

        self.tconvs = nn.ModuleList()
        for i in range(num_blocks - 1, 0, -1):
            curr_conv_module = SmallInitConv2D(channel_base * 2 ** i, channel_base * 2 ** (i - 1),
                                               kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            if use_spectral_norm:
                curr_conv_module = spectral_norm(curr_conv_module)
            self.tconvs.append(nn.Sequential(nn.Upsample(scale_factor=2),
                                             curr_conv_module))
            # TODO: Give option to use spectral norm for SmallINitConvTranspose2d as well.
            # self.tconvs.append(SmallInitConvTranspose2d(channel_base * 2 ** i,
            #                                             channel_base * 2 ** (i - 1),
            #                                             2, stride=2))

        self.up_convs = nn.ModuleList()
        for i in range(num_blocks - 2, -1, -1):
            self.up_convs.append(double_conv(channel_base * 2 ** (i + 1), channel_base * 2 ** i,
                                             use_spectral_norm=use_spectral_norm))

        final_conv_channels = channel_base if not add_input_img else channel_base + 1
        self.final_conv = nn.Conv2d(final_conv_channels, out_channels, 1)
        if use_spectral_norm:
            self.final_conv = spectral_norm(self.final_conv)
        self.add_input_img = add_input_img

    def forward(self, x):
        intermediates = []
        cur = x
        for down_conv in self.down_convs[:-1]:
            cur = down_conv(cur)
            intermediates.append(cur)
            cur = nn.MaxPool2d(2)(cur)

        cur = self.down_convs[-1](cur)

        for i in range(self.num_blocks - 1):
            cur = self.tconvs[i](cur)
            cur = torch.cat((cur, intermediates[-i - 1]), 1)
            cur = self.up_convs[i](cur)

        if self.add_input_img:
            cur = torch.cat((cur, x), 1)

        return self.final_conv(cur)


@quick_register
class ClunkyResNet(nn.Module):
    def __init__(self, in_channels, hid_channels, out_channels, num_hid_blocks, final_flatten=False,
                 residual_dropout_rate=0., use_spectral_norm=False):
        super().__init__()
        self.num_hid_blocks = num_hid_blocks
        self.residual_dropout_rate = residual_dropout_rate
        self.first_conv = double_conv(in_channels=in_channels, out_channels=hid_channels,
                                      use_spectral_norm=use_spectral_norm)
        self.middle_convs = nn.ModuleList()
        self.middle_dropouts = nn.ModuleList()
        for i in range(num_hid_blocks):
            self.middle_convs.append(
                double_conv(in_channels=hid_channels, out_channels=hid_channels, use_spectral_norm=use_spectral_norm))
            self.middle_dropouts.append(nn.Dropout2d(p=residual_dropout_rate))
        self.final_conv = double_conv(in_channels=hid_channels, out_channels=out_channels,
                                      use_spectral_norm=use_spectral_norm)
        self.final_flatten = final_flatten

    def forward(self, x):
        bs = x.shape[0]
        zs = self.first_conv(x)
        for i in range(self.num_hid_blocks):
            zs = zs + self.middle_dropouts[i](self.middle_convs[i](zs))
        zs = self.final_conv(zs)
        return zs if not self.final_flatten else zs.view(bs, -1)


@quick_register
class DoubleConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, hid_channels, use_spectral_norm=False):
        super().__init__()
        self.block = double_conv(in_channels=in_channels, out_channels=out_channels, hid_channels=hid_channels,
                                 use_spectral_norm=use_spectral_norm)

    def forward(self, x):
        return self.block(x)


@quick_register
class SingleResnetBlockWithClassifier(nn.Module):
    def __init__(self, in_channels, out_channels, hid_channels, im_size, num_classes, use_spectral_norm=False):
        super().__init__()

        self.block = double_conv(in_channels=in_channels, out_channels=out_channels, hid_channels=hid_channels,
                                 use_spectral_norm=use_spectral_norm)
        self.classifier = nn.Linear(in_features=np.prod(im_size), out_features=num_classes)

    def forward(self, x):
        bs = x.shape[0]
        z = self.block(x)
        z = z.mean(dim=1, keepdims=True).view(bs, -1)
        return self.classifier(z)


if __name__ == "__main__":
    """
    Run from root. 
    python -m src.dl.models.convolutional
    """
    test_num = 4

    if test_num == 0:
        # ____ Test UNet. ___
        # Sample MNIST Inputs used in CMNIST experiments.
        from src.data.datasets.n_digit_mnist import NDigitMNIST
        from src.data.data_loading.transforms import resize_normalize, resize_noise_normalize
        import matplotlib.pyplot as plt

        new_size = 32
        dataset = NDigitMNIST(root="./data", digits=[4, 9], train=True, transform=resize_normalize(new_size),
                              secondary_transform=resize_normalize)

        # Initialize the network.
        unet = UNet(num_blocks=3, in_channels=1, out_channels=1, channel_base=32, add_input_img=True)

        num_imgs = 10
        for i in range(num_imgs):
            data_dict = dataset[i]
            img = data_dict["raw_img"][None, ...]

            outs = unet(img)

            plt.imshow(outs.detach().numpy()[0, 0], cmap="coolwarm")
            plt.colorbar()
            plt.show()

    if test_num == 1:
        # Test ShallowResNet.
        # Sample MNIST Inputs used in CMNIST experiments.
        from src.data.datasets.n_digit_mnist import NDigitMNIST
        from src.data.data_loading.transforms import resize_normalize, resize_noise_normalize
        import matplotlib.pyplot as plt

        new_size = 32
        dataset = NDigitMNIST(root="./data", digits=[4, 9], train=True, transform=resize_normalize(new_size),
                              secondary_transform=resize_normalize)

        # Initialize the network.
        cnet = ClunkyResNet(in_channels=1, out_channels=1, hid_channels=32, num_hid_blocks=1)

        num_imgs = 10
        for i in range(num_imgs):
            data_dict = dataset[i]
            img = data_dict["raw_img"][None, ...]

            outs = cnet(img)

            plt.imshow(outs.detach().numpy()[0, 0], cmap="coolwarm")
            plt.colorbar()
            plt.show()

    if test_num == 2:
        # ____ Test single ResNet block. ____
        # Sample MNIST Inputs used in CMNIST experiments.
        from src.data.datasets.n_digit_mnist import NDigitMNIST
        from src.data.data_loading.transforms import resize_normalize, resize_noise_normalize
        import matplotlib.pyplot as plt

        new_size = 32
        dataset = NDigitMNIST(root="./data", digits=[4, 9], train=True, transform=resize_normalize(new_size),
                              secondary_transform=resize_normalize)

        # Initialize the network.
        srn = DoubleConvBlock(in_channels=1, out_channels=1, hid_channels=32)

        num_imgs = 10
        for i in range(num_imgs):
            data_dict = dataset[i]
            img = data_dict["raw_img"][None, ...]

            outs = srn(img)

            plt.imshow(outs.detach().numpy()[0, 0], cmap="coolwarm")
            plt.colorbar()
            plt.show()

    if test_num == 3:
        # ____ Test single resnet with classifier. ____
        # Sample MNIST Inputs used in CMNIST experiments.
        from src.data.datasets.n_digit_mnist import NDigitMNIST
        from src.data.data_loading.transforms import resize_normalize, resize_noise_normalize
        import matplotlib.pyplot as plt

        new_size = 32
        dataset = NDigitMNIST(root="./data", digits=[4, 9], train=True, transform=resize_normalize(new_size),
                              secondary_transform=resize_normalize)

        # Initialize the network.
        srnwc = SingleResnetBlockWithClassifier(in_channels=1, out_channels=32, hid_channels=32, im_size=(32, 32),
                                                num_classes=10)

        num_imgs = 10
        for i in range(num_imgs):
            data_dict = dataset[i]
            img = data_dict["raw_img"][None, ...]

            outs = srnwc(img)
            print(outs.shape)

    if test_num == 4:
        import copy
        # Check if spectral_norm accepts SmallInitConv2D.
        xs = torch.randn(size=(5, 5, 10, 10))
        convnet = double_conv(in_channels=5, out_channels=5, use_spectral_norm=True)
        ys = convnet(xs)
        try:
            cm_copy = copy.deepcopy(convnet)
        except:
            print("Failed to copy.")
        # cm = remove_spectral_norm(module=convnet)
        recursively_remove_spectral_norm(module=convnet)
        cm = copy.deepcopy(convnet)
        breakpoint()
        print(f"Success!")
