from copy import deepcopy
import numpy as np
from sklearn.linear_model import LogisticRegression, LinearRegression
from xgboost import XGBRegressor
from sklearn.model_selection import KFold

class BaseSRegressor():
    def __init__(
    self,
    learner=XGBRegressor()
    ):
        self.model = deepcopy(learner)

    def fit(self, X, a, y):
        self.model.fit(np.column_stack((X, a)), y)

    def predict(self, X):
        n = X.shape[0]
        tau1 = self.model.predict(np.column_stack((X, np.ones(n))))
        tau0 = self.model.predict(np.column_stack((X, np.zeros(n))))
        return tau1 - tau0

class BaseTRegressor():
    def __init__(
    self,
    learner = XGBRegressor(),
    ):
        self.model_t = deepcopy(learner)
        self.model_c = deepcopy(learner)

    def fit(self, X, a, y):
        self.model_t.fit(X[a==1], y[a==1])
        self.model_c.fit(X[a==0], y[a==0])

    def predict(self, X, return_element = False):
        tau1 = self.model_t.predict(X)
        tau0 = self.model_c.predict(X)
        if return_element:
            return tau1, tau0
        else:
            return tau1 - tau0

class BaseXRegressor():
    def __init__(
    self,
    outcome_learner = XGBRegressor(),
    effect_learner = LinearRegression(),
    p_learner = LogisticRegression()
    ):
        self.outcome_t = deepcopy(outcome_learner)
        self.outcome_c = deepcopy(outcome_learner)
        self.model_t = deepcopy(effect_learner)
        self.model_c = deepcopy(effect_learner)
        self.ps = deepcopy(p_learner)

    def fit(self, X, a, y):
        self.outcome_t.fit(X[a==1], y[a==1])
        self.outcome_c.fit(X[a==0], y[a==0])
        delta1 = y[a==1] - self.outcome_c.predict(X[a==1])
        delta0 = self.outcome_t.predict(X[a==0]) - y[a==0]
        self.model_t.fit(X[a==1], delta1)
        self.model_c.fit(X[a==0], delta0)
        self.ps.fit(X, a)

    def predict(self, X):
        delta1 = self.model_t.predict(X)
        delta0 = self.model_c.predict(X)
        p = self.ps.predict_proba(X)[:,1]
        return p * delta1 + (1 - p) * delta0

class BaseDRRegressor():
    def __init__(
    self,
    outcome_learner = XGBRegressor(),
    effect_learner = LinearRegression(),
    p_learner = LogisticRegression()
    ):
        self.outcome_t = deepcopy(outcome_learner)
        self.outcome_c = deepcopy(outcome_learner)
        self.model = deepcopy(effect_learner)
        self.ps = deepcopy(p_learner)

    def fit(self, X, a, y, p_trim = 0.05):
        pseudo = np.zeros(len(a))
        kf = KFold(n_splits=2)
        for train_id, test_id in kf.split(X):

            train_id_1 = train_id[a[train_id] == 1]
            train_id_0 = train_id[a[train_id] == 0]
            self.outcome_t.fit(X[train_id_1], y[train_id_1])
            self.outcome_c.fit(X[train_id_0], y[train_id_0])
            tau1 = self.outcome_t.predict(X[test_id])
            tau0 = self.outcome_c.predict(X[test_id])

            self.ps.fit(X[train_id], a[train_id])
            p = self.ps.predict_proba(X[test_id])[:,1]
            p = np.maximum(p, p_trim)
            p = np.minimum(p, 1-p_trim)

            tau = a[test_id] * tau1 + (1-a[test_id]) * tau0
            pseudo[test_id] = (a[test_id] - p) / p / (1 - p) * (y[test_id] - tau) + tau1 - tau0
        self.model.fit(X, pseudo)

    def predict(self, X):
        return self.model.predict(X)

class SurrogateRegressor():
    def __init__(
    self,
    outcome_learner = XGBRegressor(),
    effect_learner = LinearRegression(),
    p_learner = LogisticRegression()
    ):
        self.impute = deepcopy(outcome_learner)
        self.outcome_t = deepcopy(outcome_learner)
        self.outcome_c = deepcopy(outcome_learner)
        self.model = deepcopy(effect_learner)
        self.ps = deepcopy(p_learner)
        self.rho = deepcopy(p_learner)

    def fit(self, X, a, s, y, r, p_trim = 0.05):
        pseudo = np.zeros(len(a))
        kf = KFold(n_splits=2)
        for train_id, test_id in kf.split(X):
            train_id_r = train_id[r[train_id] == 1]
            X0 = np.column_stack((a, s, X))
            self.impute.fit(X0[train_id_r], y[train_id_r])
            mu = self.impute.predict(X0)

            self.rho.fit(X0[train_id], r[train_id])
            rho = self.rho.predict_proba(X0[test_id])[:,1]
            rho = np.maximum(rho, p_trim)

            train_id_1 = train_id[a[train_id] == 1]
            train_id_0 = train_id[a[train_id] == 0]
            self.outcome_t.fit(X[train_id_1], mu[train_id_1])
            self.outcome_c.fit(X[train_id_0], mu[train_id_0])
            tau1 = self.outcome_t.predict(X[test_id])
            tau0 = self.outcome_c.predict(X[test_id])

            self.ps.fit(X[train_id], a[train_id])
            p = self.ps.predict_proba(X[test_id])[:,1]
            p = np.maximum(p, p_trim)
            p = np.minimum(p, 1-p_trim)

            tau = a[test_id] * tau1 + (1-a[test_id]) * tau0
            y0 = r[test_id] / rho * (y[test_id] - mu[test_id]) + mu[test_id]
            pseudo[test_id] = (a[test_id] - p) / p / (1 - p) * (y0 - tau) + tau1 - tau0
        self.model.fit(X, pseudo)

    def predict(self, X):
        return self.model.predict(X)

    def predict_nuisance(self, X, a, s):
        X0 = np.column_stack((a, s, X))
        out = {'mu': self.impute.predict(X0),
               'rho': self.rho.predict_proba(X0)[:,1],
               'tau': a * self.outcome_t.predict(X) + (1-a) * self.outcome_c.predict(X),
               'pi': self.ps.predict_proba(X)[:,1]}
        return out
