import torch
import torch.nn as nn
from normflows.flows import Flow

from .nets import FNN, CustomActivation, NodeNetwork

class CustomSplit(Flow):
    """ Split features into three sets
    """

    def __init__(self, node, parents, nonparents):
        """Constructor

        Arguments:
        ----------
        node : int
            Index of target node
        parents : list of ints
            Indices of node's parents
        nonparents : list of ints
            Indices of non-parent nodes
        """
        super().__init__()
        self.node = [node] # turn node into a list to get z_node of shape [batch, 1]
        self.parents = parents
        self.nonparents = nonparents

    def forward(self, z):

        z_node = z[:, self.node]
        z_parents = z[:, self.parents]
        z_nonparents = z[:, self.nonparents]
        
        log_det = 0
        
        return [z_node, z_parents, z_nonparents], log_det

    def inverse(self, z):
        
        z_node, z_parents, z_nonparents = z

        n_nodes = 1 + z_parents.shape[1] + z_nonparents.shape[1]
        n_samples = z_node.shape[0]
        
        z = torch.empty(n_samples, n_nodes)
        z[:, self.node] = z_node
        z[:, self.parents] = z_parents
        z[:, self.nonparents] = z_nonparents
        
        log_det = 0
        
        return z, log_det

class CustomMerge(CustomSplit):
    """ Same as CustomSplit but with forward and backward pass interchanged
    """

    def __init__(self, node, parents, nonparents):
        super().__init__(node, parents, nonparents)

    def forward(self, z):
        return super().inverse(z)

    def inverse(self, z):
        return super().forward(z)


class HyperCoupling(Flow):
    """ Hypernetwork Coupling layer.
    """

    def __init__(self, DAG, node, hypernet_nhs=[32, 32], nodenet_nh=16):
        """Constructor

        Parameters:
        -----------
        hypernet_nhs : list of ints
            Number of neurons in each hidden layer of the hypernetwork.
        nodenet_nh : int
            Number of hidden neurons in the node's (shallow) network.
        """
        super().__init__()

        self.nh = nodenet_nh
        # the hypernet outputs the parameters of the node network [weight1, weight2, bias1, bias2]
        self.hypernet = FNN(layers=[len(DAG.get_parents(node))] + hypernet_nhs + [3*self.nh + 1], activation='relu')
        # activation function of node network
        self.act = CustomActivation(alpha=0.05)
        
    def forward(self, z):

        z_node, z_parents, z_nonparents = z
        
        param = self.hypernet(z_parents)

        # enforce positive weights by taking absolute value
        weight1 = torch.abs(param[:, :self.nh])
        weight2 = torch.abs(param[:, self.nh:2*self.nh])
        bias1 = param[:, 2*self.nh:3*self.nh]
        bias2 = param[:, 3*self.nh:][:, None]

        node_net = NodeNetwork(weight1, weight2, bias1, bias2, self.act.alpha)
        
        z_node, log_det = node_net.forward(z_node)

        return [z_node, z_parents, z_nonparents], log_det

    def inverse(self, z):
        z_node, z_parents, z_nonparents = z

        param = self.hypernet(z_parents)

        # enforce positive weights by taking absolute value
        weight1 = torch.abs(param[:, :self.nh])
        weight2 = torch.abs(param[:, self.nh:2*self.nh])
        bias1 = param[:, 2*self.nh:3*self.nh]
        bias2 = param[:, 3*self.nh:][:, None]

        node_net = NodeNetwork(weight1, weight2, bias1, bias2, self.act.alpha)

        z_node, log_det = node_net.inverse(z_node)
        
        return [z_node, z_parents, z_nonparents], log_det

class HyperCouplingBlock(Flow):
    """ Hypernetwork Coupling layer including split and merge operation
    """

    def __init__(self, DAG, node, hypernet_nhs=[32, 32], nodenet_nh=16):
        """Constructor
        """
        super().__init__()
        
        self.flows = nn.ModuleList([])
        self.flows += [CustomSplit(node, DAG.get_parents(node), DAG.get_nonparents(node))]
        self.flows += [HyperCoupling(DAG, node, hypernet_nhs=hypernet_nhs, nodenet_nh=nodenet_nh)]
        self.flows += [CustomMerge(node, DAG.get_parents(node), DAG.get_nonparents(node))]

    def forward(self, z):
        log_det_tot = torch.zeros(z.shape[0], dtype=z.dtype, device=z.device)
        for flow in self.flows:
            z, log_det = flow(z)
            log_det_tot += log_det
        return z, log_det_tot

    def inverse(self, z):
        log_det_tot = torch.zeros(z.shape[0], dtype=z.dtype, device=z.device)
        for i in range(len(self.flows) - 1, -1, -1):
            z, log_det = self.flows[i].inverse(z)
            log_det_tot += log_det
        return z, log_det_tot

