from dataset import *
from train import *
import torch
from genotypes import *


def search(search_num):
    path = './config.yaml'
    config = read_yaml(path)
    device="cuda:0"
    best_model_dir = './results/trainmodel-2022-04-11~19:54:01/best_model.pth.tar'
    model = torch.load(best_model_dir).to(device)
    dataset = DataSetDarts(0).generate_random_dataset(search_num)
    acc, tup = test_batch(model, dataset, config, device)
    print("Best acc: {:05f}, Best tuple: {}".format(acc, str(tup)))
    print(transfer_geno(tup))


def transfer_geno(best_tuple):
    normal_cell = best_tuple[0]
    normal_tup = []
    for (node, op) in normal_cell:
        normal_tup.append((PRIMITIVES[op], node))
    reduc_cell = best_tuple[1]
    reduc_tup = []
    for (node, op) in reduc_cell:
        reduc_tup.append((PRIMITIVES[op], node))
    cifar = Genotype(normal=normal_tup, normal_concat=[2, 3, 4, 5], reduce=reduc_tup, reduce_concat=[2, 3, 4, 5])
    return cifar


if __name__ == "__main__":
    search(100000)