#!/usr/bin/env python
# coding: utf-8

"""Script to run intervention effect estimation experiments."""

import os
import sys
import argparse
import numpy as np
import pandas as pd
import pickle
import torch
import torch.nn as nn # Added import
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error
import networkx as nx

# Add project root to path
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.insert(0, project_root)

# Import ECAM components
try:
    from src.ecam import ECAM
    from src.intervention import InterventionModule
    from src.graph_learner import GraphLearner
    ecam_available = True
except ImportError:
    print("Warning: ECAM source files not found or import failed.")
    ECAM = None
    InterventionModule = None
    GraphLearner = None
    ecam_available = False

# --- Helper Functions ---

def estimate_intervention_effect_regression(obs_data, target_node, intervention_details, graph_adj):
    """
    Estimates intervention effect using simple linear regression on parents based on provided graph_adj.
    Predicts E[Y | do(X=x)] by training Y ~ Parents(Y) on observational data,
    then predicting with intervened node's value set and other parents at their mean.
    """
    n_samples, n_nodes = obs_data.shape
    intervened_node = intervention_details["node"]
    intervention_value = intervention_details["value"]

    # Identify parents of the target node based on the graph
    if graph_adj is None:
        # If no graph, use all other nodes as potential parents (less accurate)
        parents = np.array([i for i in range(n_nodes) if i != target_node])
    else:
        parents = np.where(graph_adj[:, target_node] == 1)[0]

    if not parents.size > 0: # Check if parents array is not empty
        # If no parents, predict the mean of the target node from observational data
        return np.mean(obs_data[:, target_node])

    # Prepare data for regression: Y ~ Parents(Y)
    X_train = obs_data[:, parents]
    y_train = obs_data[:, target_node]

    # Ensure X_train is 2D
    if X_train.ndim == 1:
        X_train = X_train.reshape(-1, 1)

    # Train linear regression model
    model = LinearRegression()
    try:
        model.fit(X_train, y_train)
    except ValueError as e:
        print(f"  Regression Fit Error (Target: {target_node}, Parents: {parents}): {e}")
        return np.nan # Cannot fit model

    # Prepare input vector for prediction under intervention
    # Use mean of parents from observational data, but set intervened node's value
    pred_input_values = np.mean(obs_data[:, parents], axis=0)

    # Check if the intervened node is among the parents used for prediction
    try:
        parent_list = parents.tolist()
        intervened_idx_in_parents = parent_list.index(intervened_node)
        pred_input_values[intervened_idx_in_parents] = intervention_value
    except ValueError:
        # Intervened node is not a direct parent of the target node
        pass # Keep the mean values for actual parents

    # Ensure pred_input_values is 2D
    if pred_input_values.ndim == 1:
        pred_input_values = pred_input_values.reshape(1, -1)

    # Predict the value
    try:
        predicted_value = model.predict(pred_input_values)[0]
    except Exception as e:
        print(f"  Regression Predict Error (Target: {target_node}, Parents: {parents}): {e}")
        return np.nan
        
    return predicted_value


def estimate_intervention_effect_ecam(obs_data, target_node, intervention_details, d_model=64, graph_reg=0.01):
    """Estimates intervention effect using regression based on graph learned by ECAM's GraphLearner."""
    if not ecam_available or GraphLearner is None:
        print("ECAM GraphLearner not available.")
        return np.nan

    n_samples, n_nodes = obs_data.shape

    # --- Learn graph using GraphLearner (similar to discovery script) ---
    # Use random features as input placeholder for GraphLearner
    node_features = torch.randn(1, n_nodes, d_model) 
    try:
        graph_learner = GraphLearner(d_model, graph_reg)
        graph_learner.eval()
        with torch.no_grad():
            G_ecam_raw = graph_learner(node_features) # G shape (1, n_nodes, n_nodes)
            adj_matrix_ecam_raw = G_ecam_raw.squeeze(0).detach().cpu().numpy()
            
            # Post-process: thresholding
            pred_adj_ecam = (adj_matrix_ecam_raw > 0.5).astype(int) # Simple thresholding
            np.fill_diagonal(pred_adj_ecam, 0) # Ensure no self-loops
    except Exception as e:
        print(f"Error running ECAM GraphLearner for intervention: {e}")
        return np.nan
    # --- Use learned graph with regression --- 
    # print(f"  (ECAM using learned graph: {np.sum(pred_adj_ecam)} edges)") # Debug print
    return estimate_intervention_effect_regression(obs_data, target_node, intervention_details, pred_adj_ecam)


# --- Main Experiment ---

def main(args):
    results = []
    synthetic_data_dir = os.path.join(project_root, "data", "synthetic")

    if not os.path.exists(synthetic_data_dir):
        print(f"Error: Synthetic data directory not found: {synthetic_data_dir}")
        print("Please run generate_synthetic_data.py first.")
        return

    num_graphs_processed = 0
    for i in range(args.num_graphs):
        graph_idx = i + 1
        graph_file = os.path.join(synthetic_data_dir, f"graph_{graph_idx}.pkl")
        obs_file = os.path.join(synthetic_data_dir, f"obs_data_{graph_idx}.npy")
        interv_file = os.path.join(synthetic_data_dir, f"interv_data_{graph_idx}.pkl")

        if not (os.path.exists(graph_file) and os.path.exists(obs_file) and os.path.exists(interv_file)):
            print(f"Warning: Data files for graph {graph_idx} not found. Skipping.")
            continue

        print(f"\n--- Processing Graph {graph_idx} ---")
        num_graphs_processed += 1

        # Load data
        try:
            with open(graph_file, "rb") as f:
                graph_data = pickle.load(f)
                G_true = graph_data["graph"]
                # Ensure nodes are sorted for consistent adjacency matrix
                sorted_nodes = sorted(G_true.nodes())
                true_adj = nx.to_numpy_array(G_true, nodelist=sorted_nodes)
                n_nodes = G_true.number_of_nodes()

            obs_data = np.load(obs_file)

            with open(interv_file, "rb") as f:
                all_interv_data = pickle.load(f) # List of dicts
        except Exception as e:
            print(f"Error loading data for graph {graph_idx}: {e}")
            continue

        # Iterate through interventions performed during data generation
        for interv_info in all_interv_data:
            intervened_node = interv_info["intervened_node"]
            intervened_value = interv_info["intervened_value"]
            true_interv_outcome_data = interv_info["data"] # Shape (n_samples_interv, n_nodes)

            intervention_details = {"node": intervened_node, "value": intervened_value}
            print(f"  Intervention: do({intervened_node}={intervened_value:.2f})", end="")

            intervention_results = []
            # Iterate through possible target nodes (excluding intervened node)
            for target_node in range(n_nodes):
                if target_node == intervened_node:
                    continue

                # Calculate true mean outcome under intervention
                true_mean_outcome = np.mean(true_interv_outcome_data[:, target_node])

                # Estimate using baseline (regression on true graph parents)
                pred_regr = estimate_intervention_effect_regression(obs_data, target_node, intervention_details, true_adj)
                mse_regr = (pred_regr - true_mean_outcome)**2 if not np.isnan(pred_regr) else np.nan

                # Estimate using ECAM (regression on ECAM-learned graph parents)
                pred_ecam = estimate_intervention_effect_ecam(obs_data, target_node, intervention_details)
                mse_ecam = (pred_ecam - true_mean_outcome)**2 if not np.isnan(pred_ecam) else np.nan

                intervention_results.append({
                    "target_node": target_node,
                    "true_mean_outcome": true_mean_outcome,
                    "pred_regr": pred_regr,
                    "mse_regr": mse_regr,
                    "pred_ecam": pred_ecam,
                    "mse_ecam": mse_ecam,
                })

            # Aggregate MSE for this intervention
            avg_mse_regr = np.nanmean([r["mse_regr"] for r in intervention_results])
            avg_mse_ecam = np.nanmean([r["mse_ecam"] for r in intervention_results])
            print(f" -> Avg MSE Regr: {avg_mse_regr:.4f}, Avg MSE ECAM: {avg_mse_ecam:.4f}")

            # Append detailed results
            for r in intervention_results:
                 results.append({
                    "graph_idx": graph_idx,
                    "intervened_node": intervened_node,
                    "intervened_value": intervened_value,
                    **r # Unpack results for this target node
                })

    if num_graphs_processed == 0:
        print("No synthetic data graphs were processed.")
        return

    # Save results
    results_df = pd.DataFrame(results)
    output_dir = args.output_dir # Use the argument directly
    os.makedirs(output_dir, exist_ok=True)
    output_path = os.path.join(output_dir, "intervention_estimation_results.csv")
    results_df.to_csv(output_path, index=False)
    print(f"\nIntervention estimation results saved to {output_path}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run Intervention Effect Estimation Experiments")
    parser.add_argument("--num_graphs", type=int, default=10, help="Number of synthetic graphs to process (should match generated data)")
    parser.add_argument("--output_dir", type=str, default="/home/ubuntu/ecam_project/results/intervention", help="Directory to save results")
    # Add other relevant arguments if needed

    args = parser.parse_args()
    main(args)

