from data.Dataloaders import pick_dataset
from models.DCGAN import DCGAN
from utils.util import parse_args_DCGAN
import torch
import wandb

args = parse_args_DCGAN()

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Set image size and number of channels
if args.dataset == 'mnist':
    img_size = 32
    channels = 1
elif args.dataset == 'cifar10':
    img_size = 32
    channels = 3
elif args.dataset == 'tinyimagenet':
    img_size = 64
    channels = 3

# Initialize wandb
wandb.init(project='DCGAN',
            config={
                'dataset': args.dataset,
                'batch_size': args.batch_size,
                'n_epochs': args.n_epochs,
                'latent_dim': args.latent_dim,
                'd': args.d,
                'lrg': args.lrg,
                'lrd': args.lrd,
                'beta1': args.beta1,
                'beta2': args.beta2,
                'sample_and_save_freq': args.sample_and_save_freq
            },
            name = 'DCGAN_{}'.format(args.dataset))

# Load dataset, initialize model and train
train_dataloader = pick_dataset(name = args.dataset, train=True, batch_size=args.batch_size, img_size=img_size)
model = DCGAN(n_epochs = args.n_epochs, device = device, latent_dim = args.latent_dim, d = args.d, channels = channels, lrg = args.lrg, lrd=args.lrd, beta1 = args.beta1, beta2 = args.beta2, img_size = img_size, sample_and_save_freq = args.sample_and_save_freq, dataset=args.dataset)
model.train_model(train_dataloader)

# Finish wandb
wandb.finish()