"""Prior Network"""
import torch
import torch.nn as nn
from functorch import jacrev,jacfwd,vmap
from .mlp import NLayerLeakyMLP


class NPTransitionPrior(nn.Module):

    def __init__(
        self, 
        lags, 
        latent_size, 
        num_layers=2,
        hidden_dim=64
        ):
        super().__init__()
        self.L = lags
        # self.init_hiddens = nn.Parameter(0.01 * torch.randn(lags, latent_size))       
        gs = [NLayerLeakyMLP(in_features=lags*latent_size+1, 
                             out_features=1, 
                             num_layers=num_layers, 
                             hidden_dim=hidden_dim) for i in range(latent_size)]
        
        self.gs = nn.ModuleList(gs)
    
    def forward(self, x, masks=None):
        # x: [BS, T, D] -> [BS, T-L, L+1, D]
        batch_size, length, input_dim = x.shape
        # init_hiddens = self.init_hiddens.repeat(batch_size, 1, 1)
        # x = torch.cat((init_hiddens, x), dim=1)
        x = x.unfold(dimension = 1, size = self.L+1, step = 1)
        x = torch.swapaxes(x, 2, 3)
        shape = x.shape
        x = x.reshape(-1, self.L+1, input_dim)
        xx, yy = x[:,-1:], x[:,:-1]
        yy = yy.reshape(-1, self.L*input_dim)
        residuals = [ ]
        sum_log_abs_det_jacobian = 0
        for i in range(input_dim):
            if masks is None:
                inputs = torch.cat((yy, xx[:,:,i]),dim=-1)
            else:
                mask = masks[i]
                inputs = torch.cat((yy*mask, xx[:,:,i]),dim=-1)
            residual = self.gs[i](inputs)
            with torch.enable_grad():
                pdd1 = vmap(jacrev(self.gs[i]))(inputs)
            #   pdd2 = jacobian(self.gs[i], inputs, create_graph=True, vectorize=True)
            # Determinant of low-triangular mat is product of diagonal entries
            # logabsdet = torch.log(torch.abs(torch.diag(pdd2[:,0,:,-1])))
            # torch.allclose(torch.diag(pdd2[:,0,:,-1]), pdd1[:,0,-1])
            logabsdet = torch.log(torch.abs(pdd1[:, 0, -1]))
            sum_log_abs_det_jacobian += logabsdet
            residuals.append(residual)

        residuals = torch.cat(residuals, dim=-1)
        residuals = residuals.reshape(batch_size, -1, input_dim)
        sum_log_abs_det_jacobian = torch.sum(sum_log_abs_det_jacobian.reshape(batch_size, length-self.L), dim=1)
        return residuals, sum_log_abs_det_jacobian

class NPChangeTransitionPrior(nn.Module):

    def __init__(
        self, 
        lags, 
        latent_size,
        embedding_dim, 
        num_layers=3,
        hidden_dim=64):
        super().__init__()
        self.L = lags
        # self.init_hiddens = nn.Parameter(0.01 * torch.randn(lags, latent_size))       
        gs = [NLayerLeakyMLP(in_features=latent_size+lags*latent_size+1,
                             out_features=1, 
                             num_layers=num_layers, 
                             hidden_dim=hidden_dim) for i in range(latent_size)]
        
        self.gs = nn.ModuleList(gs)
        self.fc = NLayerLeakyMLP(in_features=embedding_dim,
                                 out_features=latent_size,
                                 num_layers=2,
                                 hidden_dim=embedding_dim)

    def forward(self, x, embeddings, masks=None):
        # x: [BS, T, D] -> [BS, T-L, L+1, D]
        # embeddings (a): [BS, T-L, D] -> [BS, T-L, D]
        #
        batch_size, length, input_dim = x.shape
        embeddings = self.fc(embeddings)
        embeddings_dim = embeddings.shape[-1]
        #embeddings = embeddings[:, self.L:, :]
        embeddings = embeddings.reshape(-1, embeddings_dim)
        # init_hiddens = self.init_hiddens.repeat(batch_size, 1, 1)
        # x = torch.cat((init_hiddens, x), dim=1)
        x = x.unfold(dimension = 1, size = self.L+1, step = 1)
        x = torch.swapaxes(x, 2, 3)
        shape = x.shape
        # x: [BS*(T-L), L+1,D]
        x = x.reshape(-1, self.L+1, input_dim)
        # yy: [BS*(T-L), L, D]
        # xx: [BS*(T-L), 1, D]
        xx, yy = x[:,-1:], x[:,:-1]
        # yy [BS*(T-L), L*D]
        yy = yy.reshape(-1, self.L*input_dim)
        residuals = [ ]
        sum_log_abs_det_jacobian = 0
        for i in range(input_dim):
            if masks is None:
                inputs = torch.cat((embeddings, yy, xx[:,:,i]),dim=-1)
            else:
                mask = masks[i]
                inputs = torch.cat((embeddings, yy*mask, xx[:,:,i]),dim=-1)
            residual = self.gs[i](inputs)
            with torch.enable_grad():
                pdd1 = vmap(jacrev(self.gs[i]))(inputs)
                #pdd = jacobian(self.gs[i], inputs, create_graph=True, vectorize=True)
            # Determinant of low-triangular mat is product of diagonal entries
            #logabsdet = torch.log(torch.abs(torch.diag(pdd[:,0,:,-1])))
            logabsdet = torch.log(torch.abs(pdd1[:, 0, -1]))
            sum_log_abs_det_jacobian += logabsdet
            residuals.append(residual)

        residuals = torch.cat(residuals, dim=-1)
        residuals = residuals.reshape(batch_size, -1, input_dim)
        sum_log_abs_det_jacobian = torch.sum(sum_log_abs_det_jacobian.reshape(batch_size, length-self.L), dim=1)
        return residuals, sum_log_abs_det_jacobian
