import numpy as np
from copy import deepcopy
from sklearn.ensemble import RandomForestRegressor
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import PolynomialFeatures
from sklearn.metrics import mean_squared_error
from sklearn.base import clone

from econml.dr import DRLearner 

class EconMLDRLearner:
    def __init__(self,
                 model_regression=None,
                 model_propensity=None,
                 model_final=None,
                 cv=3,
                 use_poly2=True,
                 min_propensity=1e-6,
                 prop_clip=1e-3,
                 random_state=42,
                 discrete_outcome=False,
                 multitask_model_final=False):
        self.model_regression = model_regression or RandomForestRegressor(
            n_estimators=100, max_depth=5, min_samples_leaf=5, random_state=random_state
        )
        self.model_propensity = model_propensity or LogisticRegression(max_iter=1000)
        self.model_final = model_final or RandomForestRegressor(
            n_estimators=100, max_depth=5, min_samples_leaf=5, random_state=random_state
        )
        self.cv = cv
        self.use_poly2 = use_poly2
        self.min_propensity = min_propensity
        self.prop_clip = prop_clip
        self.random_state = random_state
        self.discrete_outcome = discrete_outcome
        self.multitask_model_final = multitask_model_final

        self._econml = None
        self._outcome_full = None  
        self._prop_full = None     

    @staticmethod
    def _stack_XT(X, T):
        T = T.reshape(-1, 1) if T.ndim == 1 else T
        return np.column_stack([X, T])

    def _fit_full_aux_models(self, X, T, Y):
        self._outcome_full = clone(self.model_regression)
        self._outcome_full.fit(self._stack_XT(X, T), Y)

        self._prop_full = clone(self.model_propensity)
        self._prop_full.fit(X, T)

    def _factual_rmse(self, X, T, Y):
        mt = self._outcome_full.predict(self._stack_XT(X, T))
        return float(np.sqrt(mean_squared_error(Y, mt)))

    @staticmethod
    def _pehe(tau_hat, mu0=None, mu1=None):
        if mu0 is None or mu1 is None:
            return np.nan
        tau_true = np.asarray(mu1) - np.asarray(mu0)
        return float(np.sqrt(np.mean((tau_hat - tau_true) ** 2)))

    @staticmethod
    def _abs_ate_error(tau_hat, mu0=None, mu1=None):
        if mu0 is None or mu1 is None:
            return np.nan
        tau_true = np.asarray(mu1) - np.asarray(mu0)
        return float(abs(np.mean(tau_hat) - np.mean(tau_true)))
    
    @staticmethod
    def _rel_ate_error(tau_hat, mu0=None, mu1=None):
        if mu0 is None or mu1 is None:
            return np.nan
        tau_true = np.asarray(mu1) - np.asarray(mu0)
        denom = abs(np.mean(tau_true)) + 1e-8
        return float(abs(np.mean(tau_hat) - np.mean(tau_true)) / denom)

    def att_abs_error_rct(self, X, T, Y, e):

        idx = np.where(e == 1)[0]
        if len(idx) == 0:
            return np.nan
        treated_rct = idx[T[idx] == 1]
        control_rct = idx[T[idx] == 0]
        att_true = float(np.mean(Y[treated_rct]) - np.mean(Y[control_rct]))

        tau_hat_treated = self.predict_tau(X[treated_rct])
        att_hat = float(np.mean(tau_hat_treated))

        return abs(att_hat - att_true)

    def policy_risk_rct(self, X, T, Y, e, lam=0.0):

        idx = np.where(e == 1)[0]
        tau_hat = self.predict_tau(X[idx])
        pi = (tau_hat > lam).astype(int)
        p1 = np.mean(pi == 1)
        p0 = 1.0 - p1
        mask_treat = (T[idx] == 1)
        mask_ctrl  = (T[idx] == 0)
        ytreat_pi1 = Y[idx][(pi == 1) & mask_treat]
        yctrl_pi0  = Y[idx][(pi == 0) & mask_ctrl]
        if len(ytreat_pi1) == 0 or len(yctrl_pi0) == 0:
            return np.nan
        Ey1_pi1 = float(np.mean(ytreat_pi1))
        Ey0_pi0 = float(np.mean(yctrl_pi0))
        value = Ey1_pi1 * p1 + Ey0_pi0 * p0
        risk = 1.0 - value
        return risk

    def fit(self, X_train, T_train, Y_train):
        X_tr = np.asarray(X_train); T_tr = np.asarray(T_train); Y_tr = np.asarray(Y_train)
        featurizer = PolynomialFeatures(degree=2, include_bias=False) if self.use_poly2 else None

        self._econml = DRLearner(
            model_propensity=deepcopy(self.model_propensity),
            model_regression=deepcopy(self.model_regression),
            model_final=deepcopy(self.model_final),
            featurizer=featurizer,
            cv=self.cv,
            min_propensity=self.min_propensity,
            discrete_outcome=self.discrete_outcome,
            multitask_model_final=self.multitask_model_final,
            random_state=self.random_state
        )
        self._econml.fit(Y_tr, T_tr, X=X_tr, W=None)
        self._fit_full_aux_models(X_tr, T_tr, Y_tr)
        return self

    def predict_tau(self, X):
        return self._econml.effect(np.asarray(X))

    def evaluate(self, X, T=None, Y=None, e=None, mu0=None, mu1=None):
        X = np.asarray(X); T = np.asarray(T); Y = np.asarray(Y)
        tau_hat = self.predict_tau(X)
        metrics =  {}
        if (mu0 is not None) and (mu1 is not None):
            metrics["PEHE"] = self._pehe(tau_hat, mu0, mu1)
            metrics['ATE_error'] = self._abs_ate_error(tau_hat, mu0, mu1)
            metrics["rel_ATE_error"] = self._rel_ate_error(tau_hat, mu0, mu1)
        else:
            metrics["ATT_abs_error_rct"] = self.att_abs_error_rct(X, T, Y, e)
            metrics["policy_risk_rct"] = self.policy_risk_rct(X, T, Y, e)
        return metrics

