""" Train ensemble"""
import torch
import torch.nn as nn
import argparse
from torch.optim import Adam, SGD
from torch.optim.lr_scheduler import CosineAnnealingLR, StepLR
import time
import os

from ml_models import get_model
from ml_datasets import get_dataloaders, ds_choices
from ml_common import train, set_seed, enable_gpu_benchmarking, get_device

set_seed(2020)
enable_gpu_benchmarking()


def main():
    parser = argparse.ArgumentParser(description="train ensemble")
    parser.add_argument(
        "--dataset", type=str, default="mnist", help="dataset", choices=ds_choices,
    )
    parser.add_argument("--epochs", type=int, default=10, help="number of epochs")
    parser.add_argument("--batch_size", type=int, default=128, help="batch size")
    parser.add_argument("--model", type=str, default="conv3", help="model name")
    parser.add_argument(
        "--n_models", type=int, default=1, help="number of models to train"
    )
    parser.add_argument(
        "--exp_id", type=str, default="alpha", help="name of the experiment"
    )
    parser.add_argument("--lr", type=float, default=0.001, help="learning rate")

    parser.add_argument("--augment", action="store_true", help="augment dataset")
    parser.add_argument(
        "--pretrained", action="store_true", help="Use pretrained model"
    )
    parser.add_argument("--disable_pbar", action="store_true", help="disable pbar")

    args = parser.parse_args()

    dataloader_train, dataloader_test = get_dataloaders(
        args.dataset, args.batch_size, augment=args.augment
    )
    exp_path = f"./exp/{args.dataset}/{args.exp_id}/"
    device = get_device()

    if not os.path.exists(exp_path):
        os.makedirs(exp_path)

    criterion = nn.CrossEntropyLoss()
    for model_id in range(args.n_models):
        print(f"\n==Training Model {model_id+1}/{args.n_models}==")

        model = get_model(args.model, args.dataset, args.pretrained)
        model = model.to(device)

        opt = SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
        sch = CosineAnnealingLR(opt, args.epochs, last_epoch=-1)

        train(
            model,
            dataloader_train,
            dataloader_test,
            criterion,
            opt,
            args.epochs,
            sch,
            disable_pbar=args.disable_pbar,
        )
        torch.save(model.state_dict(), f"{exp_path}/T{model_id}.pt")


if __name__ == "__main__":
    start = time.time()
    main()
    end = time.time()
    runtime = end - start
    print("Runtime: {:.2f} s".format(runtime))
