# causal_forest.py
from __future__ import annotations

import numpy as np
from typing import Optional, Tuple, Union
from econml.dml import CausalForestDML

ArrayLike = Union[np.ndarray]

class GRF:
    def __init__(self, **cf_kwargs):
        self._cf = CausalForestDML(**cf_kwargs)
        self._fitted = False

    @staticmethod
    def _to_numpy(x: ArrayLike) -> np.ndarray:
        if hasattr(x, "values"):  
            x = x.values
        return np.ascontiguousarray(np.asarray(x))
    
  
    @staticmethod
    def _pehe(tau_hat, mu0=None, mu1=None):
        tau_true = np.asarray(mu1) - np.asarray(mu0)
        return float(np.sqrt(np.mean((tau_hat - tau_true) ** 2)))

    @staticmethod
    def _abs_ate_error(tau_hat, mu0=None, mu1=None):
        tau_true = np.asarray(mu1) - np.asarray(mu0)
        return float(abs(np.mean(tau_hat) - np.mean(tau_true)))
    
    @staticmethod
    def _rel_ate_error(tau_hat, mu0=None, mu1=None):
        if mu0 is None or mu1 is None:
            return np.nan
        tau_true = np.asarray(mu1) - np.asarray(mu0)
        denom = abs(np.mean(tau_true)) + 1e-8
        return float(abs(np.mean(tau_hat) - np.mean(tau_true)) / denom)
    
    def fit(
        self,
        X: ArrayLike,
        t: ArrayLike,
        y: ArrayLike,
        sample_weight: Optional[ArrayLike] = None,
    ) -> "GRF":
        X = self._to_numpy(X)
        t = self._to_numpy(t).reshape(-1)
        y = self._to_numpy(y).reshape(-1)
        if sample_weight is not None:
            sample_weight = self._to_numpy(sample_weight).reshape(-1)
        self._cf.fit(y, t, X=X, sample_weight=sample_weight)
        self._fitted = True
        return self

    def predict_tau(self, X: ArrayLike) -> np.ndarray:
        X = self._to_numpy(X)
        tau_hat = self._cf.effect(X)
        return tau_hat
    
    def att_abs_error_rct(self, X, T, Y, e):

        idx = np.where(e == 1)[0]
        treated_rct = idx[T[idx] == 1]
        control_rct = idx[T[idx] == 0]

        att_true = float(np.mean(Y[treated_rct]) - np.mean(Y[control_rct]))

        tau_hat_treated = self.predict_tau(X[treated_rct])
        att_hat = float(np.mean(tau_hat_treated))

        return abs(att_hat - att_true)

    def policy_risk_rct(self, X, T, Y, e, lam=0.0):

        idx = np.where(e == 1)[0]
        tau_hat = self.predict_tau(X[idx])
        pi = (tau_hat > lam).astype(int)

        p1 = np.mean(pi == 1)
        p0 = 1.0 - p1

        mask_treat = (T[idx] == 1)
        mask_ctrl  = (T[idx] == 0)

        ytreat_pi1 = Y[idx][(pi == 1) & mask_treat]
        yctrl_pi0  = Y[idx][(pi == 0) & mask_ctrl]

        if len(ytreat_pi1) == 0 or len(yctrl_pi0) == 0:
            return np.nan

        Ey1_pi1 = float(np.mean(ytreat_pi1))
        Ey0_pi0 = float(np.mean(yctrl_pi0))

        value = Ey1_pi1 * p1 + Ey0_pi0 * p0
        risk = 1.0 - value
        return risk

    def evaluate(self, X, T=None, Y = None, e=None, mu0=None, mu1=None):
        X = np.asarray(X); T = np.asarray(T); Y = np.asarray(Y)
        tau_hat = self.predict_tau(X)
        metrics = {}
        if (mu0 is not None) and (mu1 is not None):
            metrics["PEHE"] = self._pehe(tau_hat, mu0, mu1)
            metrics["ATE_error"] = self._abs_ate_error(tau_hat, mu0, mu1)
            metrics["rel_ATE_error"] = self._rel_ate_error(tau_hat, mu0, mu1)
        else:
            metrics["ATT_abs_error_rct"] = self.att_abs_error_rct(X, T, Y, e)
            metrics["policy_risk_rct"] = self.policy_risk_rct(X, T, Y, e)

        return metrics




