from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import torch
from torch import nn
# from utils.dag_utils import subgraph, custom_backward_subgraph

from .gat_conv import AGNNConv
from .gcn_conv import AggConv
from .deepset_conv import DeepSetConv
from .gated_sum_conv import GatedSumConv
from .mlp import MLP
from .layernorm_gru import LayerNormGRU
from .layernorm_lstm import LayerNormLSTM

from .dgdagrnn import HardEvaluator

from torch.nn import LSTM, GRU


_aggr_function_factory = {
    'agnnconv': AGNNConv,
    'deepset': DeepSetConv,
    'gated_sum': GatedSumConv,
    'conv_sum': AggConv,
}

_update_function_factory = {
    'lstm': LSTM,
    'gru': GRU,
    'layernorm_lstm': LayerNormLSTM,
    'layernorm_gru': LayerNormGRU,
}

def subgraph(target_idx, edge_index, edge_attr=None, dim=0):
    '''
    function from DAGNN
    '''
    le_idx = []
    for n in target_idx:
        ne_idx = edge_index[dim] == n
        le_idx += [ne_idx.nonzero().squeeze(-1)]
    le_idx = torch.cat(le_idx, dim=-1)
    lp_edge_index = edge_index[:, le_idx]
    if edge_attr is not None:
        lp_edge_attr = edge_attr[le_idx, :]
    else:
        lp_edge_attr = None
    return lp_edge_index, lp_edge_attr


class DeepSAT(nn.Module):
    '''
    DeepSAT Graph Neural Networks for Satifiability problems.
    '''
    def __init__(self, num_rounds=1, reverse=True, mask=True, num_aggr=3, dim_node_feature=3, dim_hidden=64, dim_mlp=32, dim_pred=1, num_fc=3, wx_update=False, wx_mlp=False, dim_edge_feature=16, aggr_function='agnnconv', update_function='lstm', norm_layer='batchnorm', activation_layer='relu', **kwargs):
        super(DeepSAT, self).__init__()
        # configuration
        self.num_rounds = num_rounds
        self.reverse = reverse
        self.mask = mask

        # dimensions
        self.num_aggr = num_aggr
        self.dim_node_feature = dim_node_feature
        self.dim_hidden = dim_hidden
        self.dim_mlp = dim_mlp
        self.dim_pred = dim_pred
        self.num_fc = num_fc
        self.wx_update = wx_update
        self.wx_mlp = wx_mlp
        self.dim_edge_feature = dim_edge_feature

        self.hard_evaluator = HardEvaluator(temperature=0.001, use_aig=True)

        # 1. message/aggr-related
        dim_aggr = self.dim_hidden
        if aggr_function in _aggr_function_factory.keys():
            self.aggr_forward = _aggr_function_factory[aggr_function](dim_aggr, self.dim_hidden)
            if self.reverse:
                self.aggr_backward = _aggr_function_factory[aggr_function](dim_aggr, self.dim_hidden, reverse=True)
        else:
            raise KeyError('no support {} aggr function.'.format(aggr_function))


        # 2. update-related
        self.update_function_ = update_function
        if update_function in _update_function_factory.keys():
            # Here only consider the inputs as the concatenated vector from embedding and feature vector.
            if self.wx_update:
                self.update_forward = _update_function_factory[update_function](self.dim_node_feature+self.dim_hidden, self.dim_hidden)
                if self.reverse:
                    self.update_backward = _update_function_factory[update_function](self.dim_node_feature+self.dim_hidden, self.dim_hidden)
            else:
                self.update_forward = _update_function_factory[update_function](self.dim_hidden, self.dim_hidden)
                if self.reverse:
                    self.update_backward = _update_function_factory[update_function](self.dim_hidden, self.dim_hidden)
        else:
            raise KeyError('no support {} update function.'.format(update_function))
        # consider the embedding for the LSTM/GRU model initialized by non-zeros
        self.one = torch.ones(1).cuda()
        self.emd_int = nn.Linear(1, self.dim_hidden)
        self.one.requires_grad = False


        # 3. predictor-related
        # TODO: support multiple predictors. Use a nn.ModuleList to handle it.
        self.norm_layer = norm_layer
        self.activation_layer = activation_layer
        if self.wx_mlp:
            self.predictor = MLP(self.dim_hidden+self.dim_node_feature, self.dim_mlp, self.dim_pred, 
            num_layer=self.num_fc, norm_layer=self.norm_layer, act_layer=self.activation_layer, sigmoid=False, tanh=False)
        else:
            self.predictor = MLP(self.dim_hidden, self.dim_mlp, self.dim_pred, 
            num_layer=self.num_fc, norm_layer=self.norm_layer, act_layer=self.activation_layer, sigmoid=False, tanh=False)

    def forward_features(self, G):
        num_nodes = G.num_nodes
        num_layers_f = max(G.forward_level).item() + 1
        num_layers_b = max(G.backward_level).item() + 1
        one = self.one
        h_init = self.emd_int(one).view(1, 1, -1) # (1 x 1 x dim_hidden)
        h_init = h_init.repeat(1, num_nodes, 1) # (1 x num_nodes x dim_hidden)
        # h_init = torch.empty(1, num_nodes, self.dim_hidden).to(self.device)
        # nn.init.normal_(h_init)

        if self.mask:
            h_true = torch.ones_like(h_init).cuda()
            h_false = -torch.ones_like(h_init).cuda()
            h_true.requires_grad = False
            h_false.requires_grad = False
            h_init = self.imply_mask(G, h_init, h_true, h_false)
        else:
            h_true = None
            h_false = None

        if 'lstm' in self.update_function_:
            node_embedding = self._lstm_forward(G, h_init, num_layers_f, num_layers_b, num_nodes, h_true, h_false)
        elif 'gru' in self.update_function_:
            node_embedding = self._gru_forward(G, h_init, num_layers_f, num_layers_b, h_true, h_false)
        else:
            raise NotImplementedError('The update function should be specified as one of lstm and gru.')
        
        return node_embedding

    def forward_head(self, G, node_embedding):

        if self.wx_mlp:
            pred = self.predictor(torch.cat([node_embedding, G.x], dim=1))
        else:
            pred = self.predictor(node_embedding)

        return pred

    def forward(self, G):
        node_embedding = self.forward_features(G)
        pred = self.forward_head(G, node_embedding)
        return pred
            
    
    def _lstm_forward(self, G, h_init, num_layers_f, num_layers_b, num_nodes, h_true=None, h_false=None):
        x, edge_index = G.x, G.edge_index
        edge_attr = None
        
        node_state = (h_init, torch.zeros(1, num_nodes, self.dim_hidden).cuda()) # (h_0, c_0). here we only initialize h_0. TODO: option of not initializing the hidden state of LSTM.
        

        for _ in range(self.num_rounds):
            for l_idx in range(1, num_layers_f):
                # forward layer
                layer_mask = G.forward_level == l_idx
                l_node = G.forward_index[layer_mask]

                l_state = (torch.index_select(node_state[0], dim=1, index=l_node), 
                            torch.index_select(node_state[1], dim=1, index=l_node))

                l_edge_index, l_edge_attr = subgraph(l_node, edge_index, edge_attr, dim=1)
                msg = self.aggr_forward(node_state[0].squeeze(0), l_edge_index, l_edge_attr)
                l_msg = torch.index_select(msg, dim=0, index=l_node)
                l_x = torch.index_select(x, dim=0, index=l_node)
                
                if self.wx_update:
                    _, l_state = self.update_forward(torch.cat([l_msg, l_x], dim=1).unsqueeze(0), l_state)
                else:
                    _, l_state = self.update_forward(l_msg.unsqueeze(0), l_state)

                node_state[0][:, l_node, :] = l_state[0].to(dtype=node_state[0].dtype)
                node_state[1][:, l_node, :] = l_state[1].to(dtype=node_state[1].dtype)

                if self.mask:
                    node_state[0][:] = self.imply_mask(G, node_state[0], h_true, h_false)

            if self.reverse:
                for l_idx in range(1, num_layers_b):
                    # backward layer
                    layer_mask = G.backward_level == l_idx
                    l_node = G.backward_index[layer_mask]
                    
                    l_state = (torch.index_select(node_state[0], dim=1, index=l_node), 
                                torch.index_select(node_state[1], dim=1, index=l_node))
                    
                    l_edge_index, l_edge_attr = subgraph(l_node, edge_index, edge_attr, dim=0)
                    msg = self.aggr_backward(node_state[0].squeeze(0), l_edge_index, l_edge_attr)
                    l_msg = torch.index_select(msg, dim=0, index=l_node)
                    l_x = torch.index_select(x, dim=0, index=l_node)
                    
                    if self.wx_update:
                        _, l_state = self.update_backward(torch.cat([l_msg, l_x], dim=1).unsqueeze(0), l_state)
                    else:
                        _, l_state = self.update_backward(l_msg.unsqueeze(0), l_state)
                    
                    node_state[0][:, l_node, :] = l_state[0].to(dtype=node_state[0].dtype)
                    node_state[1][:, l_node, :] = l_state[1].to(dtype=node_state[1].dtype)

                    if self.mask:
                        node_state[0][:] = self.imply_mask(G, node_state[0], h_true, h_false)
               

        node_embedding = node_state[0].squeeze(0)

        return node_embedding
    
    def _gru_forward(self, G, h_init, num_layers_f, num_layers_b, h_true=None, h_false=None):
        x, edge_index = G.x, G.edge_index
        edge_attr = None

        node_state = h_init # (h_0). here we initialize h_0. TODO: option of not initializing the hidden state of GRU.


        for _ in range(self.num_rounds):
            for l_idx in range(1, num_layers_f):
                # forward layer
                layer_mask = G.forward_level == l_idx
                l_node = G.forward_index[layer_mask]
                
                l_state = torch.index_select(node_state, dim=1, index=l_node)

                l_edge_index, l_edge_attr = subgraph(l_node, edge_index, edge_attr, dim=1)
                msg = self.aggr_forward(node_state.squeeze(0), l_edge_index, l_edge_attr)
                l_msg = torch.index_select(msg, dim=0, index=l_node)
                l_x = torch.index_select(x, dim=0, index=l_node)
                
                if self.wx_update:
                    _, l_state = self.update_forward(torch.cat([l_msg, l_x], dim=1).unsqueeze(0), l_state)
                else:
                    _, l_state = self.update_forward(l_msg.unsqueeze(0), l_state)
                node_state[:, l_node, :] = l_state.to(dtype=node_state.dtype)
                
                if self.mask:
                    node_state = self.imply_mask(G, node_state, h_true, h_false)
            
            if self.reverse:
                for l_idx in range(1, num_layers_b):
                    # backward layer
                    layer_mask = G.backward_level == l_idx
                    l_node = G.backward_index[layer_mask]
                    
                    l_state = torch.index_select(node_state, dim=1, index=l_node)

                    l_edge_index, l_edge_attr = subgraph(l_node, edge_index, edge_attr, dim=0)
                    msg = self.aggr_backward(node_state.squeeze(0), l_edge_index, l_edge_attr)
                    l_msg = torch.index_select(msg, dim=0, index=l_node)
                    l_x = torch.index_select(x, dim=0, index=l_node)
                    
                    if self.wx_update:
                        _, l_state = self.update_backward(torch.cat([l_msg, l_x], dim=1).unsqueeze(0), l_state)
                    else:
                        _, l_state = self.update_backward(l_msg.unsqueeze(0), l_state)                
                    
                    node_state[:, l_node, :] = l_state.to(dtype=node_state.dtype)

                    if self.mask:
                        node_state = self.imply_mask(G, node_state, h_true, h_false)


        node_embedding = node_state.squeeze(0)

        return node_embedding

    
    def imply_mask(self, G, h, h_true, h_false):
        # logic implication using masking
        true_mask = (G.mask == 1.0).unsqueeze(0)
        false_mask = (G.mask == 0.0).unsqueeze(0)
        normal_mask = (G.mask == -1.0).unsqueeze(0)
        h_mask = h * normal_mask + h_true * true_mask + h_false * false_mask
        return h_mask

    def decode_assignment(self, g):
        # get the solution (assigned during data generation)
        # layer_mask = g.forward_level == 0
        # l_node = g.forward_index[layer_mask]

        # set PO as 1.
        layer_mask = g.backward_level == 0
        l_node = g.backward_index[layer_mask]
        g.mask[l_node] = torch.tensor(1.0).cuda()

        # check # PIs
        # literal index
        layer_mask = g.forward_level == 0
        l_node = g.forward_index[layer_mask]

        # for backtracking
        ORDER = []
        change_ind = -1
        mask_backup = g.mask.clone().detach()


        for i in range(len(l_node)):
            # print('==> # ', i+1, 'solving..')
            output = self.forward(g)

            # mask
            one_mask = torch.zeros(g.y.size(0)).cuda()
            one_mask = one_mask.scatter(dim=0, index=l_node, src=torch.ones(len(l_node)).cuda()).unsqueeze(1)
            
            max_val, max_ind = torch.max(output * one_mask, dim=0)
            min_val, min_ind = torch.min(output + (1 - one_mask), dim=0)

            ext_val, ext_ind = (max_val, max_ind) if (max_val > (1 - min_val)) else (min_val, min_ind)
            # ext_val, ext_ind = torch.min(torch.abs(output * one_mask - 0.5), dim=0)
            # print('Assign No. ', ext_ind.item(), 'with prob: ', ext_val.item(), 'as value: ', 1.0 if ext_val > 0.5 else 0.0)
            g.mask[ext_ind] = torch.tensor(1.0).cuda() if ext_val > 0.5 else torch.tensor(0.0).cuda()
            # push the current index to Q
            ORDER.append(ext_ind)
            
            l_node_new = []
            for i in l_node:
                if i != ext_ind:
                    l_node_new.append(i)
            l_node = torch.tensor(l_node_new)
        

        sat = self.sat_evaluate(g)
        if sat:
            layer_mask = g.forward_level == 0
            l_node = g.forward_index[layer_mask]
            return g.mask[l_node]
        
        # do the backtracking
        while ORDER:
            # renew the mask
            g.mask = mask_backup.clone().detach()
            change_ind = ORDER.pop()
            # print('Change the values when solving No. ', change_ind.item(), 'PIs')
            # literal index
            layer_mask = g.forward_level == 0
            l_node = g.forward_index[layer_mask]

            for i in range(len(l_node)):
                output = self.forward(g)
                # mask
                one_mask = torch.zeros(g.y.size(0)).cuda()
                one_mask = one_mask.scatter(dim=0, index=l_node, src=torch.ones(len(l_node)).cuda()).unsqueeze(1)
                
                max_val, max_ind = torch.max(output * one_mask, dim=0)
                min_val, min_ind = torch.min(output + (1 - one_mask), dim=0)

                ext_val, ext_ind = (max_val, max_ind) if (max_val > (1 - min_val)) else (min_val, min_ind)
                # ext_val, ext_ind = torch.min(torch.abs(output * one_mask - 0.5), dim=0)
                g.mask[ext_ind] = torch.tensor(1.0).cuda() if ext_val > 0.5 else torch.tensor(0.0).cuda()
                # push the current index to Q
                if ext_ind == change_ind:
                    g.mask[ext_ind] = 1 - g.mask[ext_ind]
                # print('Assign No. ', ext_ind.item(), 'with prob: ', ext_val.item(), 'as value: ', g.mask[ext_ind].item())
                
                l_node_new = []
                for i in l_node:
                    if i != ext_ind:
                        l_node_new.append(i)
                l_node = torch.tensor(l_node_new)
            sat = self.sat_evaluate(g)
            if sat:
                # layer_mask = g.forward_level == 0
                # l_node = g.forward_index[layer_mask]
                # return g.mask[l_node]
                return g.mask
            
        return None

    def sat_evaluate(self, G):
        
        evaluator = self.hard_evaluator
        x, edge_index = G.x, G.edge_index

        pred = G.mask

        num_layers_f = max(G.forward_level).item() + 1
        for l_idx in range(1, num_layers_f):
            # forward layer
            layer_mask = G.forward_level == l_idx
            l_node = G.forward_index[layer_mask]
            
            l_edge_index, _ = subgraph(l_node, edge_index, dim=1)
            msg = evaluator(pred, l_edge_index, x)
            l_msg = torch.index_select(msg, dim=0, index=l_node)
            
            pred[l_node, :] = l_msg
        
        # sink index
        layer_mask = G.backward_level == 0
        sink_node = G.backward_level[layer_mask]

        sat = torch.index_select(pred, dim=0, index=sink_node)
            
        return sat