from data.Dataloaders import pick_dataset
from models.DisCoNet import DisCoNet
from utils.util import parse_args_DisCoNet
import torch
import wandb

args = parse_args_DisCoNet()

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Set image size and number of channels
if args.dataset == 'mnist':
    img_size = 32
    channels = 1
elif args.dataset == 'cifar10':
    img_size = 32
    channels = 3
elif args.dataset == 'tinyimagenet':
    img_size = 64
    channels = 3
elif args.dataset == 'imagenet':
    img_size = 128
    channels = 3

# Initialize wandb
wandb.init(project='DisCoNet',
            config={
                'dataset': args.dataset,
                'batch_size': args.batch_size,
                'n_epochs': args.n_epochs,
                'latent_dim': args.latent_dim,
                'hidden_dims': args.hidden_dims,
                'lr': args.lr,
                'gen_weight': args.gen_weight,
                'recon_weight': args.recon_weight,
                'sample_and_save_frequency': args.sample_and_save_frequency
            },
            name = 'DisCoNet_{}'.format(args.dataset))

# Load dataset, initialize model and train
train_dataloader = pick_dataset(name = args.dataset, train=True, batch_size=args.batch_size, img_size=img_size)
model = DisCoNet(input_shape = img_size, device = device, input_channels = channels, latent_dim = args.latent_dim, n_epochs = args.n_epochs, hidden_dims = args.hidden_dims, lr = args.lr, batch_size = args.batch_size, gen_weight = args.gen_weight, recon_weight=args.recon_weight, sample_and_save_frequency = args.sample_and_save_frequency, dataset=args.dataset)
model.train_model(train_dataloader, train_dataloader)

# Finish wandb
wandb.finish()