import torch


def subgraph(target_idx, edge_index, edge_attr=None, dim=0):
    '''
    function from DAGNN
    '''
    le_idx = []
    for n in target_idx:
        ne_idx = edge_index[dim] == n
        le_idx += [ne_idx.nonzero().squeeze(-1)]
    le_idx = torch.cat(le_idx, dim=-1)
    lp_edge_index = edge_index[:, le_idx]
    if edge_attr is not None:
        lp_edge_attr = edge_attr[le_idx, :]
    else:
        lp_edge_attr = None
    return lp_edge_index, lp_edge_attr


def sat_evaluate(evaluator, G, pred, cnf=False):
    if pred == None:
        return torch.zeros((1, 1)).cuda()
    if not cnf:
        layer_mask = G.forward_level == 0
        l_node = G.forward_index[layer_mask]
        
        sol = pred[l_node]
        sat_simulated = pyg_simulation(G, sol)[0]
        x, edge_index = G.x, G.edge_index

        num_layers_f = max(G.forward_level).item() + 1
        for l_idx in range(1, num_layers_f):
            # forward layer
            layer_mask = G.forward_level == l_idx
            l_node = G.forward_index[layer_mask]
            
            l_edge_index, _ = subgraph(l_node, edge_index, dim=1)
            msg = evaluator(pred, l_edge_index, x)
            l_msg = torch.index_select(msg, dim=0, index=l_node)
            
            pred[l_node, :] = l_msg
        
        # sink index
        layer_mask = G.backward_level == 0
        sink_node = G.backward_index[layer_mask]

        sat = torch.index_select(pred, dim=0, index=sink_node)

        print('sat simulation equals soft evaluation: ', (sat>0.5).float() == sat_simulated)
        print('sat simulation: ', sat_simulated)
        print('sat soft evaluation: ', sat)
            
        return sat
    
    else: # NOTE:
        return None


def logic(gate_type, signals):
    if gate_type == 1:  # AND
        for s in signals:
            if s == 0:
                return 0
        return 1

    elif gate_type == 2:  # NAND
        for s in signals:
            if s == 0:
                return 1
        return 0

    elif gate_type == 3:  # OR
        for s in signals:
            if s == 1:
                return 1
        return 0

    elif gate_type == 4:  # NOR
        for s in signals:
            if s == 1:
                return 0
        return 1

    elif gate_type == 5:  # NOT
        for s in signals:
            if s == 1:
                return 0
            else:
                return 1

    # elif gate_type == 6:  # BUFF
    #  for s in signals:
    #      return s

    elif gate_type == 6:  # XOR
        z_count = 0
        o_count = 0
        for s in signals:
            if s == 0:
                z_count = z_count + 1
            elif s == 1:
                o_count = o_count + 1
        if z_count == len(signals) or o_count == len(signals):
            return 0
        return 1

def pyg_simulation(g, pattern=[]):
    # PI, Level list
    max_level = 0
    PI_indexes = []
    fanin_list = []
    for idx, ele in enumerate(g.forward_level):
        level = int(ele)
        fanin_list.append([])
        if level > max_level:
            max_level = level
        if level == 0:
            PI_indexes.append(idx)
    level_list = []
    for level in range(max_level + 1):
        level_list.append([])
    for idx, ele in enumerate(g.forward_level):
        level_list[int(ele)].append(idx)
    # Fanin list 
    for k in range(len(g.edge_index[0])):
        src = g.edge_index[0][k]
        dst = g.edge_index[1][k]
        fanin_list[dst].append(src)
    
    ######################
    # Simulation
    ######################
    y = [0] * len(g.x)
   
    j = 0
    for i in PI_indexes:
        y[i] = pattern[j]
        j = j + 1
    for level in range(1, len(level_list), 1):
        for node_idx in level_list[level]:
            source_signals = []
            for pre_idx in fanin_list[node_idx]:
                source_signals.append(y[pre_idx])
            if len(source_signals) > 0:
                if int(g.x[node_idx][1]) == 1:
                    gate_type = 1
                elif int(g.x[node_idx][2]) == 1:
                    gate_type = 5
                elif int(g.x[node_idx][3]) == 1:
                    gate_type = 3
                else:
                    raise("This is PI")
                y[node_idx] = logic(gate_type, source_signals)

    # Output
    if len(level_list[-1]) > 1:
        raise('Too many POs')
    return y[level_list[-1][0]], pattern