import random
from network_designer.design_space.core.design_space_base import DesignSpace
from network_designer.design_space.core.search_space import SearchSpace, mutate_arch
from network_designer.design_space.core.utils import remove_zero_rows_cols
from network_designer.design_space.nasbench101.nasbench1_spec import ModelSpec
import numpy as np
import sys

class NB101LikeDesignSpace(DesignSpace):
    def __init__(self, args, n_classes=10):
        super(NB101LikeDesignSpace, self).__init__(args=args)
  
        #design space settings
        self.candidate_operations = ['conv3x3-bn-relu', 'conv1x1-bn-relu', 'maxpool3x3','input', 'output']
        #only 5 intermediate node
        self.design_space = SearchSpace(5, 4)
        
        #only has cifar10
        self.n_classes = n_classes
        
        self.INPUT_NODE = 0
        self.OUTPUT_NODE = 6
        
    def get_hash(self, adj, ops):
        o = ops.astype(int)
        a = adj.astype(int)
        o = ops[1:]
        a = a[1:, 1:]
        a = np.triu(a, 1)

        labels = []
        for idx, ops_f in enumerate(list(o)):
            if sum(ops_f) > 0:
                if ((sum(a[:, idx]) >0) or (idx in [self.INPUT_NODE, self.OUTPUT_NODE])):

                    labels.append(self.candidate_operations[np.argmax(ops_f)])
        
        a = remove_zero_rows_cols(a)
        #print(a)
        spec = ModelSpec(a, labels)
        return spec.hash_spec(self.candidate_operations)
        
    def sample_unique_graphs(self, sample_size, seed):
        random.seed(seed)
        adj_matrix = []
        ops_features = []
        hash_list = []
        try_out = 0
        time_out = sample_size * 10
        while(len(hash_list)<sample_size and try_out<time_out):
            adj, ops = self.design_space.sample()
            ops = ops.astype(int)
            adj = adj.astype(int)
            hash = self.get_hash(adj, ops)
            if hash not in hash_list:
                hash_list.append(hash)
                ops_features.append(ops)
                adj_matrix.append(adj)
                sys.stdout.write('\r')
                sys.stdout.write('| Sampling [%3d/%3d]' 
                                % (len(hash_list), sample_size))
                sys.stdout.flush()
            try_out+=1

        return adj_matrix, ops_features, hash_list
    
    def sample(self):
        return self.design_space.sample()
        
    def mutate_arch(self, adj, ops):
        return mutate_arch(adj, ops)
            

    
    