# Standard libaray imports
from abc import ABC, abstractmethod

# Third party imports
import numpy as np
import torch
from torch import nn
from torch.autograd import Variable as Variable
from torch.nn import functional as F
import lightning as L

EPS = 1e-15

class NodeLayer(L.LightningModule, ABC):
    def __init__(self):
        super(NodeLayer, self).__init__()

    @abstractmethod
    def initialize(self):
        pass

class SumLayer(NodeLayer):
    def __init__(self, num, child_num, normalize=True, p_conn=0.2, sparsity_temp=0.1):
        """
        :param num: the number of sum nodes in this layer
        :param child_num: the number of sum nodes in the last layer
        """
        NodeLayer.__init__(self).__init__()
        self.num = num
        self.child_num = child_num
        self.normalize = normalize
        # Sparsity control
        self.register_buffer('p_conn', torch.tensor(p_conn))
        self.sparsity_temp = sparsity_temp

        self.child_per_node = self.child_num // self.num

    def initialize(self, params=None):
        if params:
            self.register_parameter('logparams', nn.Parameter(params))
        else:
            self.register_parameter(
                'logparams',
                nn.Parameter(
                    torch.log(
                        torch.ones([self.child_num]).uniform_() / self.child_per_node
                    )
                )
            )

    def forward_no_log(self, input):
        """
        Args:
            input: tensor: (batch, child.num): the input to the layer

        Return:
            node_output: tensor: (batch, self.num): the output of nodes in this layer
        """
        batch, _ = input.size()

        input = input.view(batch, self.num, self.child_per_node)
        logparams = self.get_logparams()
        params = torch.exp(logparams)

        node_output = (input * params).sum(dim=-1)
        return node_output

    def forward(self, input):
        """
        Args:
            input: tensor: (batch, child.num): the input to the layer

        Return:
            node_output: tensor: (batch, self.num): the output of nodes in this layer
        """
        batch, _ = input.size()

        input = input.view(batch, self.num, self.child_per_node)
        params = self.get_params()
        logparams = torch.log(params)

        node_output = torch.logsumexp(input + logparams, dim=-1)
        return node_output
    
    def sample(self, input, eps=1e-6):
        params = self.get_params()
        num_ones = input.sum(1)
        binom_fit = torch.log(self.p_conn)*num_ones \
            + torch.log(1 - self.p_conn)*(self.child_per_node-num_ones)
        logits = torch.log(params+eps)+self.sparsity_temp*binom_fit.view(self.num, self.child_per_node)
        return F.gumbel_softmax(logits, tau=1, hard=True)

    def entropy(self, input):
        params = self.get_params()
        logparams = torch.log(params)
        wt_sum = self.forward_no_log(input)
        wlogw = (params * logparams).sum(dim=-1).unsqueeze(0) # unsqueeze batch dim
        return wt_sum - wlogw

    def get_params(self):
        # logparams shape (self.num, C)
        params = torch.exp(self.logparams).view(self.num, self.child_per_node)
        if self.normalize: params = params / torch.sum(params, dim=-1, keepdims=True)

        return params

    def get_logparams(self):
        # logparams shape (self.num, C)
        return torch.log(self.get_params() + EPS)

class ProductLayer(NodeLayer):
    def __init__(self, num, copies, partitions):
        """
        :param num: the number of product nodes in this layer
        :param copies: the number of copies of each partition
        :param partitions: number of variable partitions in this layer
        invariant: num = copies x partitions
        """
        NodeLayer.__init__(self).__init__()
        self.num = num
        self.copies = copies
        self.partitions = partitions
        assert(self.num == self.copies * self.partitions)

        self.ch_copies = np.round(np.sqrt(self.copies)).astype(int)

        y = torch.arange(self.num)
        group = y // (self.copies)
        offset = y % (self.copies)
        self.ch1 = group*(2*self.ch_copies) + offset // self.ch_copies
        self.ch2 = group*(2*self.ch_copies) + self.ch_copies + offset % self.ch_copies

    def initialize(self, params=None):
        pass

    def forward(self, input):
        """
        Args:
            input: tensor: (batch, child.num): the input to the layer

        Return:
            node_output: tensor: (batch, self.num): the output to the layer
        """
        node_output = input[:,self.ch1] + input[:,self.ch2]
        return node_output

    def forward_no_log(self, input):
        """
        Args:
            input: tensor: (batch, child.num): the input to the layer

        Return:
            node_output: tensor: (batch, self.num): the output to the layer
        """
        node_output = input[:,self.ch1] * input[:,self.ch2]
        return node_output

    def entropy(self, input):
        return self.forward(input)

class BernoulliLayer(NodeLayer):
    def __init__(self, num, var):
        """
        A leaf layer made up of Bernoulli nodes.

        Args:
            num: tensor: (): the number of leaf nodes
            var: tensor: (self.num): scope variable of the nodes (0-indexed)
            p: tensor: (self.num): prob of the Bernoulli nodes being True
        """
        NodeLayer.__init__(self).__init__()
        self.num = num
        self.var = var
        self.initialize()

    def initialize(self):
        self.register_buffer('p', torch.tensor([0.0, 1.0]).repeat(self.num // 2))

    def entropy(self):
        return torch.zeros((1, self.p.size(0))).to(self.device)

    def get_pparams(self):
        return self.p
    
    def sample_set(self):
        out = ((self.var+1) * (2*self.p - 1)).view(-1, 1)
        return out
    
    def sample_onehot(self):
        """
        Returns:
            torch.tensor: one-hot encoding of the Bernoulli nodes
        """
        leaves = F.one_hot(self.var.long(), num_classes=self.num//2)
        mask = torch.tensor([[0], [1]]).repeat(self.num//2, 1).to(self.device)
        return leaves * mask
        

class SumLayerMulti(NodeLayer):
    def __init__(self,
                 num,
                 child_num,
                 remainder_num,
                 parallel,
                 normalize=True,
                 tau=1.0,
                 p_conn=0.2,
                 sparsity_temp=0.1):
        """Parallel sum layer.

        Args:
            num (int): total number of sum nodes in the layer
            child_num (int): total number of child nodes excluding the remainder nodes
            remainder (int): total number of nodes not performing selection
            parallel (int): number of parallel PC layers
            normalize (bool, optional): whether to use normalized probability. Defaults to True.
            p_conn (float, optional): density parameter. Defaults to 0.2.
            sparsity_temp (float, optional): weight for controling the sparsity prior. Defaults to 0.1.
        """
        NodeLayer.__init__(self).__init__()
        self.num = num
        self.child_num = child_num
        self.remainder_num = remainder_num
        self.parallel = parallel
        self.normalize = normalize
        self.tau = tau
        # Sparsity control
        self.register_buffer('p_conn', torch.tensor(p_conn))
        self.sparsity_temp = sparsity_temp

        self.child_per_node = self.child_num // self.num
    
    def initialize(self, params=None):
        if params:
            self.register_parameter('logparams', nn.Parameter(params))
        else:
            self.register_parameter('logparams', nn.Parameter(torch.zeros([self.parallel, self.child_num])))
            nn.init.xavier_normal_(self.logparams)

    @property
    def params(self):
        logits = (self.logparams).view(self.parallel, self.num, self.child_per_node)   
        return torch.softmax(logits, -1).view(self.parallel, self.child_num)

    @property
    def numel(self):
        return self.logparams.numel()

    def get_weight(self, data_in, softmax=False):
        batch, n_total = data_in.shape[0], data_in.shape[2]

        num_ones = data_in[:, :, :n_total-self.remainder_num].sum(-1)
        # shape: (batch, parallel, num_child)
        binom_fit = torch.log(self.p_conn)*num_ones \
            + torch.log(1 - self.p_conn)*(self.child_per_node-num_ones)
        logits = (self.logparams+self.sparsity_temp*binom_fit).view(batch, self.parallel, self.num, self.child_per_node)
        if softmax:
            torch.softmax(logits, -1)
        return logits

    def forward_no_log(self, data_in):
        """
        Args:
            input: tensor: (batch, parallel, child.num): the input to the layer

        Return:
            node_output: tensor: (batch, parallel, self.num + self.remainder_num): the output of nodes in this layer
        """
        data_in = data_in.to(self.device)
        batch, n_total = data_in.shape[0], data_in.shape[2]
        node_output = (data_in[:, :, :n_total-self.remainder_num] * self.params).view(batch, self.parallel, self.num, self.child_per_node).sum(dim=(3))
        node_output = torch.cat([node_output, data_in[:, :, n_total-self.remainder_num:]], dim=-1)
        return node_output

    def forward(self, data_in):
        """
        Args:
            input: tensor: (batch, parallel, child.num): the input to the layer

        Return:
            node_output: tensor: (batch, parallel, self.num + self.remainder_num): the output of nodes in this layer
        """
        batch, n_total = data_in.shape[0], data_in.shape[2]
        node_output = torch.logsumexp((data_in[:, :, :n_total-self.remainder_num] \
                                       + torch.log(self.params)).view(batch, self.parallel, self.num, self.child_per_node),
                                      dim=(3))
        node_output = torch.cat([node_output, data_in[:, :, n_total-self.remainder_num:]], dim=-1)

        return node_output

    def sample(self, data_in, hard=True):
        """Randomly sample from input nodes
        
        Args:
            data_in (:class:`torch.tensor`): the input to the layer of shape (batch, parallel, num_child * child_per_node)
            hard (bool, optional): whether to use hard sampling. Defaults to True.
        
        Return:
            :class:`torch.tensor`: one-hot encoding of selected child nodes
        """        
        logits = self.get_weight(data_in)
        return F.gumbel_softmax(logits, tau=self.tau, hard=hard)
    
    def sample_deterministic(self, hard=True):
        """Pick input nodes with the largest weight
        
        Return:
            :class:`torch.tensor`: one-hot encoding of selected child nodes
        """
        if not hard:
            return F.softmax(self.logparams.view(self.parallel, self.num, self.child_per_node), -1)
        return F.one_hot(torch.argmax(self.logparams.view(self.parallel, self.num, self.child_per_node), 2), self.child_per_node)

    def entropy(self, entropy_in):
        """Computes the entropy

        Args:
            entropy_in (:class:`torch.tensor`): entropy of input nodes (parallel, num_child * child_per_node)
        """
        logparams = torch.log(self.params)
        wt_sum = self.forward_no_log(entropy_in.unsqueeze(0)).squeeze(0)
        wlogw = (self.params * logparams).view(self.parallel, self.num, self.child_per_node).sum(dim=-1)
        n_total = wt_sum.shape[-1]
        return torch.cat([wt_sum[:, :n_total-self.remainder_num] - wlogw, wt_sum[:, n_total-self.remainder_num:]], dim=-1)
    
    def kl_random(self, kl_in):
        params = self.params
        logparams = torch.log(params)
        kl_sum = self.forward_no_log(kl_in.unsqueeze(0)).squeeze(0)
        qlogq = (params * logparams).view(self.parallel, self.num, self.child_per_node).sum(dim=-1)
        qlogp = (params * (-np.log(self.child_per_node))).view(self.parallel, self.num, self.child_per_node).sum(dim=-1)
        n_total = kl_sum.shape[-1]
        return torch.cat([kl_sum[:, :n_total-self.remainder_num] + qlogq - qlogp, kl_sum[:, n_total-self.remainder_num:]], dim=-1)


class ProductLayerMulti(NodeLayer):
    def __init__(self, num, copies, partitions, remainder_num, parallel):
        """
        Initialize the SPN Nodes layer.

        Args:
            num (int): number of product nodes in this layer.
            copies (int): number of copies of each partition, excluding the remainder partitions.
            partitions (int): number of variable partitions in this layer.
            remainder_num (int): number of remaining child nodes not participating in product operation.∂
            parallel (int): number of parallel nodes.
        """
        NodeLayer.__init__(self).__init__()
        self.num = num
        self.copies = copies
        self.partitions = partitions
        self.remainder_num = remainder_num
        self.parallel = parallel
        assert (self.num == self.copies * self.partitions)

        self.ch_copies = np.round(np.sqrt(self.copies)).astype(int)

        y = torch.arange(self.num)  # shape (parallel, num)
        group = y // (self.copies)
        offset = y % (self.copies)
        self.ch1 = group*(2*self.ch_copies) + offset // self.ch_copies # all even numbers
        self.ch2 = group*(2*self.ch_copies) + self.ch_copies + offset % self.ch_copies # all odd numbers

    @property
    def numel(self):
        return 0

    def initialize(self, params=None):
        pass

    def forward(self, data_in):
        """
        Args:
            input: tensor: (batch, parallel, child.num): the input to the layer

        Return:
            node_output: tensor: (batch, parallel, self.num+self.remainder_num): the output to the layer
        """
        n_total = data_in.size(2)
        node_output = data_in[:, :, self.ch1] + data_in[:, :, self.ch2]
        node_output  = torch.cat([node_output, data_in[:, :, n_total-self.remainder_num:]], 2)
        return node_output

    def forward_no_log(self, data_in):
        """
        Args:
            input: tensor: (batch, child.num): the input to the layer

        Return:
            node_output: tensor: (batch, self.num+self.remainder_num): the output to the layer
        """
        n_total = data_in.size(2)
        node_output = data_in[:, :, self.ch1] * data_in[:, :, self.ch2]
        node_output  = torch.cat([node_output, data_in[:, :, n_total-self.remainder_num:]], 2)
        return node_output

    def entropy(self, entropy_in):
        return self.forward(entropy_in.unsqueeze(0)).squeeze(0)
    
    def kl_random(self, kl_in):
        return self.forward(kl_in.unsqueeze(0)).squeeze(0)


class BernoulliLayerMulti(NodeLayer):
    def __init__(self, num, var, parallel):
        """
        A leaf layer made up of Bernoulli nodes for multiple parallel PCs.

        Args:
            num: tensor: (): the number of leaf nodes
            var: tensor: (self.num): scope variable of the nodes (0-indexed)
            parallel: int: number of parallel nodes
        """
        NodeLayer.__init__(self).__init__()
        self.num = num
        self.var = var
        self.parallel = parallel
        self.initialize()

    def initialize(self):
        self.register_buffer('p', torch.tensor([0., 1.]).repeat((self.parallel, self.num // 2)))

    @property
    def numel(self):
        return 0

    def entropy(self):
        return torch.zeros((self.parallel, self.p.size(1))).to(self.device)
    
    def kl_random(self):
        return torch.zeros((self.parallel, self.p.size(1))).to(self.device)

    def get_pparams(self):
        return self.p
    
    def sample_set(self):
        out = ((self.var+1) * (2*self.p - 1)).view(-1, 1)
        return out
    
    def sample_onehot(self, batch_size=None):
        """
        Returns:
            torch.tensor: one-hot encoding of the Bernoulli nodes
        """
        leaves = F.one_hot(self.var.long(), num_classes=self.num//2).to(self.device)
        mask = torch.tensor([[0], [1]]).repeat(self.num//2, 1).to(self.device)
        if isinstance(batch_size, int):
            leaves = leaves.repeat(batch_size, self.parallel, 1, 1)
            mask = mask.repeat(batch_size, self.parallel, 1, 1)
        else:
            leaves = leaves.repeat(self.parallel, 1, 1)
            mask = mask.repeat(self.parallel, 1, 1)
        return leaves * mask


class FinalLayerMulti(NodeLayer):
    def __init__(self,
                 num_vars,
                 max_copies,
                 parallel,
                 normalize=True,
                 tau=1.0,
                 p_conn=0.2,
                 sparsity_temp=0.1):
        """

        Args:
            num_vars (int): total number of variables
            max_copies (int): size budget
            parallel (int): number of parallel PC layers
            normalize (bool, optional): whether to use normalized probability. Defaults to True.
            tau (float, optional): temperature for Gumbel softmax. Defaults to 1.0.
            p_conn (float, optional): density parameter. Defaults to 0.2.
            sparsity_temp (float, optional): weight for controling the sparsity prior. Defaults to 0.1.
        """
        NodeLayer.__init__(self).__init__()
        self.num_vars = num_vars
        self.partition_sizes = self._compute_partition_size(num_vars, max_copies)
        
        self.parallel = parallel
        self.normalize = normalize
        self.tau = tau
        # Sparsity control
        self.register_buffer('p_conn', torch.tensor(p_conn))
        self.sparsity_temp = sparsity_temp

        self.logparams = None

    def _compute_partition_size(self, num_vars, max_copies):
        sizes = (np.ones((num_vars))*2).astype(int)
        n = num_vars
        num_rem = 0
        max_copies_square = max_copies * max_copies
        while n > 1:
            if sizes[0] > max_copies_square:
                sizes = np.clip(sizes, None, max_copies)
            if n % 2 == 1:
                num_rem += 1
            if num_rem > 0:
                sizes = np.concatenate([sizes[:n//2]**2, sizes[-num_rem:]])
            else:
                sizes = sizes[:n//2]**2
            n = n // 2
            
        if sizes[0] > max_copies * max_copies:
            sizes = np.concatenate([[max_copies_square], np.clip(sizes[1:], None, max_copies)])
        return sizes

    def initialize(self, params=None):
        if params:
            for i, param in enumerate(params):
                self.register_parameter(f'logparam_{i}', nn.Parameter(params))
        else:
            for i, s in enumerate(self.partition_sizes):
                self.register_parameter(f'logparam_{i}', nn.Parameter(torch.zeros([self.parallel, s])))
        self.logparams = nn.ParameterList()
        for name, param in self.named_parameters():
            if not params:
                nn.init.xavier_normal_(param)
            self.logparams.append(param)

    @property
    def numel(self):
        return np.sum([param.numel() for param in self.logparams])

    def forward_no_log(self, data_in):
        """
        Args:
            input: tensor: (batch, child.num): the input to the layer

        Return:
            node_output: tensor: (batch, self.num): the output of nodes in this layer
        """
        end_indices = np.cumsum(self.partition_sizes)
        start_idx = 0
        out_per_partition = []
        for end_idx, logparam in zip(end_indices, self.logparams):
            params = F.softmax(logparam, dim=-1)  # batch x parallel x partition size
            _node_output = (data_in[:, :, start_idx:end_idx] * params).sum(-1)
            out_per_partition.append(_node_output)
            start_idx = end_idx
        
        return torch.prod(torch.stack(out_per_partition), dim=0)
    
    def forward(self, data_in):
        """
        Args:
            input: tensor: (batch, child.num): the input to the layer

        Return:
            node_output: tensor: (batch, self.num): the output of nodes in this layer
        """
        end_indices = np.cumsum(self.partition_sizes)
        start_idx = 0
        node_output = 0

        for end_idx, logparam in zip(end_indices, self.logparams):
            n_child_nodes = start_idx - end_idx
            params = F.softmax(logparam, dim=-1)
            node_output = node_output + torch.logsumexp(data_in[:, :, start_idx:end_idx] + torch.log(params), dim=(2))
            start_idx = end_idx

        return node_output

    def sample(self, data_in, hard=True):
        """Randomly sample from input nodes
        
        Args:
            data_in (:class:`torch.tensor`): the input to the layer of shape (batch, parallel, num_child * child_per_node, num_vars)
        
        Return:
            :class:`torch.tensor`: binary representation of the final output subset
        """
        batch = data_in.size(0)
        end_indices = np.cumsum(self.partition_sizes)
        start_idx = 0

        out = 0

        for end_idx, logparam in zip(end_indices, self.logparams):
            n_child_nodes = end_idx - start_idx
            num_ones = data_in[:, :, start_idx:end_idx].sum(-1)
            num = end_idx - start_idx
            # shape: (batch, parallel, num_child)
            binom_fit = torch.log(self.p_conn)*num_ones \
                + torch.log(1 - self.p_conn)*(n_child_nodes-num_ones)
            logits = (logparam+self.sparsity_temp*binom_fit)  # batch x parallel x partition size
            onehot_samples = F.gumbel_softmax(logits, tau=self.tau, hard=hard).view(batch, self.parallel, num, 1)
            out = out + (data_in[:, :, start_idx:end_idx] * onehot_samples).sum(2)

            start_idx = end_idx
            
        return out
    
    def sample_deterministic(self, data_in, hard=True):
        """Randomly sample from input nodes
        
        Args:
            data_in (:class:`torch.tensor`): the input to the layer of shape (batch, parallel, num_child * child_per_node, num_vars)
        
        Return:
            :class:`torch.tensor`: binary representation of the final output subset
        """
        end_indices = np.cumsum(self.partition_sizes)
        start_idx = 0

        out = 0
        for end_idx, logparam in zip(end_indices, self.logparams):
            num = end_idx - start_idx
            # shape: (parallel, num_child)
            if hard:
                onehot_samples = F.one_hot(torch.argmax(logparam, 1), num).view(self.parallel, num, 1)
            else:
                onehot_samples = torch.softmax(logparam, 1).view(self.parallel, num, 1)
            out = out + (data_in[:, start_idx:end_idx] * onehot_samples).sum(1)
            start_idx = end_idx
            
        return out

    def entropy(self, entropy_in):
        """Computes the entropy

        Args:
            entropy_in (:class:`torch.tensor`): the input to the layer of shape (batch, parallel, num_child * child_per_node)
        """
        end_indices = np.cumsum(self.partition_sizes)
        start_idx = 0

        out = 0
        for end_idx, logparam in zip(end_indices, self.logparams):
            params = F.softmax(logparam, dim=-1)
            logparams = torch.log(params)
            wt_sum = (entropy_in[:, start_idx:end_idx] * params).sum(-1)
            wlogw = (params * logparams).sum(dim=-1)
            out += wt_sum - wlogw
            start_idx = end_idx

        return out
    
    def kl_random(self, kl_in):
        """Computes the entropy

        Args:
            entropy_in (:class:`torch.tensor`): the input to the layer of shape (batch, parallel, num_child * child_per_node)
        """
        end_indices = np.cumsum(self.partition_sizes)
        start_idx = 0

        out = 0
        for end_idx, logparam in zip(end_indices, self.logparams):
            params = F.softmax(logparam, dim=-1)
            logparams = torch.log(params)
            kl_sum = (kl_in[:, start_idx:end_idx] * params).sum(-1)
            qlogq = (params * logparams).sum(dim=-1)
            qlogp = (params * (-np.log(end_idx - start_idx))).sum(-1)
            out += kl_sum + qlogq - qlogp
            start_idx = end_idx

        return out
