from causalml.inference.tf import DragonNet
import tensorflow as tf
import numpy as np
from causalml.inference.meta.utils import convert_pd_to_np

class DragonNetModel:
    def __init__(self, params=None):

        if params is None:
            params = {}
        self.params = params
        neurons = self.params.get("neurons_per_layer",
                   self.params.get("hidden_layer_size", 200))
        targeted_reg = self.params.get("targeted_reg", True)
        self.dragon = DragonNet(neurons_per_layer=neurons, targeted_reg=targeted_reg)
        seed = int(self.params.get("random_seed", 42))
        tf.random.set_seed(seed)
        np.random.seed(seed)

    @staticmethod
    def _pehe(tau_hat, mu0=None, mu1=None):
        if mu0 is None or mu1 is None:
            return np.nan
        tau_hat = np.asarray(tau_hat, dtype=float)
        tau_true = np.asarray(mu1, dtype=float) - np.asarray(mu0, dtype=float)
        m = np.isfinite(tau_hat) & np.isfinite(tau_true)
        if not np.any(m):
            return np.nan
        return float(np.sqrt(np.mean((tau_hat[m] - tau_true[m]) ** 2)))

    @staticmethod
    def _abs_ate_error(tau_hat, mu0=None, mu1=None):
        if mu0 is None or mu1 is None:
            return np.nan
        tau_hat = np.asarray(tau_hat, dtype=float)
        tau_true = np.asarray(mu1, dtype=float) - np.asarray(mu0, dtype=float)
        m = np.isfinite(tau_hat)
        if not np.any(m):
            return np.nan
        return float(abs(np.mean(tau_hat[m]) - np.mean(tau_true[m])))

    @staticmethod
    def _rel_ate_error(tau_hat, mu0=None, mu1=None):
        if mu0 is None or mu1 is None:
            return np.nan
        tau_true = np.asarray(mu1) - np.asarray(mu0)
        denom = abs(np.mean(tau_true)) + 1e-8
        return float(abs(np.mean(tau_hat) - np.mean(tau_true)) / denom)

    def fit(self, X, treatment, y):
        X, treatment, y = convert_pd_to_np(X, treatment, y)
        self.dragon.fit(X, treatment, y)
        return self

    def _parse_return_components(self, obj, n):

        def _as1d(a):
            a = np.asarray(a)
            if a.ndim == 2 and a.shape[1] == 1:
                a = a.ravel()
            return a.astype(np.float32).reshape(-1)

        y0 = np.full(n, np.nan, dtype=np.float32)
        y1 = np.full(n, np.nan, dtype=np.float32)
        ite = np.full(n, np.nan, dtype=np.float32)

        if isinstance(obj, (list, tuple)):
            parts = [_as1d(p) for p in obj]
        else:
            arr = np.asarray(obj)
            if arr.ndim == 1:
                parts = [_as1d(arr)]
            elif arr.ndim == 2:
                parts = [_as1d(arr[:, i:i+1]) for i in range(arr.shape[1])]
            else:
                parts = []

        if len(parts) == 1:
            ite = parts[0]
        elif len(parts) == 2:
            y0, y1 = parts
            ite = y1 - y0

        for arr in (y0, y1, ite):
            bad = ~np.isfinite(arr)
            if np.any(bad):
                arr[bad] = np.nan
        return y0, y1, ite

    def predict(self, X):
        X = np.asarray(X)
        preds = self.dragon.predict(X)
        _,_,ite = self._parse_return_components(preds, n=X.shape[0])
        ite = np.asarray(ite, dtype=np.float32).reshape(-1)
        ite[~np.isfinite(ite)] = np.nan
        return ite

    def att_abs_error_rct(self, X, T, Y, e, tau_hat=None):
        X = np.asarray(X); T = np.asarray(T); Y = np.asarray(Y); e = np.asarray(e)
        idx = np.where(e == 1)[0]
        if idx.size == 0:
            return np.nan
        treated_rct = idx[T[idx] == 1]
        control_rct = idx[T[idx] == 0]
        if treated_rct.size == 0 or control_rct.size == 0:
            return np.nan
        att_true = float(np.mean(Y[treated_rct]) - np.mean(Y[control_rct]))
        if tau_hat is None:
            tau_full = self.predict(X)
        else:
            tau_full = np.asarray(tau_hat).reshape(-1)
            if tau_full.shape[0] != X.shape[0]:
                raise ValueError("tau_hat length must match X.shape[0].")
        att_hat = float(np.mean(tau_full[treated_rct]))
        return abs(att_hat - att_true)

    def policy_risk_rct(self, X, T, Y, e, lam=0.0, tau_hat=None):
        X = np.asarray(X); T = np.asarray(T); Y = np.asarray(Y); e = np.asarray(e)
        idx = np.where(e == 1)[0]
        if idx.size == 0:
            return np.nan
        if tau_hat is None:
            tau_idx = self.predict(X[idx])
        else:
            tau_full = np.asarray(tau_hat).reshape(-1)
            tau_idx = tau_full[idx]
        pi = (tau_idx > lam).astype(int)
        p1 = float(np.mean(pi == 1)); p0 = 1.0 - p1
        mask_treat = (T[idx] == 1); mask_ctrl = (T[idx] == 0)
        ytreat_pi1 = Y[idx][(pi == 1) & mask_treat]
        yctrl_pi0  = Y[idx][(pi == 0) & mask_ctrl]
        if p1 > 0 and ytreat_pi1.size == 0:
            return np.nan
        if p0 > 0 and yctrl_pi0.size == 0:
            return np.nan
        Ey1_pi1 = float(np.mean(ytreat_pi1)) if p1 > 0 else 0.0
        Ey0_pi0 = float(np.mean(yctrl_pi0))  if p0 > 0 else 0.0
        value = Ey1_pi1 * p1 + Ey0_pi0 * p0
        return 1.0 - value

    def evaluate(self, X, T=None, Y=None, e=None, mu0=None, mu1=None, lam: float = 0.0):
        X = np.asarray(X); 
        ite = self.predict(X)  

        metrics = {}
        if (mu0 is not None) and (mu1 is not None):
            metrics["PEHE"] = self._pehe(ite, mu0, mu1)
            metrics["ATE_error"] = self._abs_ate_error(ite, mu0, mu1)
            metrics["rel_ATE_error"] = self._rel_ate_error(ite, mu0, mu1)
        else:
            if (T is not None) and (Y is not None) and (e is not None):
                metrics["ATT_abs_error_rct"] = self.att_abs_error_rct(X, T, Y, e, tau_hat=ite)
                metrics["policy_risk_rct"] = self.policy_risk_rct(X, T, Y, e, lam=lam, tau_hat=ite)
        return metrics
