from collections import namedtuple
import numpy as np

from scipy.special import softmax

PRIMITIVES = [
    'none',
    'max_pool_3x3',
    'avg_pool_3x3', #
    'skip_connect',
    'sep_conv_3x3',
    'sep_conv_5x5', #
    'dil_conv_3x3',
    'dil_conv_5x5'  #
]

Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat')

def genotype(weights, steps=4, multiplier=4):
    def _parse(weights):
        gene = []
        n = 2
        start = 0
        for i in range(steps):
            end = start + n
            W = weights[start:end].copy()
            edges = sorted(range(i + 2), key=lambda x: -max(
                W[x][k] for k in range(len(W[x])) if k != PRIMITIVES.index('none')))[:2]
            for j in edges:
                k_best = None
                for k in range(len(W[j])):
                    if k != PRIMITIVES.index('none'):
                        if k_best is None or W[j][k] > W[j][k_best]:
                            k_best = k
                gene.append((PRIMITIVES[k_best], j))
            start = end
            n += 1
        return gene
        
    gene_normal = _parse(softmax(weights[0], axis=-1))
    gene_reduce = _parse(softmax(weights[1], axis=-1))

    concat = range(2+steps-multiplier, steps+2)
    genotype = Genotype(
        normal=gene_normal, normal_concat=concat,
        reduce=gene_reduce, reduce_concat=concat
    )
    return genotype

size=[14 * 2, 7]
new_weights = [np.random.random_sample(size) for _ in range(10)]
new_genos = [genotype(w.reshape(2, -1, size[-1])) for w in new_weights]

print(new_genos)