## Transfer learning NAS-Bench-201 optimisation setup

# The script first run optimisation on NAS-Bench-201 on the base task. We then use the motifs generated by the
# the surrogate GP as prior to run optimisation on CIFAR-100 and ImageNet tasks.

import argparse
import datetime
import os
import pickle
import time

import torch
from tabulate import tabulate

import bayesopt
from bayesopt.generate_test_graphs import random_sampling, mutation
from bayesopt.gp import GraphGP
from bayesopt.interpreter import Interpreter
from benchmarks import NAS201
from kernels import WeisfilerLehman
from misc.find_stuctures import find_wl_feature

parser = argparse.ArgumentParser(description='Transfer Learning NAS-Bench-201')
parser.add_argument('--base_task', default='cifar10-valid', help='the base task to first run optimisation on')
parser.add_argument('--n_repeat', type=int, default=20)
parser.add_argument('--data_path', default='./data')
parser.add_argument('--n_init', type=int, default=200)
parser.add_argument('--base_max_iters', type=int, default=48)
parser.add_argument('--transfer_max_iters', type=int, default=38)
parser.add_argument('--save_path', default='./results')
parser.add_argument('--batch_size', type=int, default=5)
parser.add_argument('--fixed_query_seed', type=int, default=None)
parser.add_argument('--load_from_cache', action='store_true')
parser.add_argument('--threshold', default=30)

args = parser.parse_args()
options = vars(args)
print(options)
tasks = ['cifar10-valid', 'cifar100', 'ImageNet16-120']
assert args.base_task in tasks

columns = ['Iteration', 'Best func val', 'Best func test', 'Time', ]


def filter_pool(pool, include_list, exclude_list, kernel, ):
    """Given a pool of candidate architectures and feature_list, accept only those architectures that match one
    of one of the features listed. (include only operation)"""
    if include_list is None or not len(include_list):
        if exclude_list is None or not len(exclude_list):
            return pool
    pruned_pool = []

    for p in pool:
        found = False
        if include_list is not None:
            for f in include_list:
                if find_wl_feature(p, (f,), kernel):
                    found = True
                    break
            if not found: continue
        if exclude_list is not None and not found:
            found = False
            for f in exclude_list:
                if find_wl_feature(p, (f,), kernel):
                    break
            if found: continue
        pruned_pool.append(p)
    return pruned_pool


def train(sampler, max_iters, include_feats=None, exclude_feats=None, base_kernel=None):
    """Main train loop """
    columns = ['Iteration', 'Best func val', 'Best func test', 'Time', ]
    start_time = time.time()
    best_tests = []
    best_vals = []
    x = []
    while len(x) < args.n_init:
        cand = random_sampling(args.n_init, benchmark='nasbench201', )[0]
        cand = filter_pool(cand, include_feats, exclude_feats, base_kernel)
        x += cand
    x = x[:args.n_init]
    y_np_list = [sampler.eval(x_) for x_ in x]
    y = torch.tensor([y[0] for y in y_np_list]).float()
    train_details = [y[1] for y in y_np_list]
    test = torch.tensor([sampler.test(x_) for x_ in x])

    # Initialise the surrogate
    k = WeisfilerLehman(oa=False, h=1, requires_grad=True)
    base_gp = GraphGP(x, y, [k], )

    for i in range(max_iters):
        base_gp.fit(wl_subtree_candidates=())
        pool = []
        while len(pool) < 200:
            cand = \
            mutation(x, y, benchmark='nasbench201', pool_size=200, n_best=10, n_mutate=100, allow_isomorphism=True)[0]
            cand = filter_pool(cand, include_feats, exclude_feats, base_kernel)
            pool += cand
        pool = pool[:200]
        a = bayesopt.GraphExpectedImprovement(base_gp)
        next_x, eis, indices = a.propose_location(top_n=args.batch_size, candidates=pool)

        # set up the next iteration
        detail = [o.eval(x_) for x_ in next_x]
        next_y = [y[0] for y in detail]
        train_details += [y[1] for y in detail]
        next_test = [o.test(x_).item() for x_ in next_x]

        x.extend(next_x)
        y = torch.cat((y, torch.tensor(next_y).view(-1))).float()
        test = torch.cat((test, torch.tensor(next_test).view(-1)))
        base_gp.reset_XY(x, y)
        end_time = time.time()

        # current best
        best_val = torch.exp(-torch.max(y))
        best_test = torch.exp(-torch.max(test))

        values = [str(i), best_val.item(), best_test.item(), str(end_time - start_time), ]
        table = tabulate([values], headers=columns, tablefmt='simple', floatfmt='8.4f')

        best_vals.append(best_val)
        best_tests.append(best_test)

        if i % 40 == 0:
            table = table.split('\n')
            table = '\n'.join([table[1]] + table)
        else:
            table = table.split('\n')[2]
        print(table)
    return base_gp, k, best_vals, best_tests


cache_path = args.data_path + '/nasbench201.pickle'
o = None
if args.load_from_cache:
    if os.path.exists(cache_path):
        try:
            o = pickle.load(open(cache_path, 'rb'))
            o.seed = args.fixed_query_seed
            o.task = args.base_task
        except:
            pass
if o is None:
    o = NAS201(args.data_path, args.base_task, seed=args.fixed_query_seed)

all_res = []
for run in range(args.n_repeat):
    print('######## STARTING REPEAT %d / %d ########' % (run + 1, args.n_repeat))
    res = {'iterative': None, 'transfer1': None, 'transfer2': None}
    o.task = args.base_task
    print('-------- Starting Base Task Optimisation ----------')
    base_gp, base_kernel, base_val, bese_test = train(o, args.base_max_iters)
    print('------------- Base Task Completed --------------')
    transfer_tasks = [t for t in tasks if t != args.base_task]
    for i, t in enumerate(transfer_tasks):
        interpreter = Interpreter(gp=base_gp, thres=args.threshold)
        o.task = t
        include = [interpreter.feat_list[i] for i in interpreter.good_idx]
        print('Preserving Motifs: ', include)
        _, _, tr_val, tr_test = train(o, args.transfer_max_iters, include, None, base_kernel)
        res['transfer' + str(i + 1)] = [tr_val, tr_test]
    all_res.append(res)
time_string = datetime.datetime.now()
time_string = time_string.strftime('%Y%m%d_%H%M%S')
pickle.dump(all_res, open(args.save_path + '/transfer_201' + time_string + '.pickle', 'wb'))
print('All done')
