import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.linear_model import LinearRegression
from utils import MLP2, Exponential_regression
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def aipw_borrow(data, ps_model, mu0_model, mu1_model, batch_size, dataset):
    X = data['X']
    A = data['A']
    Y = data['Y']
    R = data['R']

    # Constructed propensity score
    pi_model_ = LogisticRegression(C=1, solver='lbfgs', max_iter=1000).fit(X, R)
    pi = pi_model_.predict_proba(X)[:, 1]
    pi = np.clip(pi, 0.1, 0.9)

    q = np.mean(R)

    ps_model_ = LogisticRegression(C=1, solver='lbfgs', max_iter=1000).fit(X, A)
    ps_ = ps_model_.predict_proba(X)[:, 1]
    ps_ = np.clip(ps_, 0.1, 0.9)

    if dataset == 'NSW':
        mu0_model_ = MLP2(X.shape[1], 16, 1).to(device)
        mu0_model_.fit(X[A == 0], Y[A == 0], batch_size=32)
    elif dataset == "exp":
        mu0_model_ = Exponential_regression()
        mu0_model_.fit(X[A == 0], Y[A == 0])
    elif dataset == 'linear':
        mu0_model_ = LinearRegression()
        mu0_model_.fit(X[A == 0], Y[A == 0])

    mu1_ = mu1_model.predict(X)
    mu0_ = mu0_model_.predict(X)

    psi = pi/q * (R * A * (Y - mu1_) / ps_ - (1 - A) * (Y - mu0_) / (1 - ps_)) + R / q * (mu1_ - mu0_)

    tau_borrow = np.mean(psi)

    mu0_rct = mu0_model.predict(X)
    mu1_rct = mu1_
    ps_rct = ps_model.predict_proba(X)[:, 1]
    ps_rct = np.clip(ps_rct, 0.1, 0.9)

    A_rct = A
    Y_rct = Y

    chi = R / q * (A_rct * (Y_rct - mu1_rct) / ps_rct - (1 - A_rct) * (Y_rct - mu0_rct) / (1 - ps_rct)) + R / q * (mu1_rct - mu0_rct)

    tau_exp = np.mean(chi)

    bias = np.abs(tau_borrow - tau_exp)
    variance = np.var(psi) / X.shape[0]
    variance_1 = np.var(chi) / X.shape[0]
    se_1 = np.sqrt(variance_1)

    bv = bias ** 2 + variance
    se = np.sqrt(variance)
    ci_upper = tau_borrow + 1.96 * se / np.sqrt(len(X))
    ci_lower = tau_borrow - 1.96 * se / np.sqrt(len(X))
    CI = [ci_lower, ci_upper]

    return tau_borrow, tau_exp, bias, se_1, bv, se, CI


