import argparse
import sys
sys.path.append("..")

from models.cifar10_vgg_selectivenet_curriculum import cifar10vgg_curr as cifar10Selective_curr_vgg
from models.cifar10_cnn_selectivenet_curriculum import cifar10cnn_curr as cifar10Selective_curr_cnn
from models.svhn_cnn_selectivenet_curriculum import Svhncnn_curr as SVHNSelective_curr_cnn
from selectivnet_utils import *

MODELS = {"cifar10_curriculum_vgg": cifar10Selective_curr_vgg,
          "cifar10_curriculum_cnn": cifar10Selective_curr_cnn,
          "cifar100_curriculum_vgg": cifar10Selective_curr_vgg,
          "cifar100_curriculum_cnn": cifar10Selective_curr_cnn,
          "SVHN_curriculum_cnn": SVHNSelective_curr_cnn
          }



parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, default='vanilla')
parser.add_argument('--dataset', type=str, default='cifar10')

parser.add_argument('--exp_name', type=str, default='test')
parser.add_argument('--baseline', type=str, default='none')
parser.add_argument('--alpha', type=float, default=0.5)
parser.add_argument('--beta', type=float, default=1)
parser.add_argument('--lamda', type=float, default=32)
parser.add_argument('--random_percent', type=int, default=-1)
parser.add_argument('--random_strategy', type=str, default='feature')
parser.add_argument('--curriculum_strategy', type=str, default='curriculum')
parser.add_argument('--order_strategy', type=str, default='inception')
parser.add_argument('--logfile', type=str, default='training.log')
parser.add_argument('--datapath', type=str, default=None)
parser.add_argument('--repeats', type=int, default=1)

args = parser.parse_args()

print("experiment arguments: {}".format(args))

model_cls = MODELS[args.dataset+"_"+args.model]
exp_name = args.exp_name
baseline_name = args.baseline
logfile = args.logfile
datapath = args.datapath
random_percent = args.random_percent
random_strategy = args.random_strategy

coverages = [0.95, 0.9, 0.85, 0.8, 0.75, 0.7]

for repeat in range(args.repeats):
    print("====================repeat {}==============".format(repeat))
    if baseline_name == "none":
        results = train_profile(exp_name, model_cls, coverages, dataset=args.dataset, alpha=args.alpha, beta=args.beta, lamda=args.lamda, random_percent=random_percent, random_strategy=random_strategy, order_strategy=args.order_strategy, logfile=logfile, datapath=datapath, args=args)
    else:
        model_baseline = model_cls(train=to_train("{}.h5".format(baseline_name)),
                                   filename="{}.h5".format(baseline_name),
                                   baseline=True)
        results = train_profile(exp_name, model_cls, coverages, dataset=args.dataset, model_baseline=model_baseline, alpha=args.alpha, beta=args.beta, random_percent=random_percent, random_strategy=random_strategy, order_strategy=args.order_strategy, logfile=logfile, datapath=datapath)
