#!/usr/bin/env python
# coding: utf-8

"""Script to generate visualizations and tables from experiment results."""

import os
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

# Configure plotting style
sns.set_theme(style="whitegrid")
plt.rcParams["figure.figsize"] = (10, 6) # Default figure size

# --- Configuration ---
RESULTS_DIR = "/home/ubuntu/ecam_project/results"
PLOTS_DIR = os.path.join(RESULTS_DIR, "plots")
TABLES_DIR = os.path.join(RESULTS_DIR, "tables")

# Create output directories if they don't exist
os.makedirs(PLOTS_DIR, exist_ok=True)
os.makedirs(TABLES_DIR, exist_ok=True)

# --- Causal Discovery Visualization ---
def visualize_causal_discovery(results_file):
    print("\n--- Visualizing Causal Discovery Results ---")
    try:
        df = pd.read_csv(results_file)
    except FileNotFoundError:
        print(f"Error: Causal discovery results file not found at {results_file}")
        return
    except pd.errors.EmptyDataError:
        print(f"Error: Causal discovery results file is empty: {results_file}")
        return

    # Filter for synthetic data
    df_synthetic = df[df["dataset"] == "synthetic"].copy()
    df_tuebingen = df[df["dataset"] == "tuebingen"].copy()

    # --- Synthetic Data Plot (All Models) ---
    if not df_synthetic.empty:
        # Calculate mean metrics grouped by graph type and model
        synthetic_summary = df_synthetic.groupby(["graph_type", "model"])[["shd", "f1", "precision", "recall"]].mean().reset_index()
        
        # Handle potential NaN values from failed runs (like GES)
        synthetic_summary = synthetic_summary.dropna(subset=["shd", "f1"])

        if not synthetic_summary.empty:
            print("Generating plot for Causal Discovery (Synthetic Data - All Models)...")
            fig, axes = plt.subplots(1, 2, figsize=(15, 6), sharey=False)
            fig.suptitle("Causal Discovery Performance on Synthetic Data")

            # SHD Plot
            sns.barplot(ax=axes[0], x="graph_type", y="shd", hue="model", data=synthetic_summary, palette="viridis")
            axes[0].set_title("Structural Hamming Distance (SHD) - Lower is Better")
            axes[0].set_xlabel("Graph Type")
            axes[0].set_ylabel("Average SHD")
            axes[0].legend(title="Model")

            # F1 Score Plot
            sns.barplot(ax=axes[1], x="graph_type", y="f1", hue="model", data=synthetic_summary, palette="viridis")
            axes[1].set_title("F1 Score - Higher is Better")
            axes[1].set_xlabel("Graph Type")
            axes[1].set_ylabel("Average F1 Score")
            axes[1].set_ylim(0, 1) # F1 score is between 0 and 1
            axes[1].legend(title="Model")

            plt.tight_layout(rect=[0, 0.03, 1, 0.95]) # Adjust layout
            plot_path = os.path.join(PLOTS_DIR, "causal_discovery_synthetic_all_models.png")
            plt.savefig(plot_path)
            plt.close()
            print(f"Saved plot to {plot_path}")
        else:
            print("No valid results found for synthetic data after filtering NaNs.")
    else:
        print("No synthetic data results found in the causal discovery file.")

    # --- Tuebingen Data Summary (Table - All Models) ---
    if not df_tuebingen.empty:
        # Calculate mean metrics for Tuebingen data
        # Rename 'correct_direction' to 'Pairwise Accuracy'
        df_tuebingen = df_tuebingen.rename(columns={"correct_direction": "Pairwise Accuracy"})
        tuebingen_summary = df_tuebingen.groupby("model")[["Pairwise Accuracy", "f1"]].mean().reset_index()
        tuebingen_summary = tuebingen_summary.dropna()

        if not tuebingen_summary.empty:
            print("Generating table for Causal Discovery (Tübingen Data - All Models)...")
            table_path = os.path.join(TABLES_DIR, "causal_discovery_tuebingen_summary.csv")
            tuebingen_summary.to_csv(table_path, index=False, float_format='%.3f')
            print(f"Saved table to {table_path}")
        else:
            print("No valid results found for Tuebingen data after filtering NaNs.")
    else:
        print("No Tuebingen data results found in the causal discovery file.")

# --- Intervention Effect Estimation Visualization ---
def visualize_intervention(results_file):
    print("\n--- Visualizing Intervention Effect Estimation Results ---")
    try:
        df = pd.read_csv(results_file)
    except FileNotFoundError:
        print(f"Error: Intervention results file not found at {results_file}")
        return
    except pd.errors.EmptyDataError:
        print(f"Error: Intervention results file is empty: {results_file}")
        return

    # Prepare data for plotting: melt MSE columns
    df_melt = df.melt(id_vars=["graph_idx", "intervened_node", "target_node", "true_mean_outcome"],
                      value_vars=["mse_regr", "mse_ecam"],
                      var_name="model", value_name="mse")
    df_melt["model"] = df_melt["model"].replace({"mse_regr": "Regression (True Graph)", "mse_ecam": "ECAM (Learned Graph)"})
    df_melt = df_melt.dropna(subset=["mse"])

    if not df_melt.empty:
        # Calculate average MSE per graph per model
        avg_mse_per_graph = df_melt.groupby(["graph_idx", "model"])["mse"].mean().reset_index()

        print("Generating plot for Intervention Estimation MSE Distribution (Regression vs ECAM)...")
        plt.figure(figsize=(10, 7))
        # Use boxplot to compare distributions
        sns.boxplot(data=avg_mse_per_graph, x="model", y="mse", palette="viridis")
        # Optional: Add stripplot for individual points
        # sns.stripplot(data=avg_mse_per_graph, x="model", y="mse", color=".25", size=3)
        plt.title("Distribution of Average MSE for Intervention Estimation")
        plt.xlabel("Model")
        plt.ylabel("Average Mean Squared Error (MSE) per Graph")
        plt.yscale("log") # Use log scale if MSE varies widely
        plt.tight_layout()
        plot_path = os.path.join(PLOTS_DIR, "intervention_estimation_mse_comparison.png")
        plt.savefig(plot_path)
        plt.close()
        print(f"Saved plot to {plot_path}")

        # Generate Scatter Plots: Predicted vs True for both models
        # Sample a subset of points to avoid overplotting if data is large
        df_pred = df.dropna(subset=["pred_regr", "pred_ecam"])
        sample_size = min(500, len(df_pred))
        df_sample = df_pred.sample(n=sample_size, random_state=42)

        print("Generating scatter plots for Intervention Estimation (Predicted vs True - Sample)...")
        fig, axes = plt.subplots(1, 2, figsize=(16, 8), sharex=True, sharey=True)
        fig.suptitle("Intervention Estimation: Predicted vs True (Sample)")

        # Regression Scatter
        sns.scatterplot(ax=axes[0], data=df_sample, x="true_mean_outcome", y="pred_regr", alpha=0.6)
        axes[0].set_title("Regression (True Graph)")
        axes[0].set_xlabel("True Mean Outcome")
        axes[0].set_ylabel("Predicted Mean Outcome")
        axes[0].grid(True)

        # ECAM Scatter
        sns.scatterplot(ax=axes[1], data=df_sample, x="true_mean_outcome", y="pred_ecam", alpha=0.6)
        axes[1].set_title("ECAM (Learned Graph)")
        axes[1].set_xlabel("True Mean Outcome")
        axes[1].set_ylabel("Predicted Mean Outcome")
        axes[1].grid(True)

        # Add diagonal line for reference to both plots
        lims = [
            min(df_sample["true_mean_outcome"].min(), df_sample["pred_regr"].min(), df_sample["pred_ecam"].min()),
            max(df_sample["true_mean_outcome"].max(), df_sample["pred_regr"].max(), df_sample["pred_ecam"].max()),
        ]
        # Add some padding to limits
        padding = (lims[1] - lims[0]) * 0.05
        lims = [lims[0] - padding, lims[1] + padding]
        
        for ax in axes:
            ax.plot(lims, lims, 'r--', alpha=0.75, zorder=0)
            ax.set_xlim(lims)
            ax.set_ylim(lims)

        plt.tight_layout(rect=[0, 0.03, 1, 0.95])
        plot_path = os.path.join(PLOTS_DIR, "intervention_estimation_scatter_comparison.png")
        plt.savefig(plot_path)
        plt.close()
        print(f"Saved plot to {plot_path}")

    else:
        print("No valid intervention estimation results found after melting/filtering.")

# --- Counterfactual Reasoning Summary ---
def summarize_counterfactual(results_file):
    print("\n--- Summarizing Counterfactual Reasoning Results ---")
    try:
        df = pd.read_csv(results_file)
    except FileNotFoundError:
        print(f"Error: Counterfactual results file not found at {results_file}")
        return
    except pd.errors.EmptyDataError:
        print(f"Error: Counterfactual results file is empty: {results_file}")
        return

    # Report summary stats of true values
    if "true_cf_value" in df.columns and not df["true_cf_value"].isnull().all():
        print("Generating summary table for True Counterfactual Values...")
        summary_stats = df["true_cf_value"].describe().reset_index()
        summary_stats.columns = ["Statistic", "Value"]
        table_path = os.path.join(TABLES_DIR, "counterfactual_true_values_summary.csv")
        summary_stats.to_csv(table_path, index=False, float_format='%.3f')
        print(f"Saved summary table to {table_path}")
    else:
        print("No valid true counterfactual values found in the results file.")
        
    # Report overall MSE if available (even if NaN)
    if "mse_ecam" in df.columns:
        avg_mse_ecam = df["mse_ecam"].mean()
        print(f"Overall Average MSE ECAM: {avg_mse_ecam:.4f}") # Will print nan if all are nan
    if "mse_baseline" in df.columns:
        avg_mse_baseline = df["mse_baseline"].mean()
        print(f"Overall Average MSE Baseline: {avg_mse_baseline:.4f}") # Will print nan if all are nan

# --- Main Execution ---
if __name__ == "__main__":
    causal_discovery_results_path = os.path.join(RESULTS_DIR, "causal_discovery", "causal_discovery_results.csv")
    intervention_results_path = os.path.join(RESULTS_DIR, "intervention", "intervention_estimation_results.csv")
    counterfactual_results_path = os.path.join(RESULTS_DIR, "counterfactual", "counterfactual_estimation_results.csv")

    visualize_causal_discovery(causal_discovery_results_path)
    visualize_intervention(intervention_results_path)
    summarize_counterfactual(counterfactual_results_path)

    print("\nVisualization and summary generation complete.")

