import torch
import numpy as np

@torch.no_grad()
def compute_pehe_score(yf_hat_valid, ycf_hat_valid, treatment, mu0, mu1):
	mu0 = mu0.cpu().numpy()
	mu1 = mu1.cpu().numpy()
	eff = (mu1 - mu0).reshape(-1)
	t = treatment.squeeze().cpu().numpy()
	eff_pred = (ycf_hat_valid.reshape(-1) - yf_hat_valid.reshape(-1)).cpu().numpy()
	eff_pred[t > 0] = -eff_pred[t > 0]
	pehe = np.sqrt(np.mean(np.square(eff_pred - eff)))
	return pehe

def policy_range(n, res=10):
	step = int(float(n) / float(res))
	n_range = range(0, int(n + 1), step)
	if not n_range[-1] == n:
		n_range.append(n)

	while len(n_range) > res:
		k = np.random.randint(len(n_range) - 2) + 1
		del n_range[k]

	return n_range


def policy_val(t, yf, eff_pred, compute_policy_curve=False):
	""" Computes the value of the policy defined by predicted effect """

	if np.any(np.isnan(eff_pred)):
		return np.nan, np.nan
	policy = eff_pred > 0
	treat_overlap = (policy == t) * (t > 0)
	control_overlap = (policy == t) * (t < 1)

	if np.sum(treat_overlap) == 0:
		treat_value = 0
	else:
		treat_value = np.mean(yf[treat_overlap])

	if np.sum(control_overlap) == 0:
		control_value = 0
	else:
		control_value = np.mean(yf[control_overlap])

	pit = np.mean(policy)
	policy_value = pit * treat_value + (1 - pit) * control_value

	policy_curve = []

	return policy_value, policy_curve

@torch.no_grad()
def compute_policy_risk(yf_hat, ycf_hat, treatment,  e, yf):
	assert len(yf_hat) == len(ycf_hat) == len(treatment) == len(e)

	t = treatment.cpu().numpy()
	e = e.cpu().numpy()
	eff_pred = (ycf_hat - yf_hat).cpu().numpy()
	yf = yf.cpu().numpy()
	eff_pred[t > 0] = -eff_pred[t > 0]
	policy_value, policy_curve = \
		policy_val(t[e > 0], yf[e > 0], eff_pred[e > 0], False)
	return 1 - policy_value

@torch.no_grad()
def compute_ate(yf_hat, ycf_hat, treatment, mu0, mu1):
	assert len(yf_hat) == len(ycf_hat) == len(treatment) == len(mu0) == len(mu1)
	eff = (mu1 - mu0).cpu().numpy()
	t = treatment.cpu().numpy().reshape(-1)
	eff_pred = (ycf_hat - yf_hat).cpu().numpy()
	eff_pred[t > 0] = -eff_pred[t > 0]
	ate_pred = np.mean(eff_pred)
	bias_ate = np.abs(ate_pred - np.mean(eff))
	return bias_ate


