from network_designer.design_space.nasbench201.graph_base import Graph
from .operations import OPS as PRIMITIVE
import torch.nn as nn
import numpy as np
from networkx.algorithms.dag import lexicographical_topological_sort

class Identity(nn.Module):

    def __init__(self):
        super(Identity, self).__init__()

    def forward(self, x):
        return x
    
class NB201_Graph(Graph):
    def __init__(self, adj_matrix, ops, search_space, channel_in, channel_out,  affine=True, track_running_stats=True): 
        super(NB201_Graph, self).__init__(adj_matrix, ops, search_space, channel_in, channel_out)

        #configuration of enter and out point
        self.configure_input_output_index(0, 7)
        
        self.affine = affine
        self.track_running_stats = track_running_stats

        self.create_nodes_from_ops(ops, adj_matrix)
        
    def aggregate(self, x):
        return sum(x)
    
    def create_nodes_from_ops(self, ops, adj):
        n = adj.shape[0]  # Get the size of the square array
        indices = np.arange(n)
        zero_indices = indices[np.all(adj == 0, axis=0) & np.all(adj == 0, axis=1)]
        
        self.op_nodes = nn.ModuleDict()
        op_names = {}
        for node_idx in lexicographical_topological_sort(self.graph):
            if node_idx in zero_indices:
                op_name = 'none'
                op_names[str(node_idx)]=op_name
                node = PRIMITIVE[op_name](self.in_dim, self.out_dim, 1, self.affine, self.track_running_stats)
                self.op_nodes[str(node_idx)]=node
            elif node_idx == self.input_index or node_idx == self.output_index:
                node = Identity()
                self.op_nodes[str(node_idx)]=node
            else:
                ops_feature = ops[node_idx]
                if sum(ops_feature) == 0:
                    op_name = 'none'
                    op_names[str(node_idx)]=op_name
                else:
                    #add 1 because we have remove identity and just change in adj matrix
                    op_name = self.search_space[np.argmax(ops_feature)+1]
                    op_names[str(node_idx)]=op_name
                node = PRIMITIVE[op_name](self.in_dim, self.out_dim, 1, self.affine, self.track_running_stats)
                self.op_nodes[str(node_idx)]=node
        #print(op_names)
        
    def forward(self, inputs):
        #dict to save intermediate tensor
        index_mediate_nodes = {}
        
        #enter point
        #print('new graph')
        index_mediate_nodes[self.input_index]=inputs
        for node_idx in lexicographical_topological_sort(self.graph):
            #node = self.graph.nodes[node_idx]
            predecessors = [i for i in self.graph.predecessors(node_idx)]
            if len(predecessors) != 0:
                #get all predecessor output from dictionary, ignore those predecessor do not connect to input
                if len(predecessors) == 1 and (predecessors[0] in index_mediate_nodes.keys()):
                    inter_x = index_mediate_nodes[predecessors[0]]
                else:
                    inter_x = self.aggregate([index_mediate_nodes[x] for x in predecessors if (x in index_mediate_nodes.keys())])
                index_mediate_nodes[node_idx] = self.op_nodes[str(node_idx)](inter_x)
            
        return index_mediate_nodes[self.output_index]