import tempfile
import os
import pandas as pd
from pathlib import Path
from cdt.utils.R import launch_R_script
import torch
from utils import adj2order, np_to_csv
from tabpfn import TabPFNRegressor
from xgboost import XGBRegressor
from sklearn.ensemble import RandomForestRegressor
from sklearn.neural_network import MLPRegressor
import numpy as np
import math


def _get_mdl_parameters(args):
    """Extract MDL gate parameters with defaults."""
    lambda_param = getattr(args, "mdl_lambda", 1.0) if args is not None else 1.0
    kappa = getattr(args, "mdl_kappa", 25.0) if args is not None else 25.0
    return float(lambda_param), float(kappa)


def _mdl_gate(n_samples, parent_candidates, parent_count_without_edge, lambda_param, kappa):
    """Compute the MDL gate threshold (per-sample) for a candidate parent."""
    if n_samples <= 0:
        return float("inf")
    admissible = max(parent_candidates - parent_count_without_edge, 1)
    size_term = max(parent_count_without_edge + 1, 1)
    penalty = math.log(admissible) + lambda_param * math.log(size_term) + kappa
    return penalty / n_samples


def _edge_should_be_removed(
    ll_improvement,
    n_samples,
    parent_candidates,
    parent_count_without_edge,
    lambda_param,
    kappa,
    override_threshold=None,
):
    """Decide whether to prune an edge using the MDL gate or a manual override."""
    if not np.isfinite(ll_improvement):
        return True, float("inf")
    delta = ll_improvement / n_samples if n_samples > 0 else float("inf")
    if override_threshold is not None:
        tau = override_threshold
    else:
        tau = _mdl_gate(
            n_samples,
            parent_candidates,
            parent_count_without_edge,
            lambda_param,
            kappa,
        )
    return delta <= tau, tau


# XGBoost
def _compute_xgb_likelihood(X, target_node, parent_set):
    """Compute likelihood for a given parent set using XGBoost."""
    if len(parent_set) == 0:
        return float("-inf"), 1.0
    
    try:
        y = X[:, target_node]
        X_parents = X[:, parent_set]
        
        # Strengthen regularization to reduce overfitting
        model = XGBRegressor(
            objective='reg:squarederror', 
            eval_metric='rmse', 
            n_estimators=50,  # Reduce ensemble size
            max_depth=3,       # Limit tree depth
            reg_alpha=0.1,     # L1 regularization
            reg_lambda=1.0,    # L2 regularization
            random_state=0
        )
        
        # Use cross-validation to avoid overfitting
        from sklearn.model_selection import cross_val_score
        from sklearn.metrics import mean_squared_error
        
        # Compute MSE via 5-fold cross-validation
        cv_scores = cross_val_score(model, X_parents, y, cv=5, scoring='neg_mean_squared_error')
        mse = -cv_scores.mean()  # Convert negative MSE to positive
        
        # Fit the model on the full dataset
        model.fit(X_parents, y)
        y_pred = model.predict(X_parents)
        
        # Use cross-validated MSE for likelihood
        n = len(y)
        log_likelihood = -n/2 * np.log(mse) - n/(2 * mse) * mse
        
        # Estimate uncertainty from cross-validation variance
        uncertainty = mse
        
        return log_likelihood, uncertainty
        
    except Exception:
        return float("-inf"), 1.0

def compute_xgb_edge_confidence_score(X, target_node, parent_set, exclude_node=None):
    """Compute LL-based edge confidence using XGBoost."""
    try:
        parents_without = [p for p in parent_set if p != exclude_node] if exclude_node is not None else parent_set
        ll_without, _ = _compute_xgb_likelihood(X, target_node, parents_without)
        ll_with, _ = _compute_xgb_likelihood(X, target_node, parent_set)
        ll_improvement = ll_with - ll_without
        return ll_improvement, ll_without, ll_with
    except Exception:
        return float("-inf"), float("-inf"), float("-inf")

def xgb_pruning(init_dag, X, args=None):
    """Prune a DAG using LL scores from XGBoost."""
    n_samples, n_nodes = X.shape
    pruned_dag = init_dag.copy()
    parent_candidates = init_dag.sum(axis=0).astype(int)
    lambda_param, kappa = _get_mdl_parameters(args)
    override = getattr(args, "xgb_confidence_threshold", None) if args is not None else None
    tau_values = []

    for j in range(n_nodes):
        for i in range(n_nodes):
            if pruned_dag[i, j] != 1:
                continue
            current_parents = [p for p in range(n_nodes) if pruned_dag[p, j] == 1]
            if not current_parents:
                continue
            ll_improvement, _, _ = compute_xgb_edge_confidence_score(
                X, j, current_parents, exclude_node=i
            )
            k_without = len(current_parents) - 1
            should_remove, tau_value = _edge_should_be_removed(
                ll_improvement,
                n_samples,
                parent_candidates[j],
                k_without,
                lambda_param,
                kappa,
                override_threshold=override,
            )
            if np.isfinite(tau_value):
                tau_values.append(tau_value)
            if should_remove:
                pruned_dag[i, j] = 0

    tau_stats = {
        "mode": "manual" if override is not None else "mdl",
        "values": tau_values,
        "override_value": override,
    }

    return pruned_dag, adj2order(pruned_dag), tau_stats


# RandomForest
def _compute_rf_likelihood(X, target_node, parent_set):
    """Compute likelihood for a given parent set using RandomForest."""
    if len(parent_set) == 0:
        return float("-inf"), 1.0
    
    try:
        y = X[:, target_node]
        X_parents = X[:, parent_set]
        
        # Strengthen regularization to reduce overfitting
        model = RandomForestRegressor(
            n_estimators=50,    # Reduce ensemble size
            max_depth=3,         # Limit tree depth
            min_samples_split=5, # Require minimum samples to split
            min_samples_leaf=2,  # Require minimum samples per leaf
            random_state=0
        )
        
        # Use cross-validation to avoid overfitting
        from sklearn.model_selection import cross_val_score
        from sklearn.metrics import mean_squared_error
        
        # Compute MSE via 5-fold cross-validation
        cv_scores = cross_val_score(model, X_parents, y, cv=5, scoring='neg_mean_squared_error')
        mse = -cv_scores.mean()  # Convert negative MSE to positive
        
        # Fit the model on the full dataset
        model.fit(X_parents, y)
        y_pred = model.predict(X_parents)
        
        # Use cross-validated MSE for likelihood
        n = len(y)
        log_likelihood = -n/2 * np.log(mse) - n/(2 * mse) * mse
        
        # Estimate uncertainty from cross-validation variance
        uncertainty = mse
        
        return log_likelihood, uncertainty
        
    except Exception:
        return float("-inf"), 1.0

def compute_rf_edge_confidence_score(X, target_node, parent_set, exclude_node=None):
    """Compute LL-based edge confidence using RandomForest."""
    try:
        parents_without = [p for p in parent_set if p != exclude_node] if exclude_node is not None else parent_set
        ll_without, _ = _compute_rf_likelihood(X, target_node, parents_without)
        ll_with, _ = _compute_rf_likelihood(X, target_node, parent_set)
        ll_improvement = ll_with - ll_without
        return ll_improvement, ll_without, ll_with
    except Exception:
        return float("-inf"), float("-inf"), float("-inf")

def rf_pruning(init_dag, X, args=None):
    """Prune a DAG using LL scores from RandomForest."""
    n_samples, n_nodes = X.shape
    pruned_dag = init_dag.copy()
    parent_candidates = init_dag.sum(axis=0).astype(int)
    lambda_param, kappa = _get_mdl_parameters(args)
    override = getattr(args, "rf_confidence_threshold", None) if args is not None else None
    tau_values = []

    for j in range(n_nodes):
        for i in range(n_nodes):
            if pruned_dag[i, j] != 1:
                continue
            current_parents = [p for p in range(n_nodes) if pruned_dag[p, j] == 1]
            if not current_parents:
                continue
            ll_improvement, _, _ = compute_rf_edge_confidence_score(
                X, j, current_parents, exclude_node=i
            )
            k_without = len(current_parents) - 1
            should_remove, tau_value = _edge_should_be_removed(
                ll_improvement,
                n_samples,
                parent_candidates[j],
                k_without,
                lambda_param,
                kappa,
                override_threshold=override,
            )
            if np.isfinite(tau_value):
                tau_values.append(tau_value)
            if should_remove:
                pruned_dag[i, j] = 0

    tau_stats = {
        "mode": "manual" if override is not None else "mdl",
        "values": tau_values,
        "override_value": override,
    }

    return pruned_dag, adj2order(pruned_dag), tau_stats


# MLP
def _compute_mlp_likelihood(X, target_node, parent_set):
    """Compute likelihood for a given parent set using an MLP."""
    if len(parent_set) == 0:
        return float("-inf"), 1.0
    
    try:
        y = X[:, target_node]
        X_parents = X[:, parent_set]
        
        # Strengthen regularization to reduce overfitting
        model = MLPRegressor(
            hidden_layer_sizes=(50, 25),  # Compact hidden layers
            activation='relu',
            solver='adam',
            alpha=0.1,           # Increase L2 regularization
            batch_size='auto',
            learning_rate='adaptive',
            learning_rate_init=0.001,
            max_iter=200,        # Limit epochs
            early_stopping=True,  # Enable early stopping
            validation_fraction=0.2,
            n_iter_no_change=10,
            random_state=0
        )
        
        # Use cross-validation to avoid overfitting
        from sklearn.model_selection import cross_val_score
        from sklearn.metrics import mean_squared_error
        
        # Compute MSE via 5-fold cross-validation
        cv_scores = cross_val_score(model, X_parents, y, cv=5, scoring='neg_mean_squared_error')
        mse = -cv_scores.mean()  # Convert negative MSE to positive
        
        # Fit the model on the full dataset
        model.fit(X_parents, y)
        y_pred = model.predict(X_parents)
        
        # Use cross-validated MSE for likelihood
        n = len(y)
        log_likelihood = -n/2 * np.log(mse) - n/(2 * mse) * mse
        
        # Estimate uncertainty from cross-validation variance
        uncertainty = mse
        
        return log_likelihood, uncertainty
        
    except Exception:
        return float("-inf"), 1.0

def compute_mlp_edge_confidence_score(X, target_node, parent_set, exclude_node=None):
    """Compute LL-based edge confidence using an MLP."""
    try:
        parents_without = [p for p in parent_set if p != exclude_node] if exclude_node is not None else parent_set
        ll_without, _ = _compute_mlp_likelihood(X, target_node, parents_without)
        ll_with, _ = _compute_mlp_likelihood(X, target_node, parent_set)
        ll_improvement = ll_with - ll_without
        return ll_improvement, ll_without, ll_with
    except Exception:
        return float("-inf"), float("-inf"), float("-inf")

def mlp_pruning(init_dag, X, args=None):
    """Prune a DAG using LL scores from an MLP."""
    n_samples, n_nodes = X.shape
    pruned_dag = init_dag.copy()
    parent_candidates = init_dag.sum(axis=0).astype(int)
    lambda_param, kappa = _get_mdl_parameters(args)
    override = getattr(args, "mlp_confidence_threshold", None) if args is not None else None
    tau_values = []

    for j in range(n_nodes):
        for i in range(n_nodes):
            if pruned_dag[i, j] != 1:
                continue
            current_parents = [p for p in range(n_nodes) if pruned_dag[p, j] == 1]
            if not current_parents:
                continue
            ll_improvement, _, _ = compute_mlp_edge_confidence_score(
                X, j, current_parents, exclude_node=i
            )
            k_without = len(current_parents) - 1
            should_remove, tau_value = _edge_should_be_removed(
                ll_improvement,
                n_samples,
                parent_candidates[j],
                k_without,
                lambda_param,
                kappa,
                override_threshold=override,
            )
            if np.isfinite(tau_value):
                tau_values.append(tau_value)
            if should_remove:
                pruned_dag[i, j] = 0

    tau_stats = {
        "mode": "manual" if override is not None else "mdl",
        "values": tau_values,
        "override_value": override,
    }

    return pruned_dag, adj2order(pruned_dag), tau_stats



















# TabPFN
def get_tabpfn_device(args):
    """
    Determine the best device to use with TabPFN.
    """
    if hasattr(args, "device") and args.device:
        requested_device = args.device
        if requested_device.startswith("cuda"):
            if torch.cuda.is_available():
                try:
                    torch.cuda.set_device(requested_device)
                    return requested_device
                except RuntimeError:
                    return "cpu"
            else:
                return "cpu"
        else:
            return requested_device
    else:
        return "cpu"


def compute_real_log_likelihood(X, target_node, parent_set, device="cpu"):
    """Compute the log-likelihood using TabPFN."""
    if TabPFNRegressor is None or len(parent_set) == 0:
        return float("-inf"), 1.0

    try:
        y = X[:, target_node]
        X_parents = X[:, parent_set]
        model_path = "./tabpfn/tabpfn-v2-regressor.ckpt"
        try:
            regressor = TabPFNRegressor(device=device, model_path=model_path)
            regressor.fit(X_parents, y)
        except (RuntimeError, torch.cuda.OutOfMemoryError):
            if device != "cpu":
                regressor = TabPFNRegressor(device="cpu", model_path=model_path)
                regressor.fit(X_parents, y)
            else:
                raise
        full_output = regressor.predict(X_parents, output_type="full")
        logits = full_output["logits"]
        criterion = full_output["criterion"]
        y_tensor = torch.tensor(y, dtype=torch.float32, device=logits.device)
        nll = criterion.forward(logits, y_tensor)
        log_likelihood = -nll.mean().item()
        uncertainty = criterion.variance(logits).mean().item()
        return log_likelihood, uncertainty
    except Exception:
        return float("-inf"), 1.0


def compute_edge_confidence_score(
    X, target_node, parent_set, exclude_node=None, device="cpu"
):
    """Compute LL-based edge confidence using TabPFN."""
    if TabPFNRegressor is None:
        return float("-inf"), float("-inf"), float("-inf")
    try:
        parents_without = (
            [p for p in parent_set if p != exclude_node]
            if exclude_node is not None
            else parent_set
        )
        ll_without, _ = compute_real_log_likelihood(
            X, target_node, parents_without, device
        )
        ll_with, _ = compute_real_log_likelihood(X, target_node, parent_set, device)
        ll_improvement = ll_with - ll_without
        return ll_improvement, ll_without, ll_with
    except Exception:
        return float("-inf"), float("-inf"), float("-inf")


def TabPFN_pruning(init_dag, X, args):
    """Prune a DAG using LL scores from TabPFN."""
    n_samples, n_nodes = X.shape
    pruned_dag = init_dag.copy()
    tabpfn_device = get_tabpfn_device(args)
    parent_candidates = init_dag.sum(axis=0).astype(int)
    lambda_param, kappa = _get_mdl_parameters(args)
    override = getattr(args, "confidence_threshold", None)
    tau_values = []

    for j in range(n_nodes):
        for i in range(n_nodes):
            if pruned_dag[i, j] != 1:
                continue
            current_parents = [p for p in range(n_nodes) if pruned_dag[p, j] == 1]
            if not current_parents:
                continue
            ll_improvement, _, _ = compute_edge_confidence_score(
                X, j, current_parents, exclude_node=i, device=tabpfn_device
            )
            k_without = len(current_parents) - 1
            should_remove, tau_value = _edge_should_be_removed(
                ll_improvement,
                n_samples,
                parent_candidates[j],
                k_without,
                lambda_param,
                kappa,
                override_threshold=override,
            )
            if np.isfinite(tau_value):
                tau_values.append(tau_value)
            if should_remove:
                pruned_dag[i, j] = 0

    tau_stats = {
        "mode": "manual" if override is not None else "mdl",
        "values": tau_values,
        "override_value": override,
    }

    return pruned_dag, tau_stats














def cam_pruning(A, X, cutoff, only_pruning=True):
    with tempfile.TemporaryDirectory() as save_path:
        if only_pruning:
            pruning_path = (
                Path(__file__).parent.parent / "pruning_R_files/cam_pruning.R"
            )

        data_np = X
        data_csv_path = np_to_csv(data_np, save_path)
        dag_csv_path = np_to_csv(A, save_path)

        arguments = dict()
        arguments["{PATH_DATA}"] = data_csv_path
        arguments["{PATH_DAG}"] = dag_csv_path
        arguments["{PATH_RESULTS}"] = os.path.join(save_path, "results.csv")
        arguments["{ADJFULL_RESULTS}"] = os.path.join(save_path, "adjfull.csv")
        arguments["{CUTOFF}"] = str(cutoff)
        arguments["{VERBOSE}"] = "FALSE"

        def retrieve_result():
            A = pd.read_csv(arguments["{PATH_RESULTS}"]).values
            os.remove(arguments["{PATH_RESULTS}"])
            os.remove(arguments["{PATH_DATA}"])
            os.remove(arguments["{PATH_DAG}"])
            return A

        dag = launch_R_script(
            str(pruning_path), arguments, output_function=retrieve_result
        )

    return dag, adj2order(dag)


def Stein_hess(X, eta_G, eta_H, s=None, device=None):
    """
    Estimates the diagonal of the Hessian of log p_X at the provided samples points
    X, using first and second-order Stein identities
    """
    n, d = X.shape
    X = X.to(device)
    X_diff = X.unsqueeze(1) - X
    if s is None:
        D = torch.norm(X_diff, dim=2, p=2)
        s = D.flatten().median()
    K = torch.exp(-torch.norm(X_diff, dim=2, p=2) ** 2 / (2 * s**2)) / s

    nablaK = -torch.einsum("kij,ik->kj", X_diff, K) / s**2
    G = torch.matmul(torch.inverse(K + eta_G * torch.eye(n).to(device)), nablaK)

    nabla2K = torch.einsum("kij,ik->kj", -1 / s**2 + X_diff**2 / s**4, K)
    return (
        -(G**2)
        + torch.matmul(torch.inverse(K + eta_H * torch.eye(n).to(device)), nabla2K)
    ).to("cpu")
