import torch
from torch import nn
from torch.nn import Module
from torch.nn import functional as F
from distributions.bijector import BlockLowerTriangular, ScaleShift, NLSq
from . import log_normal
from . import StandardGaussian



class MADE(Module):
    # Fixed sized hidden layers for now.
    def __init__(self, input_size, conditioned_size=0,
                 base_distribution=StandardGaussian,
                 params_size=2,
                 layers=3, hidden_upscale=8, activation=nn.Tanh):

        super(MADE, self).__init__()

        self.params_size = params_size
        self.autoreg = nn.ModuleList(
            [BlockLowerTriangular(
                features=input_size,
                in_block=1,
                out_block=hidden_upscale,
                bias=True, strict=True)] +
            [BlockLowerTriangular(
                features=input_size,
                in_block=hidden_upscale,
                out_block=hidden_upscale,
                bias=True, strict=False)
             for _ in range(1, layers - 1)] +
            [BlockLowerTriangular(
                features=input_size,
                in_block=hidden_upscale,
                out_block=self.params_size,
                bias=True, strict=False)]
        )
        # nn.init.zeros_(self.autoreg[-1].transform.weight)
        # nn.init.zeros_(self.autoreg[-1].transform.bias)
        self.activations = nn.ModuleList([activation() for _ in range(layers - 1)])

        self.conditioned_features = conditioned_size
        if conditioned_size > 0:
            def build_transform():
                transform = nn.Linear(conditioned_size, 2 * input_size * hidden_upscale)
                nn.init.xavier_uniform_(transform.weight)
                nn.init.zeros_(transform.bias)
                return transform
            transform = nn.Linear(conditioned_size,
                                  input_size * self.params_size)
            nn.init.zeros_(transform.weight)
            nn.init.zeros_(transform.bias)
            cond_transforms = [build_transform() for _ in range(layers - 1)] + [transform]
            self.cond_transform = nn.ModuleList(cond_transforms)
        else:
            self.cond_transform = None

        self.layers = layers
        self.base_distribution = base_distribution

    def forward(self, input, conditioned=None):
        if self.conditioned_features > 0:
            assert(conditioned is not None)
        prev_h = input
        for i in range(self.layers):
            lin_prev_h = self.autoreg[i](prev_h)
            if i < self.layers - 1:
                if self.cond_transform is not None:
                    shift, lin_scale = self.cond_transform[i](conditioned).chunk(2, dim=-1)
                    lin_prev_h = (lin_prev_h - shift) / F.softplus(lin_scale)
                prev_h = self.activations[i](lin_prev_h)
        lin_final = lin_prev_h
        if self.cond_transform is not None:
            lin_final = lin_final + self.cond_transform[-1](conditioned)
        """
        # DEBUG
        print(lin_final.size())
        for i in range(input.size(1)):
            print(i)
            for j in range(self.params_size):
                grd, = torch.autograd.grad(lin_final[0, i * self.params_size + j],
                                           input, retain_graph=True)
                print(grd)
        """
        log_px = self.base_distribution(input, lin_final)
        return log_px

    def sample(self, batch_size, conditioned=None, eps=None, device=None):
        raise NotImplementedError()

if __name__ == "__main__":
    made = MADE(5, params_size=5)
    made(torch.randn(1, 5, requires_grad=True))
