import random
import numpy as np
import torch
import os
import copy
from utils.config import parse_args
from utils.data_loader import get_data_loader
from models.wgan_gradient_penalty_VRAda import WGAN_GP_VRAda
from models.wgan_gradient_penalty import WGAN_GP


def main(args):
    if args.optim == 'VRAda':
        model = WGAN_GP_VRAda(args)
    else:
        model = WGAN_GP(args)
    train_loader, _ = get_data_loader(args)  # We're not using test_loader in this function
    Real_Inception_score = []
    # Start model training
    model.train(train_loader, Real_Inception_score, args)


if __name__ == '__main__':
    random.seed(8)
    np.random.seed(8)
    torch.manual_seed(8)
    lr_values = [0.005]
    datasets = ['cifar10', 'cifar100', 'stl10']
    optimizers = ['adam', 'tiada', 'tiada-adam', 'VRAda']
    base_args = parse_args()
    for optimizer in optimizers:
        for dataset in datasets:
            for lr in lr_values:  # Only one loop for lr
                args = copy.copy(base_args)
                args.optim = optimizer
                args.lr_x = lr
                args.lr_y = lr  # lr_x is equal to lr_y
                args.lr = lr
                args.dataset = dataset
                # Run the main function with the current hyperparameter combination
                main(args)
