import numpy as np
from copy import deepcopy
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import mean_squared_error
from sklearn.base import clone
from econml.metalearners import XLearner
from lightgbm import LGBMRegressor

class EconMLXLearner:
    def __init__(self,
                 model_mu0=None,
                 model_mu1=None,
                 cate_model_0=None,
                 cate_model_1=None,
                 propensity_model=None,
                 categories='auto',
                 allow_missing=False,
                 random_state=42):

        base = LGBMRegressor()
        self.model_mu0 = model_mu0 or clone(base)
        self.model_mu1 = model_mu1 or clone(base)
        self.cate_model_0 = cate_model_0 or clone(base)
        self.cate_model_1 = cate_model_1 or clone(base)
        self.propensity_model = propensity_model or LogisticRegression(max_iter=1000)
        self.categories = categories
        self.allow_missing = allow_missing
        self.random_state = random_state
        self._econml = None

    @staticmethod
    def _pehe(tau_hat, mu0=None, mu1=None):
        if mu0 is None or mu1 is None:
            return np.nan
        tau_true = np.asarray(mu1) - np.asarray(mu0)
        return float(np.sqrt(np.mean((tau_hat - tau_true) ** 2)))

    @staticmethod
    def _abs_ate_error(tau_hat, mu0=None, mu1=None):
        tau_true = np.asarray(mu1) - np.asarray(mu0)
        return float(abs(np.mean(tau_hat) - np.mean(tau_true)))
    
    @staticmethod
    def _rel_ate_error(tau_hat, mu0=None, mu1=None):
        tau_true = np.asarray(mu1) - np.asarray(mu0)
        return float(abs(np.mean(tau_hat) - np.mean(tau_true)))/abs(np.mean(tau_true))+1e-8

    def _factual_rmse(self, X, T, Y):
        mu0, mu1 = self.predict_mu(X)
        mt = np.where(np.asarray(T).reshape(-1) == 1, mu1, mu0)
        return float(np.sqrt(mean_squared_error(Y, mt)))

    def att_abs_error_rct(self, X, T, Y, e):
        idx = np.where(e == 1)[0]
        treated_rct = idx[T[idx] == 1]
        control_rct = idx[T[idx] == 0]
        att_true = float(np.mean(Y[treated_rct]) - np.mean(Y[control_rct]))
        tau_hat_treated = self.predict_tau(X[treated_rct])
        att_hat = float(np.mean(tau_hat_treated))
        return abs(att_hat - att_true)

    def policy_risk_rct(self, X, T, Y, e, lam=0.0):
        idx = np.where(e == 1)[0]
        tau_hat = self.predict_tau(X[idx])
        pi = (tau_hat > lam).astype(int)
        p1 = np.mean(pi == 1)
        p0 = 1.0 - p1
        mask_treat = (T[idx] == 1)
        mask_ctrl = (T[idx] == 0)
        ytreat_pi1 = Y[idx][(pi == 1) & mask_treat]
        yctrl_pi0 = Y[idx][(pi == 0) & mask_ctrl]
        Ey1_pi1 = float(np.mean(ytreat_pi1))
        Ey0_pi0 = float(np.mean(yctrl_pi0))
        value = Ey1_pi1 * p1 + Ey0_pi0 * p0
        return 1.0 - value

    def fit(self, X_train, T_train, Y_train):
        X = np.asarray(X_train); T = np.asarray(T_train); Y = np.asarray(Y_train)
        self._econml = XLearner(
            models=(deepcopy(self.model_mu0), deepcopy(self.model_mu1)),
            cate_models=(deepcopy(self.cate_model_0), deepcopy(self.cate_model_1)),
            propensity_model=deepcopy(self.propensity_model),
            categories=self.categories,
            allow_missing=self.allow_missing
        )
        self._econml.fit(Y, T, X=X)
        return self

    def predict_tau(self, X):
        return self._econml.effect(np.asarray(X)).reshape(-1)

    def predict_mu(self, X):
        X = np.asarray(X)
        mu0 = self._econml.models[0].predict(X)  
        mu1 = self._econml.models[1].predict(X)  
        return mu0, mu1

    def evaluate(self, X, T=None, Y=None, e=None, mu0=None, mu1=None):
        metrics = {}
        X = np.asarray(X)
        tau_hat = self.predict_tau(X)

        if (mu0 is not None) and (mu1 is not None):
            metrics["PEHE"] = self._pehe(tau_hat, mu0, mu1)
            metrics["ATE_abs_error"] = self._abs_ate_error(tau_hat, mu0, mu1)
            metrics["rel_ATE_error"] = self._rel_ate_error(tau_hat, mu0, mu1)

        if (e is not None) and (T is not None) and (Y is not None):
            metrics["ATT_abs_error_rct"] = self.att_abs_error_rct(X, T, Y, e)
            metrics["policy_risk_rct"] = self.policy_risk_rct(X, T, Y, e)

        return metrics
