import pdb
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from ncpn.layers import *
# from ncpn.layers import NetworkInNetwork, DownShiftedConv2dTranpose, DownRightShiftedConv2dTranspose
from ncpn.utils import *
from ncpn.axial_attention import *
import numpy as np
import functools

class GaussianFourierProjection(nn.Module):
    """Gaussian Fourier embeddings for noise levels."""

    def __init__(self, embedding_size=256, scale=16.0):
        super().__init__()
        self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)

    def forward(self, x):
        x_proj = x[:, None] * self.W[None, :] * 2 * np.pi
        return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)

class ResBlock(nn.Module):
    def __init__(
        self, num_filters, conv_op, nonlinearity,
        dropout, skip_connection=0, nin=NetworkInNetwork,
        temb_dim=None, c_dim=None
    ):
        super().__init__()
        self.skip_connection = skip_connection
        self.nonlinearity = nonlinearity
        self.conv_input = conv_op(2 * num_filters, num_filters) # cuz of concat elu

        if skip_connection != 0 :
            self.nin_skip = nin(2 * skip_connection * num_filters, num_filters)

        if temb_dim is not None:
            self.temb_mlp = Linear(2 * temb_dim, num_filters)

        if c_dim is not None:
            self.c_mlp = Linear(c_dim, 2 * num_filters, init_scale=0.05)

        self.dropout = nn.Dropout2d(dropout)

        self.conv_out = conv_op(2 * num_filters, 2 * num_filters)


    def forward(self, x, a=None, temb=None, c=None):
        h = self.conv_input(self.nonlinearity(x))
        if a is not None :
            h += self.nin_skip(self.nonlinearity(a))
        if temb is not None:
            h += self.temb_mlp(self.nonlinearity(temb))[:, :, None, None]
        h = self.nonlinearity(h)
        h = self.dropout(h)
        h = self.conv_out(h)

        if c is not None:
            h += self.c_mlp(c)[:, :, None, None]

        h1, h2 = torch.chunk(h, 2, dim=1)
        h = h1 * h2.sigmoid()
        return x + h

class AttentionBlock(nn.Module):
    def __init__(self, nf, shape, channel_dim, heads, causal=False, u_only=False, learnable_positions=False):
        super().__init__()
        # get permutations
        self.causal = causal
        self.shape = shape
        n_axes = len(shape) - 1
        permutations = calculate_permutations(n_axes, channel_dim)
        self.pos_emb = AxialPositionalEmbedding(
            nf,
            shape[1:],
            channel_dim,
            learnable=learnable_positions
        )
        u_layer = nn.ModuleList([
            PermuteToFrom(
                permutation,
                Attentionpp(
                    nf,
                    heads,
                    causal=(i + 1 == len(permutations)),
                )
            )
            for i, permutation in enumerate(permutations)
        ])
        self.u_layer = Sequential(u_layer)

        if not u_only:
            layer = nn.ModuleList([
                PermuteToFrom(
                    permutation,
                    Attentionpp(
                        nf,
                        heads,
                        causal=True,
                    )
                )
                for i, permutation in enumerate(permutations[:-1])
            ])
            self.layer = Sequential(layer)


    def forward(self, x):
        x = self.pos_emb(x)
        h = self.u_layer(x)
        if hasattr(self, 'layer'):
            if self.causal:
                h = col_shift(h) + row_shift(x)
            else:
                h = col_shift(h) + x
            h = self.layer(h)
        return h

class NCPNUp(nn.Module):
    def __init__(self, nr_resnet, nr_filters, act, dropout, attn_block_cls=None, shape=None, temb_dim=None, c_dim=None):
        super().__init__()
        self.nr_resnet = nr_resnet
        # stream from pixels above
        self.u_stream = nn.ModuleList([
            ResBlock(nr_filters, DownShiftedConv2d, act, dropout, skip_connection=0, temb_dim=temb_dim, c_dim=c_dim)
                                            for _ in range(nr_resnet)])

        # stream from pixels above and to the left
        self.ul_stream = nn.ModuleList([
            ResBlock(nr_filters, DownRightShiftedConv2d, act, dropout, skip_connection=1, temb_dim=temb_dim, c_dim=c_dim)
                                            for _ in range(nr_resnet)])

        if attn_block_cls is not None:
            self.attblocks = nn.ModuleList([
                attn_block_cls(nr_filters, shape)
                for _ in range(nr_resnet)
            ])

    def forward(self, u, ul, temb=None, c=None):
        u_list, ul_list = [], []

        for i in range(self.nr_resnet):
            u  = self.u_stream[i](u, temb=temb, c=c)
            if hasattr(self, 'attblocks'):
                u = self.attblocks[i](u)
            ul = self.ul_stream[i](ul, a=u, temb=temb, c=c)
            u_list  += [u]
            ul_list += [ul]

        return u_list, ul_list


class NCPNDown(nn.Module):
    def __init__(self, nr_resnet, nr_filters, act, dropout, attn_block_cls=None, shape=None, temb_dim=None, c_dim=None):
        super().__init__()
        self.nr_resnet = nr_resnet
        # stream from pixels above
        self.u_stream  = nn.ModuleList([
            ResBlock(nr_filters, DownShiftedConv2d, act, dropout, skip_connection=1, temb_dim=temb_dim, c_dim=c_dim)
                                            for _ in range(nr_resnet)])

        # stream from pixels above and to thes left
        self.ul_stream = nn.ModuleList([
            ResBlock(nr_filters, DownRightShiftedConv2d, act, dropout, skip_connection=2, temb_dim=temb_dim, c_dim=c_dim)
                                            for _ in range(nr_resnet)])

        if attn_block_cls is not None:
            self.attblocks = nn.ModuleList([
                attn_block_cls(nr_filters, shape)
                for _ in range(nr_resnet)
            ])

    def forward(self, u, ul, u_list, ul_list, temb=None, c=None):
        for i in range(self.nr_resnet):
            u  = self.u_stream[i](u, a=u_list.pop(), temb=temb, c=c)
            if hasattr(self, 'attblocks'):
                u = self.attblocks[i](u)
            ul = self.ul_stream[i](ul, a=torch.cat((u, ul_list.pop()), 1), temb=temb, c=c)

        return u, ul


class NCPN(nn.Module):
    def __init__(
        self,
        shape,
        nr_resnet=5,
        nr_filters=80,
        nr_resolutions=3,
        nr_logistic_mix=10,
        time_cond=False,
        dropout=0.5,
        attn=True,
        c_dim=None
    ):
        super().__init__()
        self.act = concat_elu
        self.time_cond = time_cond
        self.c_dim = c_dim
        self.nr_filters = nr_filters
        c, h, w = shape
        self.nr_logistic_mix = nr_logistic_mix
        self.nr_resolutions = nr_resolutions

        if attn:
            AttnBlock = functools.partial(AttentionBlock,
                                          channel_dim=1,
                                          heads=8,
                                          u_only=True
                                         )
        else:
            AttnBlock = None

        down_nr_resnet = [nr_resnet] + [nr_resnet + 1] * (self.nr_resolutions - 1)

        self.down_layers = nn.ModuleList([NCPNDown(
            down_nr_resnet[i], nr_filters,
            self.act, dropout, AttnBlock if i != 0 else None,
            (nr_filters, h // pow(2, self.nr_resolutions - 1 - i), w // pow(2, self.nr_resolutions - 1 - i)),
            temb_dim=nr_filters * 4 if time_cond else None,
            c_dim=c_dim
        ) for i in range(nr_resolutions)])

        self.up_layers   = nn.ModuleList([NCPNUp(
            nr_resnet, nr_filters,
            self.act, dropout, AttnBlock,
            (nr_filters, h // pow(2, i), w // pow(2, i)),
            temb_dim=nr_filters * 4 if time_cond else None,
            c_dim=c_dim
        ) for i in range(nr_resolutions)])

        self.downsize_u_stream  = nn.ModuleList([DownShiftedConv2d(nr_filters, nr_filters,
                                                    stride=(2, 2)) for _ in range(nr_resolutions - 1)])

        self.downsize_ul_stream = nn.ModuleList([DownRightShiftedConv2d(nr_filters,
                                                    nr_filters, stride=(2, 2)) for _ in range(nr_resolutions - 1)])

        self.upsize_u_stream  = nn.ModuleList([DownShiftedConv2dTranpose(nr_filters, nr_filters,
                                                    stride=(2, 2)) for _ in range(nr_resolutions - 1)])

        self.upsize_ul_stream = nn.ModuleList([DownRightShiftedConv2dTranspose(nr_filters,
                                                    nr_filters, stride=(2, 2)) for _ in range(nr_resolutions - 1)])

        self.u_init = DownShiftedConv2d(c + 1, nr_filters, filter_size=(2,3),
                        shift_output_down=True)

        self.ul_init = nn.ModuleList([DownShiftedConv2d(c + 1, nr_filters,
                                            filter_size=(1, 3), shift_output_down=True),
                                       DownRightShiftedConv2d(c + 1, nr_filters,
                                            filter_size=(2, 1), shift_output_right=True)])

        num_mix = 3 if c == 1 else 10
        self.nin_out = NetworkInNetwork(nr_filters, num_mix * nr_logistic_mix)
        self.init_padding = None

        if time_cond:
            self.fourier_emb = GaussianFourierProjection(
                embedding_size=nr_filters, scale=16.
            )
            self.nin_temb0 = Linear(nr_filters * 2, nr_filters * 2)
            self.nin_temb1 = Linear(nr_filters * 4, nr_filters * 4)

    def pad(self, x):
        if self.init_padding is None or len(self.init_padding) != len(x):
            xs = [int(y) for y in x.size()]
            padding = torch.ones(xs[0], 1, xs[2], xs[3])
            self.init_padding = padding

        self.init_padding = self.init_padding.to(x.device)
        return torch.cat([x, self.init_padding], axis=1)

    def forward(self, x, cond1=None, c=None):
        cond = cond1
        x = self.pad(x)

        # get time embeddings
        if cond is not None:
            temb = self.fourier_emb(torch.log(cond))
            temb = self.nin_temb0(temb)
            temb = self.nin_temb1(self.act(temb))
        else:
            assert not self.time_cond
            temb = None

        u_list  = [self.u_init(x)]
        ul_list = [self.ul_init[0](x) + self.ul_init[1](x)]
        for i in range(self.nr_resolutions):
            # resnet block
            u_out, ul_out = self.up_layers[i](u_list[-1], ul_list[-1], temb=temb, c=c)
            u_list  += u_out
            ul_list += ul_out

            if i != (self.nr_resolutions - 1):
                # downscale (only twice)
                u_list  += [self.downsize_u_stream[i](u_list[-1])]
                ul_list += [self.downsize_ul_stream[i](ul_list[-1])]

        ###    DOWN PASS    ###
        u  = u_list.pop()
        ul = ul_list.pop()

        for i in range(self.nr_resolutions):
            # resnet block
            u, ul = self.down_layers[i](u, ul, u_list, ul_list, temb=temb, c=c)

            # upscale (only twice)
            if i != (self.nr_resolutions - 1):
                u  = self.upsize_u_stream[i](u)
                ul = self.upsize_ul_stream[i](ul)

        x_out = self.nin_out(F.elu(ul))

        assert len(u_list) == len(ul_list) == 0, pdb.set_trace()

        return x_out

