from collections import OrderedDict
from copy import deepcopy
import ml_collections



# genotype
class Structure:
    def __init__(self, genotype_str):
        # genotype: e.g., "4,4,2,2|2,1,1,3|2,1,3,1|2,1,3,1|32"
        assert isinstance(genotype_str, str)
        self.genotype_str = genotype_str
        genotype = genotype_str.split('|')
        self.num_stages = len(genotype[:-1])
        self.patch_sizes = [int(stage.split(',')[0]) for stage in genotype[:-1]]
        self.window_sizes = [int(stage.split(',')[1]) for stage in genotype[:-1]]
        self.num_mlps = [int(stage.split(',')[2]) for stage in genotype[:-1]]
        self.mlp_ratios = [int(stage.split(',')[3]) for stage in genotype[:-1]]
        head = int(genotype[-1].split(",")[0])
        self.heads = [head//8, head//4, head//2, head]
        try:
            self.num_layers = [int(stage.split(',')[4]) for stage in genotype[:-1]]
        except Exception:
            self.num_layers = [1] * 4
        try:
            self.hidden_size = int(genotype[-1].split(",")[1])
        except Exception:
            self.hidden_size = 32

    def tostr(self):
        string = []
        for p, w, n, r, l in zip(self.patch_sizes, self.window_sizes, self.num_mlps, self.mlp_ratios, self.num_layers):
            string.append("%d,%d,%d,%d,%d"%(p, w, n, r, l))
        string = '|'.join(string) + "|" + "%d,%d"%(self.heads[-1], self.hidden_size)
        self.genotype_str = string
        return string

    def check_valid(self):
        nodes = {0: True}
        for i, node_info in enumerate(self.nodes):
            sums = []
            for op, xin in node_info:
                if op == 'none' or nodes[xin] is False: x = False
                else: x = True
                sums.append( x )
            nodes[i+1] = sum(sums) > 0
        return nodes[len(self.nodes)]

    def __repr__(self):
        return ('{name}({node_info})'.format(name=self.__class__.__name__, node_info=self.tostr()))

    def __len__(self):
        return len(self.num_stages)

    def __getitem__(self, index):
        return self.genotype.stages[index]

    @staticmethod
    def gen_all(search_space, num, return_genotype):
        assert isinstance(search_space, dict) or isinstance(search_space, OrderedDict), 'invalid class of search-space : {:}'.format(type(search_space))
        assert num >= 2, 'There should be at least two nodes in a neural cell instead of {:}'.format(num)
        all_archs = [[value] for value in search_space['embedding']]

        for inode in range(1, num):
            print("build node %d..."%inode)
            cur_nodes = get_combination(search_space['encoder'], inode)
            new_all_archs = []
            for previous_arch in all_archs:
                for cur_node in cur_nodes:
                    new_all_archs.append( previous_arch + [tuple(cur_node)] )
            all_archs = new_all_archs
        if return_genotype:
            return all_archs
        else:
            # return CellStructure
            return [Structure(x) for x in all_archs]
