import sys
sys.path.insert(0, '../')
sys.path.insert(0, '../../')

import numpy as np
from network_designer.design_space.core.base import Model
from network_designer.design_space.nasbench101.network import Network
from network_designer.design_space.nasbench101.nasbench1_spec import ModelSpec
from network_designer.design_space.core.utils import remove_zero_rows_cols


class NB101LikeModel(Model):
    def __init__(self, adj=None, ops=None, arch=None, design_space=None):
        super().__init__(adj, ops, arch, design_space)
        
    def create(self):
        self.model = self.create_model_from_graph(self.adj, self.ops)
            
        
    def create_model_from_graph(self, adj, ops):
        o = ops.astype(int)
        a = adj.astype(int)

        labels = []
        for idx, ops_f in enumerate(list(o)):
            if sum(ops_f) > 0:
                if ((sum(a[:, idx]) >0) or (idx in [self.design_space.INPUT_NODE, self.design_space.OUTPUT_NODE])):

                    labels.append(self.design_space.ops[np.argmax(ops_f)])
        
        a = remove_zero_rows_cols(a)
        #print(a)
        spec = ModelSpec(a, labels)
        model = Network(spec, num_labels=10, in_channels=3,
                stem_out_channels=128, num_stacks=3,
                num_modules_per_stack=3,
                momentum=0.9, eps=1e-5, tf_like=False)
        
        self.hash = spec.hash_spec(self.design_space.ops)
        
        return model.cuda()