import numpy as np
from copy import deepcopy
from sklearn.base import clone
from lightgbm import LGBMRegressor
from econml.metalearners import TLearner


class EconMLTLearner:

    def __init__(self,
                 model_control=None,
                 model_treated=None,
                 categories='auto',
                 allow_missing=False,
                 random_state=42):

        base = LGBMRegressor(
            n_estimators=100, max_depth=5, learning_rate=0.1, random_state=random_state
        )
        self.model_control = model_control or clone(base)
        self.model_treated = model_treated or clone(base)

        self.categories = categories
        self.allow_missing = allow_missing
        self.random_state = random_state
        self._econml = None  

    @staticmethod
    def _pehe(tau_hat, mu0=None, mu1=None):
        if mu0 is None or mu1 is None:
            return np.nan
        tau_true = np.asarray(mu1) - np.asarray(mu0)
        return float(np.sqrt(np.mean((tau_hat - tau_true) ** 2)))

    @staticmethod
    def _abs_ate_error(tau_hat, mu0=None, mu1=None):
        if mu0 is None or mu1 is None:
            return np.nan
        tau_true = np.asarray(mu1) - np.asarray(mu0)
        return float(abs(np.mean(tau_hat) - np.mean(tau_true)))

    @staticmethod
    def _rel_ate_error(tau_hat, mu0=None, mu1=None):
        if mu0 is None or mu1 is None:
            return np.nan
        tau_true = np.asarray(mu1) - np.asarray(mu0)
        denom = abs(np.mean(tau_true)) + 1e-8
        return float(abs(np.mean(tau_hat) - np.mean(tau_true)) / denom)

    def att_abs_error_rct(self, X, T, Y, e):
        idx = np.where(e == 1)[0]
        if len(idx) == 0:
            return np.nan
        treated_rct = idx[T[idx] == 1]
        control_rct = idx[T[idx] == 0]
        if len(treated_rct) == 0 or len(control_rct) == 0:
            return np.nan

        att_true = float(np.mean(Y[treated_rct]) - np.mean(Y[control_rct]))
        tau_hat_treated = self.predict_tau(X[treated_rct])
        att_hat = float(np.mean(tau_hat_treated))
        return abs(att_hat - att_true)

    def policy_risk_rct(self, X, T, Y, e, lam=0.0):
        idx = np.where(e == 1)[0]
        if len(idx) == 0:
            return np.nan
        tau_hat = self.predict_tau(X[idx])
        pi = (tau_hat > lam).astype(int)
        p1 = np.mean(pi == 1)
        p0 = 1.0 - p1
        mask_treat = (T[idx] == 1)
        mask_ctrl = (T[idx] == 0)
        ytreat_pi1 = Y[idx][(pi == 1) & mask_treat]
        yctrl_pi0 = Y[idx][(pi == 0) & mask_ctrl]

        Ey1_pi1 = float(np.mean(ytreat_pi1))
        Ey0_pi0 = float(np.mean(yctrl_pi0))
        value = Ey1_pi1 * p1 + Ey0_pi0 * p0
        risk = 1.0 - value
        return risk

    def fit(self, X_train, T_train, Y_train):
        X = np.asarray(X_train)
        T = np.asarray(T_train).reshape(-1)
        Y = np.asarray(Y_train)

        self._econml = TLearner(
            models=(deepcopy(self.model_control), deepcopy(self.model_treated)),
            categories=self.categories,
            allow_missing=self.allow_missing
        )
        self._econml.fit(Y, T, X=X)
        return self

    def predict_tau(self, X):
        return self._econml.effect(np.asarray(X)).reshape(-1)

    def predict_mu(self, X):
        X = np.asarray(X)
        mu0 = self._econml.models[0].predict(X)
        mu1 = self._econml.models[1].predict(X)
        return mu0, mu1

    def evaluate(self, X, T=None, Y=None, e=None, mu0=None, mu1=None):
        metrics = {}
        X = np.asarray(X)
        tau_hat = self.predict_tau(X)

        if (mu0 is not None) and (mu1 is not None):
            metrics["PEHE"] = self._pehe(tau_hat, mu0, mu1)
            metrics["ATE_error"] = self._abs_ate_error(tau_hat, mu0, mu1)
            metrics["rel_ATE_error"] = self._rel_ate_error(tau_hat, mu0, mu1)

        if (e is not None) and (T is not None) and (Y is not None):
            metrics["ATT_abs_error_rct"] = self.att_abs_error_rct(X, T, Y, e)
            metrics["policy_risk_rct"] = self.policy_risk_rct(X, T, Y, e)

        return metrics
