import os
import sys
import glob
import numpy as np
import torch
import utils
import logging
import argparse
import torch.utils
import torch.backends.cudnn as cudnn

from search_model_predictor import NASNetwork as Network
import random

import genotypes
import shutil
parser = argparse.ArgumentParser("NAT")
parser.add_argument('--data', type=str, default='../data', help='location of the data corpus')
parser.add_argument('--gpu', type=int, default=0, help='gpu device id')
parser.add_argument('--init_channels', type=int, default=20, help='number of init channels')
parser.add_argument('--layers', type=int, default=8, help='total number of layers')
parser.add_argument('--cutout', action='store_true', default=False, help='use cutout')
parser.add_argument('--cutout_length', type=int, default=16, help='cutout length')
parser.add_argument('--save', type=str, default='derive', help='experiment name')
parser.add_argument('--seed', type=int, default=1234, help='random seed')
parser.add_argument('--n_archs', type=int, default=10, help='number of candidate archs')
parser.add_argument('--prefix', type=str, default='.', help='parent save path')
parser.add_argument('--pw', type=str, default='transformer.pt', help=' ')
parser.add_argument('--pwp', type=str, default='predictor.pt', help=' ')
args = parser.parse_args()

if not os.path.exists(args.prefix):
    os.makedirs(args.prefix)
args.save = os.path.join(args.prefix, args.save)

utils.create_exp_dir(args.save, scripts_to_save=glob.glob('*.py')+glob.glob('*.sh')+glob.glob('*.yml'))

log_format = '%(asctime)s %(message)s'
logging.basicConfig(stream=sys.stdout, level=logging.INFO,
                    format=log_format, datefmt='%m/%d %I:%M:%S %p')
fh = logging.FileHandler(os.path.join(args.save, 'log.txt'))
fh.setFormatter(logging.Formatter(log_format))
logging.getLogger().addHandler(fh)
logger = logging.getLogger()

arch_list = ['ResBlock', 'VGG', 'Mobilenetv2']
name_list = ['ResNet20', 'VGG16', 'MobileNetV2']

def main():
    if torch.cuda.is_available():
        torch.cuda.set_device(args.gpu)
        torch.cuda.manual_seed(args.seed)
        cudnn.benchmark = True
        cudnn.enabled = True
        logging.info('GPU device = %d' % args.gpu)
    else:
        logging.info('no GPU available, use CPU!!')
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    np.random.seed(args.seed)
    random.seed(args.seed)
    torch.manual_seed(args.seed)
    logging.info("args = %s" % args)

    model = Network( \
        args.init_channels, 10, 8, None, device, steps=4, controller_hid=100, entropy_coeff=[0.0, 0.0], edge_hid = 100, transformer_nfeat = 1024, transformer_nhid = 100, transformer_dropout = 0, transformer_normalize = False, loose_end = True, op_type='LOOSE_END_PRIMITIVES'\
    )

    model.re_initialize_arch_transformer()
    model._initialize_predictor(args, 'WarmUp')
    utils.load(model.arch_transformer, args.pw) 
    model.arch_transformer.eval()
    model.predictor.load_state_dict(torch.load(args.pwp, map_location='cpu'))

    model.to(device)

    model.predictor.eval()

    model.num_limit = True

    model.set_thre(args)

    model.derive = True

    for i in range(3):

        
        tmp_path = os.path.join(args.save, name_list[i])
        os.mkdir(tmp_path)

        genotype = eval("genotypes.%s" % arch_list[i])
        arch_normal, arch_reduce = utils.genotype_to_arch(genotype, 'LOOSE_END_PRIMITIVES')

        result = model.derive_optimized_arch(arch_normal, arch_reduce, 1, logger, tmp_path, "derive", normal_concat=genotype.normal_concat, reduce_concat=genotype.reduce_concat)

        torch.save(result, os.path.join(tmp_path, 'result_{}'.format(name_list[i])))

        shutil.copy(os.path.join(tmp_path, 'target.pdf'), os.path.join(tmp_path, 'target_{}.pdf'.format(name_list[i])))
        shutil.copy(os.path.join(tmp_path, 'disguised_target.pdf'), os.path.join(tmp_path, 'disguised_target_{}.pdf'.format(name_list[i])))

        os.remove(os.path.join(tmp_path, 'target.pdf'))
        os.remove(os.path.join(tmp_path, 'disguised_target.pdf'))
        os.remove(os.path.join(tmp_path, "normal_derive.pdf"))
        os.remove(os.path.join(tmp_path, "normal_derive"))
        os.remove(os.path.join(tmp_path, "reduce_derive.pdf"))
        os.remove(os.path.join(tmp_path, "reduce_derive"))
        os.remove(os.path.join(tmp_path, "disguised_normal_derive.pdf"))
        os.remove(os.path.join(tmp_path, "disguised_normal_derive"))
        os.remove(os.path.join(tmp_path, "disguised_reduce_derive.pdf"))
        os.remove(os.path.join(tmp_path, "disguised_reduce_derive"))

if __name__ == '__main__':
    main()
