from data.Dataloaders import pick_dataset
from models.PrescribedGAN import PresGAN
from utils.util import parse_args_PresGAN
import torch
import wandb

args = parse_args_PresGAN()

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Set image size and number of channels
if args.dataset == 'mnist':
    img_size = 32
    channels = 1
elif args.dataset == 'cifar10':
    img_size = 32
    channels = 3
elif args.dataset == 'tinyimagenet':
    img_size = 64
    channels = 3

# Initialize wandb
wandb.init(project="PresGAN",
               
    config = {
            "dataset": args.dataset,
            "batch_size": args.batch_size,
            "nz": args.nz,
            "ngf": args.ngf,
            "ndf": args.ndf,
            "lrD": args.lrD,
            "lrG": args.lrG,
            "beta1": args.beta1,
            "n_epochs": args.n_epochs,
            "sigma_lr": args.sigma_lr,
            "num_gen_images": args.num_gen_images,
            "restrict_sigma": args.restrict_sigma,
            "sigma_min": args.sigma_min,
            "sigma_max": args.sigma_max,
            "stepsize_num": args.stepsize_num,
            "lambda_": args.lambda_,
            "burn_in": args.burn_in,
            "num_samples_posterior": args.num_samples_posterior,
            "leapfrog_steps": args.leapfrog_steps,
            "hmc_learning_rate": args.hmc_learning_rate,
            "hmc_opt_accept": args.hmc_opt_accept,
            "flag_adapt": args.flag_adapt
    },

    name=f"PresGAN_{args.dataset}")

# Load dataset, initialize model and train
train_dataloader = pick_dataset(name = args.dataset, train=True, batch_size=args.batch_size, img_size=img_size)
model = PresGAN(imgSize=img_size, nz=args.nz, ngf = args.ngf, ndf = args.ndf, nc = channels, device = device, beta1 = args.beta1, lrD = args.lrD, lrG = args.lrG, n_epochs = args.n_epochs, sigma_lr=args.sigma_lr,
                num_gen_images=args.num_gen_images, restrict_sigma=args.restrict_sigma, sigma_min=args.sigma_min, sigma_max=args.sigma_max, stepsize_num=args.stepsize_num, lambda_=args.lambda_,
                burn_in=args.burn_in, num_samples_posterior=args.num_samples_posterior, leapfrog_steps=args.leapfrog_steps, hmc_learning_rate=args.hmc_learning_rate, hmc_opt_accept=args.hmc_opt_accept, flag_adapt=args.flag_adapt, 
                sample_and_save_freq=args.sample_and_save_freq, dataset=args.dataset)
model.train_model(train_dataloader)

# Finish wandb
wandb.finish()