from typing import Optional


import torch
from torch import nn
from torch_geometric.nn import MessagePassing
from torch import Tensor
from torch_geometric.typing import OptTensor
from torch_geometric.utils import softmax
from satb.models.dgdagrnn import SoftEvaluator, HardEvaluator
from satb.data.ordered_data import OrderedData, return_order_info
from satb.utils import sat_evaluate


def subgraph(target_idx, edge_index, edge_attr=None, dim=0):
    '''
    function from DAGNN
    '''
    le_idx = []
    for n in target_idx:
        ne_idx = edge_index[dim] == n
        le_idx += [ne_idx.nonzero().squeeze(-1)]
    le_idx = torch.cat(le_idx, dim=-1)
    lp_edge_index = edge_index[:, le_idx]
    if edge_attr is not None:
        lp_edge_attr = edge_attr[le_idx, :]
    else:
        lp_edge_attr = None
    return lp_edge_index, lp_edge_attr

class SoftEvaluator(MessagePassing):
    '''
    AND node => Soft Min;
    OR node => Soft max;
    Not node => 1 - z;
    '''
    def __init__(self, temperature=1., use_aig=False):
        super(SoftEvaluator, self).__init__(aggr='add', flow='source_to_target')

        self.temperature = temperature
        self.use_aig = use_aig



    def forward(self, x, edge_index, node_attr=None):
        x = (x > 0.5).float()
        return self.propagate(edge_index, x=x, node_attr=node_attr)

    def message(self, x_j, node_attr_i, index: Tensor, ptr: OptTensor, size_i: Optional[int]):
        # x_j has shape [E, out_channels], where out_channel is jut one-dimentional value in range of (0, 1)
        # softmax
        if self.use_aig:
            softmin_j = softmax(-x_j/self.temperature, index, ptr, size_i)
            and_mask = (node_attr_i[:, 1] == 1.0).unsqueeze(1)
            not_mask = (node_attr_i[:, 2] == 1.0).unsqueeze(1)
            t = and_mask * (softmin_j * x_j) + not_mask * (1 - x_j)
        else:
            softmax_j = softmax(x_j/self.temperature, index, ptr, size_i)
            softmin_j = softmax(-x_j/self.temperature, index, ptr, size_i)
            
            gate_to_index = {'INPUT': 0, 'AND': 1, 'NOT': 2, 'OR': 3}
            # mask
            and_mask = (node_attr_i[:, gate_to_index['AND']] == 1.0).unsqueeze(1)
            or_mask = (node_attr_i[:, gate_to_index['OR']] == 1.0).unsqueeze(1)
            not_mask = (node_attr_i[:, gate_to_index['NOT']] == 1.0).unsqueeze(1)

            t = and_mask * (softmin_j * x_j) + or_mask * (softmax_j * x_j) + not_mask * (1 - x_j)

        return t

    def update(self, aggr_out):
        return aggr_out

x = torch.tensor([[1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.]])
edge_index = torch.tensor([[ 1, 23,  0, 23,  0, 24,  4,  2,  4,  2, 25,  6, 26,  5, 27,  1,  0,  8,
          2,  9, 28,  7,  3,  2, 11, 29, 10,  3, 25, 13, 30, 12, 14,  3, 31, 25,
         23, 16, 32, 15,  1, 24, 18,  2, 19, 33, 17,  8, 25, 21, 34, 20],
        [23,  3,  3,  4, 24,  4,  5,  5,  6, 25,  6, 26,  7, 27,  7,  8,  8,  9,
          9, 28, 10, 10, 11, 11, 29, 12, 12, 13, 13, 30, 14, 14, 15, 31, 15, 16,
         16, 32, 17, 17, 18, 18, 19, 19, 33, 20, 20, 21, 21, 34, 22, 22]])


forward_level, forward_index, backward_level, backward_index = return_order_info(edge_index, x.size(0))
G = OrderedData(x=x, edge_index=edge_index, forward_level=forward_level, forward_index=forward_index, 
                    backward_level=backward_level, backward_index=backward_index)


print(G)

# for i in range(x.size(0)):
#     print('No. ', forward_index[i], 'sits in level: ', forward_level[i], '; gate type: ',  x[i])

# print('================')

# for i in range(x.size(0)):
#     print('No. ', backward_index[i], 'sits in level: ', backward_level[i], '; gate type: ',  x[i])
# print(edge_index.T)# print('forward index: ', forward_index)
# print('forward level: ', forward_level)

# soft assignment: 0.5, 0.5, 0.5
pred = torch.zeros(size=(x.size(0), 1))
pred[0] = 0.6
pred[1] = 0.4
pred[2] = 0.5


layer_mask = G.forward_level == 0
l_node = G.forward_index[layer_mask]

sol = pred[l_node]
print('Soft Solution: ', sol)

t = 0.01
evaluator = SoftEvaluator(temperature=t, use_aig=True)

# sat_simulated = pyg_simulation(G, sol)[0]
x, edge_index = G.x, G.edge_index
num_layers_f = max(G.forward_level).item() + 1
for l_idx in range(1, num_layers_f):
    # forward layer
    print('logic level: ', l_idx)
    layer_mask = G.forward_level == l_idx
    l_node = G.forward_index[layer_mask]
    print('node index at the level: ', l_node)
    
    l_edge_index, _ = subgraph(l_node, edge_index, dim=1)
    print('edge_index between previous logic level and current logic level: ', l_edge_index)
    msg = evaluator(pred, l_edge_index, x)
    l_msg = torch.index_select(msg, dim=0, index=l_node)
    
    print('values at this logic level: ', l_msg)
    pred[l_node, :] = l_msg
    # print("The output from AND: ",  (pred[0] * torch.exp(-pred[0]/t) + pred[1] * torch.exp(-pred[1]/t))/ (torch.exp(-pred[0]/t) + torch.exp(-pred[1]/t)))
    # exit()
# sink index
layer_mask = G.backward_level == 0
sink_node = G.backward_index[layer_mask]
# layer_mask = G.forward_level == max(G.forward_level).item()
# sink_node = G.forward_index[layer_mask]
print('sink_node: ', sink_node)

sat = torch.index_select(pred, dim=0, index=sink_node)

print('Satifiability value: ', sat)


