"""
Actual code for sampling interventional distributions
Recall that the formula is 
P_C(N) := SUM_x [ P(X|C) SUM_c' P(N|C', X) P(C')]


So need to sample:
 	P(X|C) (diffusion)
 	P(C') (bernoulli from data)
 	P(N|C', X) (bernoulli from classifier output)
"""

import torch
import cxray.train_pX_lucidrains2 as lr2
import cxray.train_pN_givenCX_calibrated as tpn
import torch
import torch.nn.functional as F

import matplotlib.pyplot as plt
from torchvision.utils import make_grid
from torchvision import transforms
import cxray.cxray_dataset as cxds
from torch.utils.data import DataLoader
from collections import defaultdict, Counter
from tqdm.auto import tqdm
import argparse

import random

# =============================================
# =               Load models                 =
# =============================================

def load_pX_givenC(ckpt, device):
	ckpt = torch.load(ckpt, map_location='cpu')
	model = lr2.Unet(dim=64, channels=1, dim_mults=(1,2,4,8),
		 			 num_classes=2, cond_drop_prob=0.2)
	diffusion = lr2.GaussianDiffusion(model, image_size=128,
									  timesteps=1000,
									  sampling_timesteps=100)
	ema = lr2.EMA(diffusion, beta=0.995, update_every=10)
	ema.load_state_dict(ckpt['ema'])
	ema = ema.eval().to(device)

	return ema

def load_pX_givenC_obs(root_dir):
	dataset = cxds.CXRayDataset(root_dir, split='val', transform=transforms.ToTensor())
	split_by_c = {0: [], 1: []}
	dataloader = DataLoader(dataset, batch_size=512, num_workers=8, shuffle=False, drop_last=False)
	for batch in tqdm(dataloader):
		for c, x in zip(batch['C'], batch['X']):
			split_by_c[c.item()].append(x)

	return split_by_c



def load_pN_givenCX(ckpt, device):
	ckpt = torch.load(ckpt, map_location='cpu')
	model = tpn.Classifier(hdim=64)
	model.load_state_dict(ckpt['net'])
	model = model.eval().to(device)

	if 't' in ckpt:
		print("SETTING TEMP TO ", ckpt['t'])
		model.temperature = ckpt['t']
	else:
		model.temperature = 1.0
	return model

	return sample_pN_givenCX


def load_pC(root_dir):

	dataset = cxds.CXRayDataset(root_dir, split='train', transform=transforms.ToTensor())
	dataloader = DataLoader(dataset, batch_size=512, num_workers=8, shuffle=False, drop_last=False)

	c_counter = defaultdict(int)
	for batch in dataloader:
		for c in batch['C']:
			c_counter[c.item()] += 1

	p_c1 = c_counter[1] / sum(c_counter.values())
	return p_c1




# ==============================================
# =           Sample block                     =
# ==============================================

@torch.no_grad()
def sample_X_givenC(C, ema):
	""" Can interact with this function to get actual images for display in the paper"""
	ema = ema.eval()
	return ema.ema_model.sample(classes=C, cond_scale=1) # CHECK COND_SCALE HERE

def sample_X_givenC_obs(C, split_by_c):
	output = []
	for c in C:
		output.append(random.choice(split_by_c[c.item()]))
	return torch.stack(output).to(C.device)


@torch.no_grad()
def sample_N_givenCX(C,X, model):
	model = model.eval()
	pred = torch.div(model(X, C), model.temperature)
	pred = F.softmax(pred, dim=1)[:,1]
	return (torch.rand_like(pred) < pred).to(torch.int).long()

	# return (torch.rand_like(pred) < pred).to(torch.int).long()


@torch.no_grad()
def sample_C(p_C, bsz, device):
	return (torch.rand(bsz) < p_C).to(torch.int).long().to(device)

@torch.no_grad()
def sample_N_doC(pX_ckpt, pN_ckpt, root_dir, device, num_samples, batch_size, observational):
	"""
	Returns {C: [p_C(N=0), p_C(N=1)]}
	"""

	# Load all models
	if observational == 0:
		ema = load_pX_givenC(pX_ckpt, device)
		sample_X = lambda c_in: sample_X_givenC(c_in, ema)
	else:
		split_by_c = load_pX_givenC_obs(root_dir)
		sample_X = lambda c_in: sample_X_givenC_obs(c_in, split_by_c)

	ema = load_pX_givenC(pX_ckpt, device)
	classifier = load_pN_givenCX(pN_ckpt, device)
	p_c1 = load_pC(root_dir)


	# Do both C's here
	counters = {}
	for c in range(2):
		cur_counter = defaultdict(int)
		count = 0
		iterator = tqdm(total=num_samples)
		while count < num_samples:
			cur_batch = min(batch_size, num_samples - count)
			c_in = (torch.ones(cur_batch) * c).long().to(device)
			x = sample_X(c_in)
			c_prime = sample_C(p_c1, cur_batch, device)
			N_counter = Counter(sample_N_givenCX(c_prime, x, classifier).view(-1).cpu().numpy())
			cur_counter[0] += N_counter[0]
			cur_counter[1] += N_counter[1]

			count += cur_batch
			iterator.update(cur_batch)

		counters[c] = cur_counter

	output = {}
	for c, d in counters.items():
		d_total = sum(d.values())
		output[c] = {k: v / d_total for k,v in d.items()}
	return output




# ==================================
# =           Main block           =
# ==================================

def main():
	parser = argparse.ArgumentParser()
	parser.add_argument('--pX_ckpt', type=str, required=True)
	parser.add_argument('--pN_ckpt', type=str, required=True)
	parser.add_argument('--root_dir', type=str, required=True)
	parser.add_argument('--device', type=int, required=True)
	parser.add_argument('--num_samples', default=5000, type=int)
	parser.add_argument('--batch_size', type=int, default=512)
	parser.add_argument('--observational', type=int, default=0)

	args = parser.parse_args()

	device = 'cuda:%s' % args.device
	output = sample_N_doC(args.pX_ckpt,
				 args.pN_ckpt,
				 args.root_dir,
				 device,
				 args.num_samples,
				 args.batch_size,
				 args.observational)

	print(output)

if __name__ == '__main__':
	main()


