""" Main script to generate samples used for retraining
	
	Just a stack of {w2, x,y} sampled from p(w1),p(x,y|w1,w2)
	Notice that the w2 is only provided as a conditional above and can be valued arbitrarily.
	Our procedure here looks like:
	1. Choose some w2 arbitrarily (might as well use uniformly random w2 [which happens to be p(w2)]
	2. Sample w1 from data 
	3. sample (x,y) from trained model f_theta(w1,w2)= x',y' ~ p(x,y | w1, w2)
	4. Save (w2, x, y) 
"""
import torch
import pickle
from torch.utils.data import DataLoader
import argparse
import os
from tqdm.auto import tqdm

from cfg.dataloader_pickle import PickleDataset
from cfg.embedding import JointEmbedding2
from cfg.unet import Unet
from cfg.utils import get_named_beta_schedule
from cfg.diffusion import GaussianDiffusion

# =============================================
# =           Loading blocks                  =
# =============================================



def load_data(pkl_loc, batch_size):
	""" Loads the observational data from a pickle file """
	dataset = PickleDataset(pkl_file=pkl_loc)
	loader = DataLoader(dataset, num_workers=8, batch_size=batch_size,
						shuffle=True, drop_last=False)
	return loader


def load_model(diffuser_loc, device, w=0.0):
	""" Loads the architecture for the diffusion model we've already trained """
	net = Unet(
		in_ch=6,
		mod_ch=64,
		out_ch=6,
		ch_mul=[1,2,2,2],
		num_res_blocks=2,
		cdim=64, 
		use_conv=True,
		droprate=0,
		dtype=torch.float32).to(device)

	checkpoint = torch.load(diffuser_loc, map_location='cpu')
	out = net.load_state_dict(checkpoint['net'])
	assert len(out.missing_keys) + len(out.unexpected_keys) == 0

	betas = get_named_beta_schedule(num_diffusion_timesteps=1000)
	diffusion = GaussianDiffusion(
					dtype=torch.float32,
					model=net,
					betas=betas,
					w=w, # truly conditional sampling with no upweighting
					v=1.0,
					device=device)
	cemblayer = JointEmbedding2(num_labels_0=10, num_labels_1=2,
                               d_model=64, channels=3,
                               dim=64, hw=32).to(device)
	out = cemblayer.load_state_dict(checkpoint['cemblayer'])
	assert len(out.missing_keys) + len(out.unexpected_keys) == 0

	diffusion.model.eval()
	cemblayer.eval()
	return {'diffusion': diffusion, 'cemblayer': cemblayer}



# ========================================================
# =           Sample 1/multiple batches                  =
# ========================================================
@torch.no_grad()
def sample_from_batch(batch, diffuser, device, ddim=True, fix='perm'):

	bsz = batch['W2a'].shape[0]

	# 1: Get uniformly random w2
	#w2a = batch['W2a'].to(device)
	w2a = torch.randint(low=0, high=10, size=(bsz,)).to(device)
	#w2b = batch['W2b'].to(device)
	w2b = torch.randint(low=0, high=2, size=(bsz,)).to(device)

	# 2: Take w1 from data 
	w1 = batch['W1'].to(device)

	# 3. Sample from trained P(x,y | w1, w2)
	cemb = diffuser['cemblayer'](w1, w2a, w2b)

	if ddim:
		generated = diffuser['diffusion'].ddim_sample((bsz, 6, 32, 32), 50, 0, 'linear', cemb=cemb, disable_tqdm=True)
	else:
		generated = diffuser['diffusion'].sample((bsz, 6, 32, 32), cemb=cemb, disable_tqdm=True)

	x, y = generated[:, :3, :, :], generated[:, 3:, :, :]

	# Return (w1, w2a, w2b, x, y)
	return (w1.cpu(), w2a.cpu(), w2b.cpu(), x.cpu(), y.cpu())





def sample_batches(dataloader, diffuser, n_samples, device, ddim=True,):

	data = {'X': [], 'Y': [], 'W1': [], 'W2a': [], 'W2b': []}
	count = 0
	iterator = tqdm(total=n_samples)
	while count < n_samples:
		for batch in dataloader:
			w1, w2a, w2b, x,y = sample_from_batch(batch, diffuser, device, ddim=ddim)
			data['W1'].append(w1)
			data['W2a'].append(w2a)
			data['W2b'].append(w2b)
			data['X'].append(x)
			data['Y'].append(y)
			count += x.shape[0]
			iterator.update(n=x.shape[0])
			if count >= n_samples:
				break

	return {k: torch.cat(v) for k,v in data.items()}


def save_datadict(data_dict, save_dir):
	os.makedirs(save_dir, exist_ok=True)
	save_loc = os.path.join(save_dir, 'synthetic_W1W2XY.pkl')
	with open(save_loc, 'wb') as f:	
		pickle.dump(data_dict, f)


# ===========================================
# =           Main block                    =
# ===========================================


def main():
	parser = argparse.ArgumentParser(description='script to generate samples for retraining purposes')
	parser.add_argument('--pkl_loc', type=str, required=True, help='location of pickle training data file')
	parser.add_argument('--diffuser_loc', type=str, required=True, help='location of saved diffusion model')
	parser.add_argument('--n_samples', type=int, default=10_000, help='how many samples to generate')
	parser.add_argument('--batch_size', type=int, default=256, help='batch size')
	parser.add_argument('--device', type=int, required=True, help='which gpu to use')
	parser.add_argument('--ddim', type=int, default=1, help='1 if we want to use ddim sampling, 0 ow')
	parser.add_argument('--w', type=float, default=0.0)
	parser.add_argument('--save_dir', type=str, required=True, help='location of where to save synthetic_W1W2XY.pkl dataset')

	params = parser.parse_args()
	params.ddim = bool(params.ddim)
	params.device = 'cuda:%s' % params.device


	dataloader = load_data(params.pkl_loc, params.batch_size)
	diffuser = load_model(params.diffuser_loc, params.device, w=params.w)
	data_dict = sample_batches(dataloader, diffuser, params.n_samples, params.device, params.ddim)
	save_datadict(data_dict, params.save_dir)


if __name__ == '__main__':
	main()