""" Train ensemble"""
import torch
import torch.nn as nn
import argparse
from torch.optim import Adam, SGD
from torch.optim.lr_scheduler import CosineAnnealingLR, StepLR
import time
import os

from ml_models import get_model
from ml_datasets import get_dataloaders, ds_choices
from ml_common import train, set_seed, enable_gpu_benchmarking, get_device

set_seed(2020)
enable_gpu_benchmarking()


def main():
    parser = argparse.ArgumentParser(description="train ensemble")
    parser.add_argument(
        "--dataset", type=str, default="mnist", help="dataset", choices=ds_choices,
    )
    parser.add_argument("--epochs", type=int, default=10, help="number of epochs")
    parser.add_argument(
        "--n_classes", type=int, default=5, help="number of hash buckets"
    )
    parser.add_argument("--batch_size", type=int, default=128, help="batch size")
    parser.add_argument("--model", type=str, default="conv3", help="model name")
    parser.add_argument(
        "--exp_id", type=str, default="alpha", help="name of the experiment"
    )
    parser.add_argument("--lr", type=float, default=0.001, help="learning rate")

    parser.add_argument("--augment", action="store_true", help="augment dataset")
    parser.add_argument(
        "--pretrained", action="store_true", help="Use pretrained model"
    )

    args = parser.parse_args()

    dataloader_train, dataloader_test = get_dataloaders(
        args.dataset, args.batch_size, augment=args.augment
    )

    exp_path = f"./exp/hash/"
    device = get_device()

    if not os.path.exists(exp_path):
        os.makedirs(exp_path)

    criterion = nn.CrossEntropyLoss()
    print("\n==Training Model==")

    model = get_model(args.model, args.dataset, args.pretrained)
    model = model.to(device)

    opt = SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
    sch = CosineAnnealingLR(opt, args.epochs, last_epoch=-1)

    train(
        model, dataloader_train, dataloader_test, criterion, opt, args.epochs, sch,
    )
    torch.save(model.state_dict(), f"{exp_path}/{args.exp_id}.pt")


if __name__ == "__main__":
    start = time.time()
    main()
    end = time.time()
    runtime = end - start
    print("Runtime: {:.2f} s".format(runtime))
