#!/usr/bin/env python
# coding: utf-8

"""Script to run causal discovery experiments using ECAM and baseline models."""

import os
import sys
import time
import argparse
import numpy as np
import pandas as pd
import networkx as nx
import pickle
import torch
import torch.nn as nn # Added import for nn.Linear in run_ecam
from sklearn.metrics import f1_score, precision_recall_fscore_support
from scipy.linalg import expm

# Add project root to path to import custom modules
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.insert(0, project_root)

# Causal-learn imports (PC, GES)
try:
    from causallearn.search.ConstraintBased.PC import pc
    from causallearn.search.ScoreBased.GES import ges
    from causallearn.utils.GraphUtils import GraphUtils
    from causallearn.utils.cit import fisherz, chisq, gsq
    from causallearn.graph.GeneralGraph import GeneralGraph
    from causallearn.graph.GraphNode import GraphNode
    causal_learn_available = True
except ImportError:
    print("Warning: causal-learn library not found or specific modules failed to import. PC and GES baselines will be unavailable.")
    pc = None
    ges = None
    causal_learn_available = False

# NOTEARS import
try:
    # Assuming notears is installed from the local libs directory or user site
    from notears import utils as notears_utils
    from notears import linear as notears_linear
    notears_available = True
except ImportError:
    print("Warning: notears library not found. NOTEARS baseline will be unavailable.")
    notears_linear = None
    notears_available = False

# ECAM imports
try:
    from src.ecam import ECAM
    from src.graph_learner import GraphLearner # Assuming ECAM uses this internally
    ecam_available = True
except ImportError:
    print("Warning: ECAM source files not found or import failed. ECAM model will be unavailable.")
    ECAM = None
    ecam_available = False

# --- Helper Functions --- 

def adj_matrix_from_causallearn_graph(graph_obj, n_nodes):
    """Converts causal-learn graph object (like cg.G from PC or record['G'] from GES) to numpy adjacency matrix."""
    adj_matrix = np.zeros((n_nodes, n_nodes), dtype=int)
    nodes = graph_obj.get_nodes()
    node_name_to_index = {node.get_name(): i for i, node in enumerate(nodes)}

    # Check if node names are like 'X1', 'X2', etc.
    if all(name.startswith('X') for name in node_name_to_index.keys()):
        try:
            # Attempt to map 'X1' -> 0, 'X2' -> 1, etc.
            node_name_to_index = {name: int(name[1:]) - 1 for name in node_name_to_index.keys()}
        except ValueError:
            print("Warning: Could not convert node names like 'Xn' to zero-based indices. Adjacency matrix might be incorrect.")
            # Fallback or raise error? For now, proceed assuming names match indices if conversion fails.
            node_name_to_index = {node.get_name(): i for i, node in enumerate(nodes)}

    for i in range(n_nodes):
        for j in range(n_nodes):
            if i == j:
                continue
            node_i = nodes[i]
            node_j = nodes[j]
            # Check for edge i -> j
            if graph_obj.get_edge(node_i, node_j) is not None and graph_obj.get_endpoint(node_i, node_j) == "ARROW":
                 # Map node names back to indices if necessary
                 idx_i = node_name_to_index.get(node_i.get_name(), i)
                 idx_j = node_name_to_index.get(node_j.get_name(), j)
                 if 0 <= idx_i < n_nodes and 0 <= idx_j < n_nodes:
                     adj_matrix[idx_i, idx_j] = 1
                 else:
                     print(f"Warning: Index out of bounds ({idx_i}, {idx_j}) for node names {node_i.get_name()}, {node_j.get_name()}")
            # Check for edge j -> i (redundant if we iterate through all pairs, but safe)
            elif graph_obj.get_edge(node_j, node_i) is not None and graph_obj.get_endpoint(node_j, node_i) == "ARROW":
                 idx_i = node_name_to_index.get(node_i.get_name(), i)
                 idx_j = node_name_to_index.get(node_j.get_name(), j)
                 if 0 <= idx_i < n_nodes and 0 <= idx_j < n_nodes:
                     adj_matrix[idx_j, idx_i] = 1
                 else:
                     print(f"Warning: Index out of bounds ({idx_j}, {idx_i}) for node names {node_j.get_name()}, {node_i.get_name()}")

    return adj_matrix

def calculate_shd(true_adj, pred_adj):
    """Calculate Structural Hamming Distance (SHD)."""
    if pred_adj is None or true_adj.shape != pred_adj.shape:
        return np.nan
    # Ensure binary matrices
    true_adj_bin = (true_adj > 0).astype(int)
    pred_adj_bin = (pred_adj > 0).astype(int)
    
    diff = np.abs(true_adj_bin - pred_adj_bin)
    # Difference in undirected skeletons (where edge exists in one but not the other)
    diff_skeleton = ((true_adj_bin + true_adj_bin.T) > 0) != ((pred_adj_bin + pred_adj_bin.T) > 0)
    shd_skeleton = np.sum(diff_skeleton) / 2
    
    # Difference in directed edges (present in both skeletons but wrong direction)
    # Edges present in both skeletons
    common_skeleton = ((true_adj_bin + true_adj_bin.T) > 0) & ((pred_adj_bin + pred_adj_bin.T) > 0)
    # Edges with different directionality within the common skeleton
    wrong_direction = np.sum(common_skeleton & (true_adj_bin != pred_adj_bin)) / 2
    
    shd = shd_skeleton + wrong_direction
    return int(shd)

def calculate_f1(true_adj, pred_adj):
    """Calculate F1 score for edge prediction (presence/absence)."""
    if pred_adj is None or true_adj.shape != pred_adj.shape:
        return np.nan, np.nan, np.nan
        
    # Consider edge presence/absence (skeleton)
    true_flat_skeleton = ((true_adj + true_adj.T) > 0).astype(int).flatten()
    pred_flat_skeleton = ((pred_adj + pred_adj.T) > 0.5).astype(int).flatten()
    
    # Exclude diagonal elements
    n = true_adj.shape[0]
    mask = ~np.eye(n, dtype=bool).flatten()
    true_flat_skeleton = true_flat_skeleton[mask]
    pred_flat_skeleton = pred_flat_skeleton[mask]

    precision, recall, f1, _ = precision_recall_fscore_support(true_flat_skeleton, pred_flat_skeleton, average="binary", zero_division=0)
    return precision, recall, f1

# --- Baseline Model Implementations/Wrappers --- 

def run_pc(data, alpha=0.05):
    """Run PC algorithm.
    Returns adjacency matrix.
    """
    if not causal_learn_available: return None
    n_nodes = data.shape[1]
    try:
        # Assuming continuous data, use fisherz test
        cg = pc(data, alpha=alpha, indep_test=fisherz, stable=True, uc_rule=0, uc_priority=2, verbose=False)
        # cg.G is the graph object
        adj_matrix = adj_matrix_from_causallearn_graph(cg.G, n_nodes)
        return adj_matrix
    except Exception as e:
        print(f"Error running PC: {e}")
        return None

def run_ges(data, score_func="local_score_bic", parameters=None):
    """Run GES algorithm.
    Returns adjacency matrix.
    """
    if not causal_learn_available: return None
    n_nodes = data.shape[1]
    try:
        # Default parameters for continuous data with BIC score
        default_params = {"kfold": 5, "lambda": 0.01}
        if parameters is None:
            parameters = default_params
            
        # Removed verbose=False as it caused error
        record = ges(data, score_func=score_func, maxP=None, parameters=parameters)
        # record["G"] is a GeneralGraph object
        adj_matrix = adj_matrix_from_causallearn_graph(record["G"], n_nodes)
        return adj_matrix
    except Exception as e:
        print(f"Error running GES: {e}")
        return None

def run_notears_linear(data, lambda1=0.1, loss_type="l2"):
    """Run NOTEARS linear algorithm.
    Returns adjacency matrix (W_est).
    """
    if not notears_available:
        print("NOTEARS library not available, skipping.")
        return None
    try:
        X = data.astype(np.float64) # Ensure float64 for NOTEARS
        W_est = notears_linear.notears_linear(X, lambda1=lambda1, loss_type=loss_type, max_iter=100, h_tol=1e-8, rho_max=1e+16, w_threshold=0.3)
        # Thresholding is done inside notears_linear based on w_threshold
        return (np.abs(W_est) > 0).astype(int)
    except Exception as e:
        print(f"Error running NOTEARS Linear: {e}")
        return None

# TODO: Implement Attn+Mask baseline
def run_attn_mask(data):
    print("Attn+Mask baseline not implemented yet.")
    return None

# TODO: Implement CATT baseline
def run_catt(data):
    print("CATT baseline not implemented yet.")
    return None

# TODO: Implement ECAM wrapper for causal discovery
def run_ecam(data, d_model=64, n_heads=4, graph_reg=0.01, epochs=10, lr=1e-3):
    """Runs ECAM for causal discovery by training its graph learner."""
    print("ECAM wrapper for causal discovery - Basic Implementation")
    if not ecam_available: return None
    
    n_samples, n_nodes = data.shape
    # ECAM expects (batch, seq_len, d_model). We treat nodes as seq_len.
    # We need an embedding layer if input data isn't already high-dimensional.
    
    # Simple embedding placeholder if d_model != n_nodes
    if n_nodes != d_model:
        print(f"Warning: Data dimension ({n_nodes}) != d_model ({d_model}). Using simple embedding.")
        embedding = nn.Linear(n_nodes, d_model)
        # Use mean of samples as input features for embedding? Or treat samples as batch?
        # Let's treat samples as batch for now, embedding node dimension
        # This assumes data is (n_samples, n_nodes)
        # We need input as (batch_size, n_nodes, feature_dim_per_node)
        # Let's assume feature_dim_per_node = 1 for simplicity
        input_data_tensor = torch.tensor(data, dtype=torch.float32).unsqueeze(-1) # (n_samples, n_nodes, 1)
        # Need to decide how to get a single (batch, n_nodes, d_model) tensor for GraphLearner
        # Option 1: Average over samples? (loses sample variance)
        # Option 2: Use a different embedding strategy or adapt GraphLearner
        # Option 3: Train ECAM properly (requires more setup)
        
        # Using Option 1 for now as a placeholder:
        avg_features = torch.mean(torch.tensor(data, dtype=torch.float32), dim=0, keepdim=True) # (1, n_nodes)
        # Requires embedding from n_nodes -> d_model, but GraphLearner takes (batch, n_nodes, d_model)
        # Let's adjust GraphLearner input expectation or this wrapper
        # Assuming GraphLearner can take (batch, n_nodes, input_feature_dim)
        # If input_feature_dim=1, we need embedding inside GraphLearner or here.
        # Let's stick to the original GraphLearner expecting d_model input features.
        # We need a way to represent each node with d_model features.
        # Using random features as a dummy representation for now.
        node_features = torch.randn(1, n_nodes, d_model) # (batch=1, n_nodes, d_model)
        print("Using random node features as input to GraphLearner.")
    else:
        # If n_nodes == d_model, maybe treat samples as batch?
        # input_tensor = torch.tensor(data, dtype=torch.float32).unsqueeze(1) # (n_samples, 1, d_model)? No.
        # Treat nodes as sequence, samples as batch?
        # input_tensor = torch.tensor(data, dtype=torch.float32) # (n_samples, n_nodes=d_model)
        # Need (batch, seq_len, d_model). Let batch=1, seq_len=n_nodes=d_model
        # Requires a way to represent the graph structure from samples.
        # Using average features again as a placeholder.
        avg_features = torch.mean(torch.tensor(data, dtype=torch.float32), dim=0, keepdim=True) # (1, n_nodes=d_model)
        node_features = avg_features.unsqueeze(1) # (1, 1, d_model)? No.
        # Let's use the random features approach consistently for now.
        node_features = torch.randn(1, n_nodes, d_model)
        print("Using random node features as input to GraphLearner.")

    try:
        # Instantiate GraphLearner
        graph_learner = GraphLearner(d_model, graph_reg)
        optimizer = torch.optim.Adam(graph_learner.parameters(), lr=lr)

        # Placeholder: No actual training, just get initial graph
        graph_learner.eval()
        with torch.no_grad():
            G = graph_learner(node_features) # G shape (1, n_nodes, n_nodes)
            adj_matrix = G.squeeze(0).detach().cpu().numpy()
            
            # Post-process: thresholding
            pred_adj = (adj_matrix > 0.5).astype(int) # Simple thresholding
            np.fill_diagonal(pred_adj, 0) # Ensure no self-loops
            return pred_adj

    except Exception as e:
        print(f"Error running ECAM for discovery: {e}")
        return None

# --- Main Experiment --- 

def main(args):
    results = []
    
    # Define datasets to process
    print("\n--- Running on Synthetic Data ---")
    num_graphs = args.num_graphs # Use arg
    node_counts = [10, 20] # Example node counts
    graph_types = ["er", "sf"]
    edge_probs = [0.2]
    
    # Ensure synthetic data generation functions are available
    try:
        from scripts.generate_synthetic_data import generate_random_dag, generate_weights, generate_scm_data
        synthetic_data_available = True
    except ImportError:
        print("Error: Could not import synthetic data generation functions. Skipping synthetic data experiments.")
        synthetic_data_available = False

    if synthetic_data_available:
        for n in node_counts:
            for graph_type in graph_types:
                for p in edge_probs: # p is edge prob for ER, m for SF (adjust generation logic)
                    for i in range(num_graphs):
                        graph_idx = i # Use 0-based index internally
                        print(f"\nProcessing Synthetic Graph: n={n}, type={graph_type}, p={p}, graph_idx={graph_idx}")
                        # 1. Generate Graph and Data
                        try:
                            G_true = generate_random_dag(n, p, graph_type=graph_type)
                            true_adj = nx.to_numpy_array(G_true, nodelist=range(n))
                            weights = generate_weights(G_true)
                            data = generate_scm_data(G_true, weights, n_samples=args.num_samples, noise_scale=0.1)
                        except Exception as e:
                            print(f"Error generating data for graph {graph_idx}: {e}. Skipping.")
                            continue
                        
                        # 2. Run Baselines & ECAM
                        models = {
                            "PC": run_pc,
                            "GES": run_ges,
                            "NOTEARS_Linear": run_notears_linear,
                            # "Attn+Mask": run_attn_mask, # TODO
                            # "CATT": run_catt, # TODO
                            "ECAM": run_ecam # Basic implementation
                        }
                        
                        for name, model_func in models.items():
                            start_time = time.time()
                            pred_adj = model_func(data)
                            duration = time.time() - start_time
                            
                            # 3. Evaluate
                            shd = calculate_shd(true_adj, pred_adj)
                            prec, rec, f1 = calculate_f1(true_adj, pred_adj)
                            
                            if pred_adj is not None:
                                print(f"  {name}: SHD={shd}, F1={f1:.3f}, Precision={prec:.3f}, Recall={rec:.3f}, Time={duration:.2f}s")
                            else:
                                print(f"  {name}: Failed or not implemented.")
                                # Ensure metrics are NaN if pred_adj is None
                                shd, prec, rec, f1 = np.nan, np.nan, np.nan, np.nan
                                
                            results.append({
                                "dataset": "synthetic",
                                "graph_type": graph_type,
                                "n_nodes": n,
                                "p_or_m": p,
                                "graph_idx": graph_idx,
                                "model": name,
                                "shd": shd,
                                "f1": f1,
                                "precision": prec,
                                "recall": rec,
                                "time": duration
                            })

    # --- Tuebingen Data --- 
    print("\n--- Running on Tuebingen Data ---")
    # Ensure the path uses the processed data directory structure from previous steps
    tuebingen_dir = os.path.join(project_root, "data", "real_world", "tuebingen", "datasets")
    # Ground truth needs to be loaded properly. Placeholder for now.
    tuebingen_gt_map = {} # Needs actual loading
    # Attempt to load from causal_description.Rmd or a potential ground truth file
    # This part requires parsing the specific format of the ground truth source.
    # For now, we proceed without ground truth evaluation for Tuebingen.
    print(f"Warning: Tuebingen ground truth loading not implemented. Evaluation metrics will be NaN.")
    can_evaluate_tuebingen = False

    processed_files = []
    if os.path.exists(tuebingen_dir):
        processed_files = [f for f in os.listdir(tuebingen_dir) if f.startswith("causal_tubingen") and f.endswith(".csv")]
        print(f"Found {len(processed_files)} processed Tuebingen CSV files.")
    else:
        print(f"Warning: Tuebingen data directory not found: {tuebingen_dir}")

    for filename in processed_files:
        try:
            pair_index_str = filename.split("tubingen")[-1].split(".")[0]
            pair_index = int(pair_index_str)
        except:
            print(f"Warning: Could not parse pair index from {filename}. Skipping.")
            continue
            
        file_path = os.path.join(tuebingen_dir, filename)
        try:
            data = pd.read_csv(file_path).values
            if data.shape[1] != 2:
                 print(f"Warning: Tuebingen file {filename} has shape {data.shape}, expected 2 columns. Skipping.")
                 continue
            # Normalize data
            if np.std(data, axis=0).any() != 0:
                 data = (data - np.mean(data, axis=0)) / np.std(data, axis=0)
            else:
                 print(f"Warning: Data for pair {pair_index} has zero standard deviation. Skipping normalization.")

        except Exception as e:
            print(f"Error reading or processing Tuebingen file {filename}: {e}. Skipping.")
            continue

        # Define placeholder true_adj if ground truth is unknown
        true_adj = np.array([[0, 0], [0, 0]]) 
        
        print(f"\nProcessing Tuebingen Pair: index={pair_index}")
        
        # Run Baselines (PC, GES, NOTEARS should work for bivariate)
        models = {
            "PC": run_pc,
            "GES": run_ges,
            "NOTEARS_Linear": run_notears_linear,
            # Add other relevant bivariate methods if needed
        }
        
        for name, model_func in models.items():
            start_time = time.time()
            pred_adj = model_func(data)
            duration = time.time() - start_time
            
            # Evaluate (SHD, F1)
            shd = calculate_shd(true_adj, pred_adj)
            prec, rec, f1 = calculate_f1(true_adj, pred_adj)
            correct_direction = int(np.array_equal(pred_adj, true_adj)) if pred_adj is not None and can_evaluate_tuebingen else np.nan
            
            if pred_adj is not None:
                 f1_str = f"{f1:.3f}" if can_evaluate_tuebingen and not np.isnan(f1) else 'N/A'
                 shd_str = f"{shd}" if can_evaluate_tuebingen and not np.isnan(shd) else 'N/A'
                 cd_str = f"{correct_direction}" if can_evaluate_tuebingen and not np.isnan(correct_direction) else 'N/A'
                 print(f"  {name}: SHD={shd_str}, F1={f1_str}, CorrectDir={cd_str}, Time={duration:.2f}s")
            else:
                 print(f"  {name}: Failed or not implemented for pair {pair_index}.")
                 shd, prec, rec, f1, correct_direction = np.nan, np.nan, np.nan, np.nan, np.nan

            results.append({
                "dataset": "tuebingen",
                "graph_type": "pair",
                "n_nodes": 2,
                "p_or_m": pair_index, # Use pair index for reference
                "graph_idx": pair_index,
                "model": name,
                "shd": shd if can_evaluate_tuebingen else np.nan,
                "f1": f1 if can_evaluate_tuebingen else np.nan,
                "precision": prec if can_evaluate_tuebingen else np.nan,
                "recall": rec if can_evaluate_tuebingen else np.nan,
                "correct_direction": correct_direction if can_evaluate_tuebingen else np.nan,
                "time": duration
            })

    # Save results
    results_df = pd.DataFrame(results)
    # Ensure output directory exists (use specific subdir for clarity)
    output_dir = os.path.join(args.output_dir, "causal_discovery")
    os.makedirs(output_dir, exist_ok=True)
    output_path = os.path.join(output_dir, "causal_discovery_results.csv")
    # Overwrite previous results file
    results_df.to_csv(output_path, index=False)
    print(f"\nCausal discovery results saved to {output_path}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run Causal Discovery Experiments")
    parser.add_argument("--num_samples", type=int, default=1000, help="Number of samples for synthetic data")
    parser.add_argument("--num_graphs", type=int, default=5, help="Number of synthetic graphs per configuration") # Reduced default
    parser.add_argument("--output_dir", type=str, default="/home/ubuntu/ecam_project/results", help="Base directory to save results")
    # Add other relevant arguments (e.g., specific models to run, dataset paths)
    args = parser.parse_args()
    
    main(args)

