import os, torch, optuna, joblib
import numpy as np
from torch import nn
from options import parse_args
from train import Trainer, get_dataloaders, init_model, init_opt, mean

def objective(trial):
    # lam = trial.suggest_float('lambda', 1e-5, 100, log=True)
    # args.sl_lam = lam

    alpha = trial.suggest_float('alpha', 0, 1, step = 0.1)
    args.dp_alpha = alpha

    bests, steps, tests = [], [], []
    for seed in range(5):
        torch.manual_seed(seed)
        model, tokenizer, curr, name, step = init_model(args, device)
        optimizer = init_opt(model, args)
        crit = nn.CrossEntropyLoss(reduction='none')
        writer = None

        trainer = Trainer(model, tokenizer, crit, optimizer, curr, args.epochs,
                writer, name, step, epoch_size, args.debug, device, args)

        print('[Starting Training]')
        trainer.train(train_dataloader, dev_dataloader, test_samples)
        print('[Testing]')
        best_acc, best_step = trainer.load_best()
        test_acc = trainer.evaluate(test_dataloader)[2]
        print('Acc:', best_acc)
        trainer.cleanup()

        bests.append(best_acc)
        tests.append(test_acc)
        steps.append(best_step/epoch_size)

    trial.set_user_attr('best_steps', steps)
    trial.set_user_attr('test_accs', tests)
    trial.set_user_attr('dev_accs', bests)
    return mean(bests)

if __name__ == '__main__':
    args = parse_args()
    device = torch.device('cuda:%d'%args.gpu if torch.cuda.is_available() else 'cpu')
    
    train_dataloader, dev_dataloader, test_dataloader, test_samples,\
            train_dataset, _ = get_dataloaders(args)
    epoch_size = len(train_dataloader)

    if args.curr == 'dp':
        args.dp_tao = np.percentile(train_dataset['diff'], 50)

    study_path = os.path.join(args.study_dir, args.study_name)
    saver = lambda study, _: joblib.dump(study, study_path)

    if os.path.isfile(study_path):
        study = joblib.load(study_path)
    else:
        study = optuna.create_study(study_name = args.study_name, direction='maximize')

    study.optimize(objective, n_trials=100, callbacks = [saver])
