# Third party imports
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch.autograd import Variable as Variable
import lightning as L

# Local application imports
import sys
sys.path.insert(0, './')
from .spn_nodes import *

EPS = 1e-15


class SPNMulti(L.LightningModule):
    """Sample multiple PCs in parallel"""
    def __init__(self, num_vars, num_pcs, normalize=True, tau=1.0):
        super(SPNMulti, self).__init__()
        self.num_vars = num_vars
        self.num_pcs = num_pcs
        self.normalize = normalize
        self.tau = tau

        self.leaf_layer = None # Layer for leaf distributions.
        self.layers = [] # 1D array of layers
        self.net = None # nn.Sequential module, called once self.layers is ready

        self.param_shape_per_layer = []

        self.terms_mask = None

    def ready(self):
        self.leaf_layer.initialize()
        for layer in self.layers:
            layer.initialize()

        self.num_roots = 1
        self.net = nn.ModuleList([self.leaf_layer] + self.layers)

        return self

    @property
    def num_params(self):
        n = np.sum([np.prod(p) for p in self.param_shape_per_layer[1:-1]])
        n += np.sum([np.prod(pair) for pair in self.param_shape_per_layer[-1]])
        return n
    
    @property
    def numel(self):
        return np.sum([layer.numel for layer in self.net])

    def reinitialize(self, params):
        # params: (self.num_params())
        assert(params.size() == torch.Size([self.num_params()]))

        leaf_params_count = np.prod(self.param_shape_per_layer[0])
        leaf_params, params = params[:leaf_params_count], params[leaf_params_count:]

        # leaf layer
        if isinstance(self.leaf_layer, BernoulliLayerMulti):
            self.leaf_layer.initialize(params=leaf_params)
        else:
            assert(False)

        # inner layers
        for i, layer in enumerate(self.layers):
            params_count = np.prod(self.param_shape_per_layer[i+1])
            layer_params, params = params[:params_count], params[params_count:]
            layer_params = layer_params.view(*self.param_shape_per_layer[i+1])
            layer.initialize(layer_params)

        assert(torch.numel(params) == 0)

    def add_sum_layer(self, num, remainder_num, p_conn=0.2, sparsity_temp=0.1):
        """
        Adds a sum-node layer

        Args:
            num (int): number of sum nodes in the layer
            remainder_num (int): number of remainder input nodes to the layer
        """
        child_num = self.layers[-1].num if self.layers else self.leaf_layer.num

        self.param_shape_per_layer.append([self.num_pcs, child_num])
        self.layers.append(SumLayerMulti(num=num,
                                         child_num=child_num,
                                         remainder_num=remainder_num,
                                         parallel=self.num_pcs,
                                         normalize=self.normalize,
                                         tau=self.tau,
                                         p_conn=p_conn,
                                         sparsity_temp=sparsity_temp))
        return self.layers[-1]

    def add_product_layer(self, num, copies, partitions, remainder_num):
        """
        Adds a product-node layer

        Args:
            num (int): total number of product nodes in the layer
            copies (int): number of sum nodes for each partition
            partitions (int): number of partitions of the set {1, 2, ..., n},\
                where n is the number of nodes.
            remainder_num (int): number of remainder nodes
        """
        self.param_shape_per_layer.append([0])
        self.layers.append(ProductLayerMulti(num=num,
                                             copies=copies,
                                             partitions=partitions,
                                             remainder_num=remainder_num,
                                             parallel=self.num_pcs))
        return self.layers[-1]

    def add_bernoulli_layer(self, num, var):
        """
        Adds a Bernoulli leaf layer

        Args:
            num: int: (): number of nodes in the layer
            var: np_arr: (num): scope variables of the leaf nodes
        """
        var_param = torch.from_numpy(var)
        self.param_shape_per_layer.append([self.num_pcs, num])
        self.leaf_layer = BernoulliLayerMulti(num=num, var=var_param, parallel=self.num_pcs)
        return self.leaf_layer
    
    def add_final_layer(self, num_vars, max_copies, p_conn=0.2, sparsity_temp=1.0):
        self.layers.append(FinalLayerMulti(num_vars=num_vars,
                                           max_copies=max_copies,
                                           parallel=self.num_pcs,
                                           normalize=self.normalize,
                                           tau=self.tau,
                                           p_conn=p_conn,
                                           sparsity_temp=sparsity_temp))
        self.param_shape_per_layer.append([[self.num_pcs, s] for s in self.layers[-1].partition_sizes])
        return self.layers[-1]
    
    def sample(self, batch_size=1, hard=True):
        """
        Samples from the SPN. Used for backpropagation

        Return:
            sample: tensor: (batch, num_pcs, num_vars): The sample from the SPN
        """
        partition_onehot = self.leaf_layer.sample_onehot(batch_size)
        #print(partition_onehot.shape)

        for layer in self.layers:
            if isinstance(layer, ProductLayerMulti):
                if layer.remainder_num > 0:
                    partition_onehot = torch.cat([partition_onehot[:, :, layer.ch1] + partition_onehot[:, :, layer.ch2],
                                                  partition_onehot[:, :, -layer.remainder_num:]],
                                                 dim=2)
                else:
                    partition_onehot = torch.cat([partition_onehot[:, :, layer.ch1] + partition_onehot[:, :, layer.ch2]],
                                                 dim=2)
                #print('Prod', partition_onehot.shape)
            elif isinstance(layer, SumLayerMulti):
                r = layer.remainder_num
                n_total = partition_onehot.shape[2]
                sample_fn = layer.sample
                one_hot_samples = sample_fn(partition_onehot, hard).view(batch_size, layer.parallel, layer.num, layer.child_per_node, 1)  # (batch, parallel, num_child, child_per_node, 1)
                partition_onehot_part_1 = partition_onehot[:, :, :n_total-r].view(batch_size, layer.parallel, layer.num, layer.child_per_node, -1)
                partition_onehot_part_1 = torch.sum(one_hot_samples * partition_onehot_part_1, axis=3)
                partition_onehot = torch.cat([partition_onehot_part_1, partition_onehot[:, :, n_total-r:]], axis=2)
                #print('Sum', partition_onehot.shape)
            elif isinstance(layer, FinalLayerMulti):
                out = layer.sample(partition_onehot)
                # print('Out', out.shape)
        return out
    
    def sample_deterministic(self, hard=True):
        partition_onehot = self.leaf_layer.sample_onehot(1).squeeze(0)
        for layer in self.layers:
            if isinstance(layer, ProductLayerMulti):
                if layer.remainder_num > 0:
                    partition_onehot = torch.cat([partition_onehot[:, layer.ch1] + partition_onehot[:, layer.ch2],
                                                  partition_onehot[:, -layer.remainder_num:]],
                                                 dim=1)
                else:
                    partition_onehot = partition_onehot[:, layer.ch1] + partition_onehot[:, layer.ch2]
                # print('prod', partition_onehot.max(2).values)
            elif isinstance(layer, SumLayerMulti):
                r = layer.remainder_num
                n_total = partition_onehot.shape[1]
                sample_fn = layer.sample_deterministic
                one_hot_samples = sample_fn(hard).view(layer.parallel, layer.num, layer.child_per_node, 1)  # (batch, parallel, num_child, child_per_node, 1)
                partition_onehot_part_1 = partition_onehot[:, :n_total-r].view(layer.parallel, layer.num, layer.child_per_node, -1)
                partition_onehot_part_1 = torch.sum(one_hot_samples * partition_onehot_part_1, axis=2)
                partition_onehot = torch.cat([partition_onehot_part_1, partition_onehot[:, n_total-r:]], axis=1)
                # print('sum', partition_onehot.max(2).values)
            elif isinstance(layer, FinalLayerMulti):
                out = layer.sample_deterministic(partition_onehot, hard)

        return out

    def forward(self, x_in, log=False):
        """Computes the joint probability given input values.

        Args:
            x_in (:class:`torch.Tensor`): Input values of shape (batch, num_pcs, num_vars)

        Returns:
            :class:`torch.Tensor`: output values
        """
        batch, num_pcs, num_vars = x_in.shape
        output = torch.cat([1-x_in.unsqueeze(3), x_in.unsqueeze(3)], 3).view(batch, num_pcs, 2*num_vars)
        if log:
            output = torch.log(output + 1e-16)
        for layer in self.layers:
            output = layer.forward(output) if log else layer.forward_no_log(output)
        return output
    
    def entropy_selective(self):
        """
        Computes the entropy of all spn in parallel.

        Return:
            ent: tensor: (num_pcs,): The entropy of all spn
        """
        ent = self.leaf_layer.entropy()

        for layer in self.layers:
            ent = layer.entropy(ent)
        return ent
    
    def kl_random(self):
        """Compute the KL divergence w.r.t. an SPN with the same structure but with uniform weights in all sum nodes, i.e.
        all sum nodes select each child equally likely.
        """
        kl = self.leaf_layer.kl_random()
        for layer in self.layers:
            kl = layer.kl_random(kl)
        return kl


class SPNFG(L.LightningModule):
    """Sum-product network model for factor graphs
    """
    def __init__(self,
                 num_nodes,
                 num_factors,
                 spn_target='factor',
                 normalize=True,
                 tau=1.0):
        """

        Args:
            num_nodes (int): Number of graph nodes
            num_factors (int): Number of factors
            spn_target (str, optional): Target of the SPN, either 'factor' or 'node'. Defaults to 'factor'.
                If set to factor, the SPN models joint pmf on factors, otherwise on nodes.
            normalize (bool, optional): Whether to use normalized probability. Defaults to True.
            tau (float, optional): Temperature for Gumbel-Softmax. Defaults to 1.0.
        """
        super(SPNFG, self).__init__()
        self.spn_target = spn_target
        self.tau = tau
        # Partition variable Y
        self.register_parameter('logpy', nn.Parameter(torch.zeros(num_nodes, num_factors+1)))
        if spn_target == 'factor':
            self.spn = SPNMulti(num_vars=num_factors,
                                num_pcs=num_nodes,
                                normalize=normalize,
                                tau=tau)
        else:
            self.spn = SPNMulti(num_vars=num_nodes,
                                num_pcs=num_factors,
                                normalize=normalize,
                                tau=tau)

    @property
    def numel(self):
        return self.spn.numel + self.logpy.numel()

    def ready(self):
        self.spn.ready()
        nn.init.xavier_normal_(self.logpy, gain=0.01)

    def add_sum_layer(self, num, remainder_num, p_conn=0.2, sparsity_temp=1.0):
        return self.spn.add_sum_layer(num, remainder_num, p_conn, sparsity_temp)

    def add_product_layer(self, num, copies, partitions, remainder_num):
        return self.spn.add_product_layer(num, copies, partitions, remainder_num)

    def add_bernoulli_layer(self, num, var):
        return self.spn.add_bernoulli_layer(num, var)

    def add_final_layer(self, num_vars, max_copies, p_conn=0.2, sparsity_temp=1.0):
        self.spn.add_final_layer(num_vars, max_copies, p_conn, sparsity_temp)
    
    def weighted_graph(self, batch_size=1):
        """
        Calculates a weighted factor graphs from the SPN.

        Return:
            graph (torch.tensor): soft factor graph of size (batch_size, num_nodes, num_factors)
        """
        y = F.softmax(self.logpy.repeat(batch_size, 1, 1), tau=self.tau, hard=True)
        
        mtx = self.spn.weighted_graph(batch_size)  # batch x parallel x num vars
        if self.spn_target != 'factor':
            mtx = mtx.transpose(1, 2)  # batch x num vars x parallel

        assert mtx.max() <= 1 or torch.allclose(mtx.max(), torch.tensor(1.0))
        
        # Use the partition to mask impossible edges
        #   last dimension tells if the edge connection is node to factor or factor to node
        is_in_factor2node = torch.cumsum(y, dim=-1)
        factor2node = mtx * (1 - is_in_factor2node[:, :, :-1])
        node2factor = mtx * is_in_factor2node[:, :, :-1]

        return node2factor, factor2node
    
    def sample(self, batch_size=1, hard=True):
        """
        Samples factor graphs from the SPN.

        Return:
            graph (torch.tensor): sampled factor graph of size (batch_size, num_nodes, num_factors)
        """
        y = F.gumbel_softmax(self.logpy.repeat(batch_size, 1, 1), tau=self.tau, hard=True)
        
        mtx = self.spn.sample(batch_size, hard)  # batch x parallel x num vars
        if self.spn_target != 'factor':
            mtx = mtx.transpose(1, 2)  # batch x num vars x parallel

        assert mtx.max() <= 1  or torch.allclose(mtx.max(), torch.tensor(1.0))
        
        # Use the partition to mask impossible edges
        #   last dimension tells if the edge connection is node to factor or factor to node
        is_in_factor2node = torch.cumsum(y, dim=-1)
        factor2node = mtx * (1 - is_in_factor2node[:, :, :-1])
        node2factor = mtx * is_in_factor2node[:, :, :-1]

        return node2factor, factor2node
    
    def sample_deterministic(self, hard=True):
        """
        Samples factor graphs from the SPN.

        Return:
            graph (torch.tensor): sampled factor graph of size (num_nodes, num_factors)
        """
        y = F.one_hot(torch.argmax(self.logpy, 1), self.logpy.shape[1])
        
        mtx = self.spn.sample_deterministic(hard)
        if self.spn_target != 'factor':
            mtx = mtx.T

        is_in_factor2node = torch.cumsum(y, dim=-1)
        factor2node = mtx * (1 - is_in_factor2node[:, :-1])
        node2factor = mtx * is_in_factor2node[:, :-1]

        return node2factor, factor2node
    
    def prob_mode(self, log=False):
        """Computes the probability of the most likely factor graph"""
        py = torch.softmax(self.logpy, -1)
        pymax = torch.max(py, -1)[0]
        if log:
            pymax = pymax.log()
        mtx = self.spn.sample_deterministic(True)
        pb = self.spn.forward(mtx.unsqueeze(0), log)
        if log:
            return torch.sum(pymax, -1), pb
        return torch.prod(pymax, dim=-1), pb

    def prob(self, partition, mtx, log=False):
        """Computes the probability of a random factor graph"""
        if mtx.ndim == 2:
            mtx = mtx.unsqueeze(0)
        with torch.no_grad():
            y = F.one_hot(partition, num_classes=self.logpy.shape[1])
            py = (F.softmax(self.logpy, -1) * y).sum(-1)
            pb = self.spn.forward(mtx, log)
            if log:
                return torch.sum(torch.log(py), -1), pb
            return torch.prod(py, dim=-1), pb
    
    def sample_log_prob(self, batch_size):
        with torch.no_grad():
            y = F.gumbel_softmax(self.logpy.repeat(batch_size, 1, 1), tau=self.tau, hard=True)
            partition = torch.argmax(y, -1)
            
            mtx = self.spn.sample(batch_size)  # batch x parallel x num vars
            if self.spn_target != 'factor':
                mtx = mtx.transpose(1, 2)  # batch x num vars x parallel
                py, pb = self.prob(partition, mtx=mtx.transpose(1, 2), log=True)
            else:
                py, pb = self.prob(partition, mtx=mtx, log=True)
            # Use the partition to mask impossible edges
            #   last dimension tells if the edge connection is node to factor or factor to node
            is_in_factor2node = torch.cumsum(y, dim=-1)
            factor2node = mtx * (1 - is_in_factor2node[:, :, :-1])
            node2factor = mtx * is_in_factor2node[:, :, :-1]
            dag = torch.bmm(node2factor, factor2node.transpose(1, 2))

            return dag, py + pb.sum(1)
    
    def entropy(self):
        py = torch.softmax(self.logpy, -1)
        return self.spn.entropy_selective().sum() - (py * torch.log(py)).sum()
    
    def kl_random(self):
        """Compute the KL divergence w.r.t. an SPN with the same structure but with uniform weights in all sum nodes, i.e.
        all sum nodes select each child equally likely.
        """
        py = torch.softmax(self.logpy, -1)
        kl_y = (py * (torch.log(py) + np.log(self.spn.num_vars+1))).sum()
        return self.spn.kl_random().sum() + kl_y

    def entropy_y(self, reduction='batch mean'):
        py = torch.softmax(self.logpy, -1)
        if reduction == 'mean':
            return -(py * torch.log(py)).mean()
        elif reduction == 'sum':
            return -(py * torch.log(py)).sum()
        elif reduction == 'batch mean':
            return -(py * torch.log(py)).mean(0).sum()
        else:
            return -(py * torch.log(py))        


class ISPNFG(SPNFG):
    """Sum-product network model for factor graphs with non-genetic interventions
    """
    def __init__(self,
                 num_nodes,
                 num_factors,
                 num_intervention,
                 spn_target='factor',
                 normalize=True,
                 tau=1.0):
        """

        Args:
            num_nodes (int): Number of graph nodes
            num_factors (int): Number of factors
            num_intervention (int, optional): Number of non-genetic interventions. Defaults to 0.
            spn_target (str, optional): Target of the SPN, either 'factor' or 'node'. Defaults to 'factor'.
                If set to factor, the SPN models joint pmf on factors, otherwise on nodes.
            normalize (bool, optional): Whether to use normalized probability. Defaults to True.
            tau (float, optional): Temperature for Gumbel-Softmax. Defaults to 1.0.
        """

        super(ISPNFG, self).__init__(num_nodes, num_factors, spn_target, normalize)
        # Each intervention affects factors independently
        if spn_target == 'factor':
            self.spn_int = SPNMulti(
                num_vars=num_factors,
                num_pcs=num_intervention,
                normalize=normalize,
                tau=tau
            )
        else:
            self.spn_int = SPNMulti(
                num_vars=num_intervention,
                num_pcs=num_factors,
                normalize=normalize,
                tau=tau
            )
    
    def ready(self):
        super(ISPNFG, self).ready()
        self.spn_int.ready()
    
    @property
    def numel(self):
        return self.spn.numel + self.spn_int.numel
    
    def add_sum_layer(self, num, remainder_num, p_conn=0.5, sparsity_temp=1.0):
        return self.spn.add_sum_layer(num, remainder_num, p_conn, sparsity_temp)
    
    def add_product_layer(self, num, copies, partitions, remainder_num):
        return self.spn.add_product_layer(num, copies, partitions, remainder_num)
    
    def add_bernoulli_layer(self, num, var):
        return self.spn.add_bernoulli_layer(num, var)

    def add_final_layer(self, num_vars, max_copies, p_conn=0.5, sparsity_temp=1.0):
        self.spn.add_final_layer(num_vars, max_copies, p_conn, sparsity_temp)
    
    def add_sum_layer_int(self, num, remainder_num, p_conn=0.2, sparsity_temp=1.0):
        return self.spn_int.add_sum_layer(num, remainder_num, p_conn, sparsity_temp)
    
    def add_product_layer_int(self, num, copies, partitions, remainder_num):
        return self.spn_int.add_product_layer(num, copies, partitions, remainder_num)
    
    def add_bernoulli_layer_int(self, num, var):
        return self.spn_int.add_bernoulli_layer(num, var)

    def add_final_layer_int(self, num_vars, max_copies, p_conn=0.2, sparsity_temp=1.0):
        self.spn_int.add_final_layer(num_vars, max_copies, p_conn, sparsity_temp)
    
    def sample(self, batch_size=1, hard=True):
        """
        Samples factor graphs from the SPN.

        Return:
            graph (torch.tensor): sampled factor graph of size (batch_size, num_nodes, num_factors)
        """
        node2factor, factor2node = super(ISPNFG, self).sample(batch_size, hard)
        int2factor = self.spn_int.sample(batch_size, hard)

        if self.spn_target != 'factor':
            int2factor = int2factor.transpose(1, 2)  # batch x num_vars x parallel
        return node2factor, factor2node, int2factor
    
    def sample_deterministic(self, hard=True):
        """
        Samples factor graphs from the SPN.

        Return:
            graph (torch.tensor): sampled factor graph of size (num_nodes, num_factors)
        """
        node2factor, factor2node = super(ISPNFG, self).sample_deterministic(hard)
        int2factor = self.spn_int.sample_deterministic(hard)
        if self.spn_target != 'factor':
            int2factor = int2factor.transpose(0, 1)  #num_vars x parallel
        return node2factor, factor2node, int2factor
    
    def entropy(self):
        py = torch.softmax(self.logpy, -1)
        return self.spn.entropy_selective().sum() + self.spn_int.entropy_selective().sum() - (py * torch.log(py)).sum()
    
    def kl_random(self):
        """Compute the KL divergence w.r.t. an SPN with the same structure but with uniform weights in all sum nodes, i.e.
        all sum nodes select each child equally likely.
        """
        py = torch.softmax(self.logpy, -1)
        kl_y = (py * (torch.log(py) + np.log(self.spn.num_vars+1))).sum()
        return self.spn.kl_random().sum() + self.spn_int.kl_random().sum() + kl_y

