import numpy as np
from copy import deepcopy
from lightgbm import LGBMRegressor
from causalml.inference.meta import BaseSRegressor  

class CausalMLSLearner:
    def __init__(self, learner=None, control_name=0, random_state=42):
        self.control_name = control_name
        self.learner = learner or LGBMRegressor(
            n_estimators=100, max_depth=5, learning_rate=0.1, random_state=random_state
        )
        self.model_ = None
        self.treat_label_ = None 

    @staticmethod
    def _pehe(tau_hat, mu0=None, mu1=None):
        if mu0 is None or mu1 is None:
            return np.nan
        tau_true = np.asarray(mu1) - np.asarray(mu0)
        m = np.isfinite(tau_hat) & np.isfinite(tau_true)
        if not np.any(m): return np.nan
        return float(np.sqrt(np.mean((tau_hat[m] - tau_true[m]) ** 2)))

    @staticmethod
    def _abs_ate_error(tau_hat, mu0=None, mu1=None):
        if mu0 is None or mu1 is None:
            return np.nan
        tau_true = np.asarray(mu1) - np.asarray(mu0)
        return float(abs(np.mean(tau_hat) - np.mean(tau_true)))

    @staticmethod
    def _rel_ate_error(tau_hat, mu0=None, mu1=None):
        if mu0 is None or mu1 is None:
            return np.nan
        tau_true = np.asarray(mu1) - np.asarray(mu0)
        denom = abs(np.mean(tau_true)) + 1e-8
        return float(abs(np.mean(tau_hat) - np.mean(tau_true)) / denom)

    def fit(self, X, T, Y):
        self.model_ = BaseSRegressor(learner=deepcopy(self.learner), control_name=self.control_name)
        self.model_.fit(X=np.asarray(X), treatment=np.asarray(T), y=np.asarray(Y))
        self.treat_label_ = int(self.model_.t_groups[0])
        return self

    def predict_tau(self, X):
        te = self.model_.predict(np.asarray(X)) 
        return te.reshape(-1)

    def predict_mu(self, X):
        te, yhat_cs, yhat_ts = self.model_.predict(
            np.asarray(X), return_components=True
        )
        k = self.treat_label_
        mu0 = np.asarray(yhat_cs[k]) 
        mu1 = np.asarray(yhat_ts[k]) 
        return mu0, mu1

    def att_abs_error_rct(self, X, T, Y, e):
        idx = np.where(e == 1)[0]
        if len(idx) == 0: return np.nan
        treated = idx[T[idx] == self.treat_label_]
        control = idx[T[idx] == self.control_name]
        if len(treated) == 0 or len(control) == 0: return np.nan
        att_true = float(np.mean(Y[treated]) - np.mean(Y[control]))
        att_hat = float(np.mean(self.predict_tau(X[treated])))
        return abs(att_hat - att_true)

    def policy_risk_rct(self, X, T, Y, e, lam=0.0):
        idx = np.where(e == 1)[0]
        if len(idx) == 0: return np.nan
        tau_hat = self.predict_tau(X[idx])
        pi = (tau_hat > lam).astype(int)  # treat if tau>lam
        mask_treat = (T[idx] == self.treat_label_)
        mask_ctrl  = (T[idx] == self.control_name)
        y1 = Y[idx][(pi == 1) & mask_treat]
        y0 = Y[idx][(pi == 0) & mask_ctrl]
        if len(y1) == 0 or len(y0) == 0: return np.nan
        p1 = float(np.mean(pi == 1)); p0 = 1.0 - p1
        value = float(np.mean(y1)) * p1 + float(np.mean(y0)) * p0
        return 1.0 - value

    def evaluate(self, X, T=None, Y=None, e=None, mu0=None, mu1=None):
        out = {}
        tau_hat = self.predict_tau(X)
        if (mu0 is not None) and (mu1 is not None):
            out["PEHE"] = self._pehe(tau_hat, mu0, mu1)
            out["ATE_abs_error"] = self._abs_ate_error(tau_hat, mu0, mu1)
            out["rel_ATE_error"] = self._rel_ate_error(tau_hat, mu0, mu1)
        if (e is not None) and (T is not None) and (Y is not None):
            out["ATT_abs_error_rct"] = self.att_abs_error_rct(X, T, Y, e)
            out["policy_risk_rct"]   = self.policy_risk_rct(X, T, Y, e)
        return out
