import torch
import numpy as np
from experiment.data import MultipleSim_Separable, MultipleSim_NonSeparable, data_preprocess
from torch.utils.data import DataLoader


def pehe(y_cf_list, tau_true):
    """
    calculate PEHE (precision of estimation of heterogenoues effect)
    y_cf_list: list of counterfactual outcomes
    tau_true: true value of treatment effect
    """
    treat_level = tau_true.size()[1]  # dim of tau_true: data_size * treat_level
    tau_est = torch.zeros_like(tau_true)
    for i in range(treat_level):
        y_cf_contrast = y_cf_list[i] - y_cf_list[-1]
        # when outocme is binary, tau is P(Y(A)=1)
        if tau_est.size()[1] < y_cf_contrast.size()[1]:
            tau_est[:, i] = y_cf_contrast[:, i+1]
        # when outcome variable is continuous
        else:
            tau_est[:, i] = y_cf_contrast[..., i]
    pehe = (tau_est - tau_true).flatten().square().sum()
    return pehe


def ate_error(y_cf_list, ate_true, type='absolute'):
    """
    calculate ATE error
    y_cf_list: list of counterfactual outcomes
    tau_true: true value of treatment effect
    type: type of ate_error
        'absolute': MAE (mean absolute error)
        'square': MSE (mean square error)
    """
    treat_level = ate_true.size()[0]  # dim of ate_true: treat_level
    tau_est = torch.zeros_like(ate_true)
    for i in range(treat_level):
        y_cf_contrast = y_cf_list[i] - y_cf_list[-1]
        if tau_est.size()[0] < y_cf_contrast.size()[1]:
            tau_est[i] = y_cf_contrast[:, i+1].flatten().mean()
        else:
            tau_est[i] = y_cf_contrast.flatten().mean()
    if type == 'absolute':
        ate_error = torch.abs(tau_est - ate_true)
    else:
        ate_error = torch.square(tau_est - ate_true)
    return ate_error


def log_likelihood_bin(y, p_pred):
    return y*torch.log(p_pred) + (1-y)*torch.log(p_pred)


def marginal_effect_eval(net, epsilon, epsilon_idx, y, y_pred, z_sd, **kwargs):
    # y and y_pred must be scaled back to the original sacle
    input_size, sample_size, confounder_size, seed = kwargs["input_size"], kwargs["sample_size"], kwargs["confounder_size"], kwargs["seed"]
    train_size, val_size = kwargs["train_size"], kwargs["val_size"]
    sim_type = kwargs["sim_type"]

    # load counterfactual data
    if sim_type == 'separable':
        data_cf = MultipleSim_Separable(input_size, sample_size, confounder_size, seed, epsilon_idx)
    elif sim_type == 'nonseparable':
        data_cf = MultipleSim_NonSeparable(input_size, sample_size, confounder_size, seed, epsilon_idx)

    _, _, test_set, _, y_scalar_cf, _, _, _ = data_preprocess(data_cf, 
                                                              partition_seed=1, 
                                                              x_scale=False, y_scale=True, 
                                                              train_size=train_size, 
                                                              val_size=val_size, 
                                                              test_size=val_size)
    test_data = DataLoader(test_set, batch_size=val_size, shuffle=False)

    # calculate y_cf_pred
    for y_cf, a_cf in test_data:
        y_cf_pred = net.get_y_cf(a_cf, z_sd, kwargs["classification_flag"])

        y_cf = y_scalar_cf.inverse_transform(y_cf.detach().numpy())
        y_cf_pred = kwargs["y_scalar"].inverse_transform(y_cf_pred.detach().numpy())

    # estimate marginal effect
    marginal_effect_est = (y_cf_pred - y_pred)/epsilon

    # true marginal effect
    marginal_effect_true = (y_cf - y)/epsilon

    # persentage of error
    error_percent = abs((marginal_effect_true - marginal_effect_est)/marginal_effect_true).mean()

    # mean of error
    error_mean = abs(marginal_effect_true - marginal_effect_est).mean()

    # accuracy of direction
    dir_cor = ((marginal_effect_est * marginal_effect_true) > 0).sum()/val_size

    return marginal_effect_est, marginal_effect_true, error_percent, error_mean, dir_cor
