import torch
import numpy as np


def create_atd(do_var, target_var, var_dims, delta, do_dim=None):
    target_start = np.sum(var_dims[: target_var])
    target_end = np.sum(var_dims[: (target_var + 1)])

    do_start = np.sum(var_dims[: do_var])
    do_end = np.sum(var_dims[: (do_var + 1)])

    def param_fn(z, generator, device, data):
        if do_dim is not None:
            d = torch.zeros(var_dims[do_var])
            d[do_dim] = delta
        else:
            d = torch.ones(var_dims[do_var]) * delta

        zero = data[:, do_start:do_end]
        one = data[:, do_start:do_end] + d.to(device)
        cntf0 = generator.do(z, zero.to(device), do_var, data=data)[:, target_start:target_end]
        cntf1 = generator.do(z, one.to(device), do_var, data=data)[:, target_start:target_end]
        return (cntf1 - cntf0) / delta

    return param_fn


def create_uniform_gauss_atd(do_var, target_var, var_dims, delta, d1_value, d0_value):
    target_start = np.sum(var_dims[: target_var])
    target_end = np.sum(var_dims[: (target_var + 1)])

    do_start = np.sum(var_dims[: do_var])
    do_end = np.sum(var_dims[: (do_var + 1)])

    def param_fn(z, generator, device, data):
        mean = (d1_value + d0_value) / 2
        std = (d1_value - d0_value) / 2
        treatments = torch.randn(size=(z.shape[0], 1)) * std + mean
        within_interval = (treatments[:, 0] <= d1_value) & (treatments[:, 0] >= d0_value)
        uniform_noises = torch.rand(size=(z.shape[0], 1)) * (d1_value - d0_value) + d0_value
        treatments[within_interval, :] = uniform_noises[within_interval, :]
        one = treatments + delta
        zero = treatments
        cntf1 = generator.do(z, one.to(device), do_var, data=data)[:, target_start:target_end]
        cntf0 = generator.do(z, zero.to(device), do_var, data=data)[:, target_start:target_end]
        return (cntf1 - cntf0) / delta

    return param_fn


def create_discrete_cntf(do_var, target_var, var_dims, value):
    target_start = np.sum(var_dims[: target_var])
    target_end = np.sum(var_dims[: (target_var + 1)])

    def param_fn(z, generator, device, data=None):
        val = torch.Tensor([value + 0.])
        cntf = generator.do(z, val.to(device), do_var, data)[:, target_start:target_end]
        return cntf

    return param_fn


def create_discrete_ate(do_var, target_var, var_dims):
    target_start = np.sum(var_dims[: target_var])
    target_end = np.sum(var_dims[: (target_var + 1)])

    def param_fn(z, generator, device, data):
        cntf0 = generator.do(z, torch.Tensor([1., 0.]).to(device), do_var, data)[:, target_start:target_end]
        cntf1 = generator.do(z, torch.Tensor([0., 1.]).to(device), do_var, data)[:, target_start:target_end]
        return cntf1 - cntf0

    return param_fn

