""" Train ensemble"""
import torch
import torch.nn as nn
import argparse
from torch.optim import Adam, SGD
from torch.optim.lr_scheduler import CosineAnnealingLR
import time
import os

from ml_models import get_model
from ml_datasets import get_dataloaders, ds_choices
from ml_common import set_seed, save_model, enable_gpu_benchmarking, get_device
from utils import train_diverse

set_seed(2020)
enable_gpu_benchmarking()


def main():
    parser = argparse.ArgumentParser(description="train ensemble")
    parser.add_argument(
        "--dataset", type=str, default="mnist", help="dataset", choices=ds_choices,
    )
    parser.add_argument(
        "--dataset_ood", type=str, default="kmnist", help="dataset", choices=ds_choices,
    )
    parser.add_argument("--epochs", type=int, default=10, help="number of epochs")
    parser.add_argument("--batch_size", type=int, default=128, help="batch size")
    parser.add_argument("--model", type=str, default="conv3", help="model name")
    parser.add_argument(
        "--n_models", type=int, default=1, help="number of models to train"
    )
    parser.add_argument(
        "--exp_id", type=str, default="alpha", help="name of the experiment"
    )
    parser.add_argument(
        "--lamb",
        type=float,
        default=0.0,
        help="Weightage given for diversity objective",
    )
    parser.add_argument("--lr", type=float, default=0.001, help="learning rate")

    parser.add_argument("--augment", action="store_true", help="Use Pretrained model")
    parser.add_argument(
        "--pretrained", action="store_true", help="Use pretrained model"
    )

    parser.add_argument("--disable_pbar", action="store_true", help="Disable pbar")

    args = parser.parse_args()

    device = get_device()

    dataloader_train, dataloader_test = get_dataloaders(
        args.dataset, args.batch_size, augment=args.augment
    )
    dataloader_ood, _ = get_dataloaders(
        args.dataset_ood, args.batch_size, augment=args.augment
    )
    exp_path = f"./exp/{args.dataset}/{args.exp_id}/"

    criterion = nn.CrossEntropyLoss()
    model_params = []
    model_list = []
    for _ in range(args.n_models):
        model = get_model(args.model, args.dataset, args.pretrained)
        model = model.to(device)
        model_list.append(model)
        model_params += list(model.parameters())

    opt = SGD(model_params, lr=args.lr, momentum=0.9, weight_decay=5e-4)
    sch = CosineAnnealingLR(opt, args.epochs, last_epoch=-1)

    train_diverse(
        model_list,
        dataloader_train,
        dataloader_ood,
        dataloader_test,
        opt,
        criterion,
        args.epochs,
        sch,
        args.lamb,
        args.disable_pbar,
    )

    for i, model in enumerate(model_list):
        save_model(model, f"{exp_path}/T{i}.pt")


if __name__ == "__main__":
    start = time.time()
    main()
    end = time.time()
    runtime = end - start
    print("Runtime: {:.2f} s".format(runtime))
