import os, torch, optuna, joblib
import tempfile
from torch import nn
from options import parse_args
from train import Trainer, get_dataloaders, init_model, init_opt, mean

def objective(trial):
    ent_cfg = {i: {'c1': trial.suggest_float('%d-c1'%i, -10, 10, step=2),
        'c2': trial.suggest_float('%d-c2'%i, -0.5, 1.5, step=0.25)} for i in range(3)}

    bests, steps, tests = [], [], []
    for seed in range(3):
        torch.manual_seed(seed)
        model, tokenizer, curr, name, step = init_model(args, device, ent_cfg)
        name = next(tempfile._get_candidate_names())
        print(name)
        optimizer = init_opt(model, args)
        crit = nn.CrossEntropyLoss(reduction='none')
        writer = None

        trainer = Trainer(model, tokenizer, crit, optimizer, curr, args.epochs,
                writer, name, step, epoch_size, args.debug, device, args)

        print('[Starting Training]')
        trainer.train(train_dataloader, dev_dataloader, test_samples)
        print('[Testing]')
        best_acc, best_step = trainer.load_best()
        test_acc = trainer.evaluate(test_dataloader)[2]
        print('Acc:', best_acc)
        trainer.cleanup()

        bests.append(best_acc)
        tests.append(test_acc)
        steps.append(best_step/epoch_size)

    trial.set_user_attr('best_steps', steps)
    trial.set_user_attr('test_accs', tests)
    trial.set_user_attr('dev_accs', bests)
    return mean(bests)

if __name__ == '__main__':
    args = parse_args()
    device = torch.device('cuda:%d'%args.gpu if torch.cuda.is_available() else 'cpu')
    
    train_dataloader, dev_dataloader, test_dataloader, test_samples,\
            train_dataset, _ = get_dataloaders(args)
    epoch_size = len(train_dataloader)

    study_path = os.path.join(args.study_dir, args.study_name)
    saver = lambda study, _: joblib.dump(study, study_path)

    if os.path.isfile(study_path):
        study = joblib.load(study_path)
    else:
        # sampler = optuna.samplers.CmaEsSampler(n_startup_trials = 30)
        study = optuna.create_study(study_name = args.study_name, direction='maximize')

    study.optimize(objective, n_trials=400, callbacks = [saver], n_jobs = 3)
