import train_pX_lucidrains2 as lr2
import torch
import argparse
from tqdm.auto import tqdm
import os
import math
import numpy as np
from torchvision.utils import save_image


# ==================================================
# =           Model loading block                  =
# ==================================================


def load_model(checkpoint, version, ddim, ddim_numsteps):
	model = lr2.Unet(dim=64, 
					channels=1, 
					dim_mults = (1,2,4,8),
					num_classes=2,
					cond_drop_prob=0.2)

	if ddim == 0:
		ddim_numsteps = 1000

	diffusion = lr2.GaussianDiffusion(
		model, image_size=128, timesteps=1000, sampling_timesteps=ddim_numsteps)

	ckpt = torch.load(checkpoint, map_location='cpu')
	ema = lr2.EMA(diffusion, beta=0.995, update_every=10)

	ema.load_state_dict(ckpt['ema'])
	return ema




# ==================================================
# =           Sampling block                       =
# ==================================================
@torch.no_grad()	
def sample(ema, save_loc, device, batch_size, num_samples, c=None, pfx=''):
	os.makedirs(save_loc, exist_ok=True)
	C_PROB = 0.5397
	batch_sizes = []
	count = 0
	while count < num_samples:
		batch_sizes.append(min(num_samples - count, batch_size))
		count += batch_sizes[-1]

	ema = ema.to(device)
	img_count = 0
	for bsz in tqdm(batch_sizes):
		if c == None:
			classes = (torch.randn(bsz) < C_PROB).to(torch.int).long().to(device)
		elif c == 1:
			classes = torch.ones(bsz).long().to(device)
		elif c == 0:
			classes = torch.zeros(bsz).long().to(device)

		images = ema.ema_model.sample(classes=classes, cond_scale=1.0).cpu()

		for img in images:
			filename = os.path.join(save_loc, '%s_img_%05d.png' % (pfx,img_count))
			save_image(img, filename)
			img_count += 1






# =========================================
# =              MAIN BLOCK               =
# =========================================

def main():
	parser = argparse.ArgumentParser()
	parser.add_argument('--ckpt', type=str, required=True)
	parser.add_argument('--num_samples', type=int, default=10_000)
	parser.add_argument('--version', type=int)
	parser.add_argument('--batch_size', type=int, default=512)
	parser.add_argument('--ddim', type=int, default=1)
	parser.add_argument('--ddim_numsteps', type=int, default=100)
	parser.add_argument('--device', type=int, required=True)
	parser.add_argument('--save_loc', type=str, required=True)
	args = parser.parse_args()

	ema = load_model(args.ckpt, args.version, args.ddim, args.ddim_numsteps)
	device = 'cuda:%s' % args.device


	for c in [0,1]:
		print("Working on C=%s" % c)
		sample(ema, os.path.join(args.save_loc, 'c%s' % c), device, args.batch_size,
			   args.num_samples, c=c, pfx='c%s' % c)

if __name__ == '__main__':
	main()