#!/usr/bin/env python
# coding: utf-8

"""Script to run counterfactual reasoning experiments."""

import os
import sys
import argparse
import numpy as np
import pandas as pd
import pickle
import torch
import torch.nn as nn # Added
import networkx as nx
from sklearn.linear_model import LinearRegression # Added
from sklearn.metrics import mean_squared_error

# Add project root to path
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.insert(0, project_root)

# Import ECAM components
try:
    from src.ecam import ECAM
    from src.counterfactual import CounterfactualModule # Still a placeholder, not used directly
    from src.graph_learner import GraphLearner
    ecam_available = True
except ImportError as e:
    print(f"Warning: ECAM modules not found or import failed: {e}")
    ECAM = None
    CounterfactualModule = None
    GraphLearner = None
    ecam_available = False

# Import SCM generation function
try:
    from scripts.generate_synthetic_data import generate_scm_data
except ImportError as e:
    print(f"Warning: generate_scm_data function not found: {e}")
    generate_scm_data = None

# Import NOTEARS for baseline graph estimation
try:
    from notears import linear as notears_linear
    notears_available = True
except ImportError:
    print("Warning: notears library not found. NOTEARS baseline graph estimation unavailable.")
    notears_linear = None
    notears_available = False

# --- Helper Functions ---

def fit_linear_scm(data, graph_adj):
    """Fits linear coefficients for a given graph structure."""
    n_samples, n_nodes = data.shape
    weights = {} # Store weights as {(parent, child): weight}
    models = {} # Store fitted regression models
    
    for j in range(n_nodes):
        parents = np.where(graph_adj[:, j] == 1)[0]
        if parents.size > 0:
            X_train = data[:, parents]
            y_train = data[:, j]
            if X_train.ndim == 1:
                X_train = X_train.reshape(-1, 1)
            
            model = LinearRegression(fit_intercept=False) # Assuming noise is zero-mean
            try:
                model.fit(X_train, y_train)
                models[j] = model
                for idx, parent_node in enumerate(parents):
                    weights[(parent_node, j)] = model.coef_[idx]
            except ValueError as e:
                print(f"  Linear SCM Fit Error (Target: {j}, Parents: {parents}): {e}")
                models[j] = None # Mark as failed
        else:
            # Root node, no parents to fit
            models[j] = None 
            
    return weights, models

def estimate_noise_from_scm(data, graph_adj, weights, scm_type="linear"):
    """Estimate exogenous noise variables from observed data and a given SCM (graph + weights)."""
    n_samples, n_nodes = data.shape
    noise_estimates = np.zeros_like(data)
    nodes_in_order = list(nx.topological_sort(nx.DiGraph(graph_adj))) # Ensure topological order based on graph_adj

    for node in nodes_in_order:
        parents = np.where(graph_adj[:, node] == 1)[0]
        observed_values = data[:, node]

        if not parents.size > 0:
            # Root node
            noise_estimates[:, node] = observed_values
        else:
            # Compute parent contribution using provided weights
            parent_contribution = sum(weights.get((p, node), 0) * data[:, p] for p in parents)

            # Estimate noise based on SCM type
            if scm_type == "linear":
                noise_estimates[:, node] = observed_values - parent_contribution
            # Add non-linear case if needed
            else:
                raise ValueError(f"Noise estimation for SCM type 	{scm_type}	 not implemented")

    return noise_estimates

def predict_from_scm(graph_adj, weights, noise_sample, intervention_details=None, scm_type="linear"):
    """Predicts outcomes given graph, weights, noise, and optional intervention."""
    n_nodes = graph_adj.shape[0]
    X_pred = np.zeros(n_nodes)
    nodes_in_order = list(nx.topological_sort(nx.DiGraph(graph_adj)))

    intervened_node = None
    intervention_value = None
    if intervention_details:
        intervened_node = list(intervention_details.keys())[0]
        intervention_value = intervention_details[intervened_node]

    for node in nodes_in_order:
        # Apply intervention if specified
        if node == intervened_node:
            X_pred[node] = intervention_value
            continue

        parents = np.where(graph_adj[:, node] == 1)[0]
        node_noise = noise_sample[node]

        if not parents.size > 0:
            X_pred[node] = node_noise
        else:
            parent_contribution = sum(weights.get((p, node), 0) * X_pred[p] for p in parents)
            if scm_type == "linear":
                X_pred[node] = parent_contribution + node_noise
            # Add non-linear case if needed
            else:
                raise ValueError(f"Prediction for SCM type 	{scm_type}	 not implemented")

    return X_pred

def calculate_true_counterfactual(graph_true, weights_true, noise_sample, query, scm_type="linear"):
    """Calculates the true counterfactual outcome using Pearl's 3 steps (known SCM)."""
    # Step 1: Abduction (already done, noise_sample is provided)
    # Step 2 & 3: Action & Prediction
    predicted_values = predict_from_scm(nx.to_numpy_array(graph_true, nodelist=sorted(graph_true.nodes())), 
                                        weights_true, 
                                        noise_sample, 
                                        intervention_details=query["intervention"], 
                                        scm_type=scm_type)
    return predicted_values[query["target"]]


def estimate_counterfactual_scm(obs_data_sample, query, estimated_graph_adj, scm_type="linear"):
    """Estimates counterfactual using Abduction-Action-Prediction with an estimated SCM.
    
    Args:
        obs_data_sample (np.ndarray): The factual observed data sample (shape: n_nodes).
        query (dict): Counterfactual query.
        estimated_graph_adj (np.ndarray): Adjacency matrix of the estimated graph.
        scm_type (str): Type of SCM.

    Returns:
        float: Estimated counterfactual value.
    """
    n_nodes = estimated_graph_adj.shape[0]
    
    # Fit parameters for the estimated graph using the full observational dataset (needed for noise estimation)
    # This assumes obs_data is available in the scope or passed as argument.
    # For simplicity, let's assume obs_data is globally accessible or refit here (less efficient).
    # Let's pass obs_data to this function.
    # weights_est, _ = fit_linear_scm(obs_data, estimated_graph_adj) # Needs full obs_data
    # For now, we cannot easily refit here. Let's assume weights_est are somehow provided or use a placeholder.
    # Placeholder: Use true weights for noise estimation step - this is incorrect for a true baseline.
    # A proper baseline needs to estimate weights from obs_data + estimated_graph_adj.
    # Let's skip the baseline for now due to complexity of passing/refitting weights.
    
    # --- Simplified Approach: Use True Weights for Abduction (Not a realistic baseline) ---
    # This calculates noise based on true SCM, then predicts on estimated SCM.
    # try:
    #     noise_est_sample = estimate_noise_from_scm(obs_data_sample.reshape(1, -1), G_true, weights_true, scm_type)[0]
    # except Exception as e:
    #     print(f"Error in baseline noise estimation: {e}")
    #     return np.nan
    # 
    # # Step 2 & 3: Action & Prediction using ESTIMATED graph and weights
    # # We still need estimated weights. Let's use true weights again as placeholder.
    # try:
    #     predicted_values = predict_from_scm(estimated_graph_adj, weights_true, noise_est_sample, query["intervention"], scm_type)
    #     return predicted_values[query["target"]]
    # except Exception as e:
    #     print(f"Error in baseline prediction: {e}")
    #     return np.nan
    print("SCM-based baseline counterfactual estimation requires estimated weights - skipping.")
    return np.nan

def estimate_counterfactual_ecam(obs_data_sample, query, obs_data, d_model=64, graph_reg=0.01, scm_type="linear"):
    """Estimates counterfactual using regression on graph learned by ECAM's GraphLearner.
       Uses Abduction(True SCM) - Action - Prediction(Estimated SCM) framework.
    """
    if not ecam_available or GraphLearner is None:
        print("ECAM GraphLearner not available.")
        return np.nan

    n_nodes = obs_data_sample.shape[0]

    # --- Learn graph using GraphLearner (similar to discovery/intervention script) ---
    node_features = torch.randn(1, n_nodes, d_model) 
    try:
        graph_learner = GraphLearner(d_model, graph_reg)
        graph_learner.eval()
        with torch.no_grad():
            G_ecam_raw = graph_learner(node_features) # G shape (1, n_nodes, n_nodes)
            adj_matrix_ecam_raw = G_ecam_raw.squeeze(0).detach().cpu().numpy()
            pred_adj_ecam = (adj_matrix_ecam_raw > 0.5).astype(int) 
            np.fill_diagonal(pred_adj_ecam, 0) 
    except Exception as e:
        print(f"Error running ECAM GraphLearner for counterfactual: {e}")
        return np.nan

    # --- Fit SCM parameters based on ECAM graph and observational data ---
    try:
        weights_ecam, _ = fit_linear_scm(obs_data, pred_adj_ecam)
    except Exception as e:
        print(f"Error fitting SCM for ECAM graph: {e}")
        return np.nan
        
    # --- Abduction: Estimate noise using TRUE SCM (as per standard CF definition) ---
    # Requires access to G_true, weights_true - assume available in main scope for now
    try:
        noise_est_sample = estimate_noise_from_scm(obs_data_sample.reshape(1, -1), G_true, weights_true, scm_type)[0]
    except Exception as e:
        print(f"Error in ECAM noise estimation (using true SCM): {e}")
        return np.nan

    # --- Action & Prediction using ECAM's estimated graph and weights ---
    try:
        predicted_values = predict_from_scm(pred_adj_ecam, weights_ecam, noise_est_sample, query["intervention"], scm_type)
        return predicted_values[query["target"]]
    except Exception as e:
        print(f"Error in ECAM prediction: {e}")
        return np.nan

# --- Global variables for true SCM (needed for noise estimation in ECAM method) --- 
G_true = None
weights_true = None

# --- Main Experiment ---

def main(args):
    global G_true, weights_true # Allow modification of global vars
    results = []
    synthetic_data_dir = os.path.join(project_root, "data", "synthetic")

    if not os.path.exists(synthetic_data_dir):
        print(f"Error: Synthetic data directory not found: {synthetic_data_dir}")
        return

    num_graphs_processed = 0
    for i in range(args.num_graphs):
        graph_idx = i + 1
        graph_file = os.path.join(synthetic_data_dir, f"graph_{graph_idx}.pkl")
        obs_file = os.path.join(synthetic_data_dir, f"obs_data_{graph_idx}.npy")

        if not (os.path.exists(graph_file) and os.path.exists(obs_file)):
            print(f"Warning: Data files for graph {graph_idx} not found. Skipping.")
            continue

        print(f"\n--- Processing Graph {graph_idx} ---")
        num_graphs_processed += 1

        # Load data and true SCM
        try:
            with open(graph_file, "rb") as f:
                graph_data = pickle.load(f)
                G_true = graph_data["graph"] # Assign to global
                weights_true = graph_data["weights"] # Assign to global
                n_nodes = G_true.number_of_nodes()
            obs_data = np.load(obs_file)
            # Assuming linear SCM for now based on generation script default
            scm_type = "linear"
        except Exception as e:
            print(f"Error loading data for graph {graph_idx}: {e}")
            continue

        # Step 1: Abduction (Estimate noise for all observational samples using TRUE SCM)
        try:
            noise_estimates = estimate_noise_from_scm(obs_data, nx.to_numpy_array(G_true, nodelist=sorted(G_true.nodes())), weights_true, scm_type)
        except Exception as e:
            print(f"Error estimating noise for graph {graph_idx}: {e}")
            continue

        # Generate and evaluate counterfactual queries for a subset of samples
        num_samples_to_query = min(args.num_queries, obs_data.shape[0])
        query_indices = np.random.choice(obs_data.shape[0], num_samples_to_query, replace=False)

        print(f"  Generating {args.num_cf_per_sample} counterfactuals for {num_samples_to_query} samples...")
        
        # Pre-calculate ECAM graph and weights once per graph_idx
        pred_adj_ecam = None
        weights_ecam = None
        if ecam_available and GraphLearner is not None:
            node_features = torch.randn(1, n_nodes, args.d_model)
            try:
                graph_learner = GraphLearner(args.d_model, args.graph_reg)
                graph_learner.eval()
                with torch.no_grad():
                    G_ecam_raw = graph_learner(node_features)
                    adj_matrix_ecam_raw = G_ecam_raw.squeeze(0).detach().cpu().numpy()
                    pred_adj_ecam = (adj_matrix_ecam_raw > 0.5).astype(int)
                    np.fill_diagonal(pred_adj_ecam, 0)
                weights_ecam, _ = fit_linear_scm(obs_data, pred_adj_ecam)
            except Exception as e:
                print(f"Error pre-calculating ECAM graph/weights: {e}")
                pred_adj_ecam = None
                weights_ecam = None
        
        # Placeholder for baseline graph/weights (e.g., from NOTEARS)
        pred_adj_baseline = None 
        weights_baseline = None
        # TODO: Load baseline graph from discovery results if needed

        for sample_idx in query_indices:
            factual_obs = obs_data[sample_idx, :]
            noise_sample = noise_estimates[sample_idx, :]

            for cf_idx in range(args.num_cf_per_sample):
                # Randomly select intervention node, target node, and intervention value
                interv_node = np.random.randint(0, n_nodes)
                target_node = np.random.choice([n for n in range(n_nodes) if n != interv_node])
                # Sample intervention value (e.g., shift from factual)
                interv_value = factual_obs[interv_node] + np.random.normal(0, 1.0)

                cf_query = {
                    "intervention": {interv_node: interv_value},
                    "target": target_node
                }

                # Calculate True Counterfactual
                true_cf_value = calculate_true_counterfactual(G_true, weights_true, noise_sample, cf_query, scm_type)
                if np.isnan(true_cf_value):
                    continue # Skip if true value calculation failed

                # Estimate Counterfactuals
                pred_baseline = np.nan # Baseline skipped for now
                # if pred_adj_baseline is not None and weights_baseline is not None:
                #     try:
                #         # Abduction using baseline SCM
                #         noise_baseline = estimate_noise_from_scm(factual_obs.reshape(1,-1), pred_adj_baseline, weights_baseline, scm_type)[0]
                #         # Prediction using baseline SCM
                #         pred_baseline = predict_from_scm(pred_adj_baseline, weights_baseline, noise_baseline, cf_query["intervention"], scm_type)[cf_query["target"]]
                #     except Exception as e:
                #         print(f"Error in baseline CF estimation: {e}")
                #         pred_baseline = np.nan
                # else:
                #     pred_baseline = np.nan

                pred_ecam = np.nan
                if pred_adj_ecam is not None and weights_ecam is not None:
                    # Check if the ECAM graph is a DAG before proceeding
                    is_dag = nx.is_directed_acyclic_graph(nx.DiGraph(pred_adj_ecam))
                    if not is_dag:
                        # print(f"Warning: ECAM graph for graph_idx {graph_idx} contains cycles. Skipping ECAM CF estimation for query {cf_idx} on sample {sample_idx}.") # Optional: more detailed warning
                        pred_ecam = np.nan
                    else:
                        try:
                            # Abduction using TRUE SCM (standard definition)
                            # Prediction using ECAM SCM
                            pred_ecam = predict_from_scm(pred_adj_ecam, weights_ecam, noise_sample, cf_query["intervention"], scm_type)[cf_query["target"]]
                        except Exception as e:
                            print(f"Error in ECAM CF estimation (DAG check passed): {e}")
                            pred_ecam = np.nan
                
                mse_baseline = (pred_baseline - true_cf_value)**2 if not np.isnan(pred_baseline) else np.nan
                mse_ecam = (pred_ecam - true_cf_value)**2 if not np.isnan(pred_ecam) else np.nan

                results.append({
                    "graph_idx": graph_idx,
                    "sample_idx": sample_idx,
                    "cf_query_idx": cf_idx,
                    "interv_node": interv_node,
                    "interv_value": interv_value,
                    "target_node": target_node,
                    "true_cf_value": true_cf_value,
                    "pred_baseline": pred_baseline,
                    "mse_baseline": mse_baseline,
                    "pred_ecam": pred_ecam,
                    "mse_ecam": mse_ecam,
                })

    if num_graphs_processed == 0:
        print("No synthetic data graphs were processed.")
        return

    # Save results
    results_df = pd.DataFrame(results)
    os.makedirs(args.output_dir, exist_ok=True)
    output_path = os.path.join(args.output_dir, "counterfactual_estimation_results.csv")
    results_df.to_csv(output_path, index=False)
    print(f"\nCounterfactual estimation results saved to {output_path}")

    # Print summary statistics
    if not results_df.empty:
        avg_mse_baseline = results_df["mse_baseline"].mean()
        avg_mse_ecam = results_df["mse_ecam"].mean()
        print(f"\nSummary:")
        print(f"  Avg MSE Baseline: {avg_mse_baseline:.4f}")
        print(f"  Avg MSE ECAM: {avg_mse_ecam:.4f}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run Counterfactual Reasoning Experiments")
    parser.add_argument("--num_graphs", type=int, default=10, help="Number of synthetic graphs to process")
    parser.add_argument("--num_queries", type=int, default=100, help="Number of observational samples to generate counterfactuals for")
    parser.add_argument("--num_cf_per_sample", type=int, default=5, help="Number of counterfactual queries per sample")
    parser.add_argument("--output_dir", type=str, default="/home/ubuntu/ecam_project/results/counterfactual", help="Directory to save results")
    # Add args for ECAM model if needed
    parser.add_argument("--d_model", type=int, default=64, help="Dimension for ECAM GraphLearner")
    parser.add_argument("--graph_reg", type=float, default=0.01, help="Regularization for ECAM GraphLearner")

    args = parser.parse_args()
    main(args)

