import numpy as np
import matplotlib.pyplot as plt
import os
import re

def extract_ix_theta(MI_results, MI_labels):
    ix_theta_indices = []
    ix_theta_values = []
    for idx, label in enumerate(MI_labels):
        if label.startswith("I(X_") and "; θ" in label:
            match = re.match(r"I\(X_(\d+); θ\)", label)
            if match:
                i = int(match.group(1))
                ix_theta_indices.append(i)
                ix_theta_values.append(MI_results[idx])
    sorted_indices = np.argsort(ix_theta_indices)
    ix_theta_indices = np.array(ix_theta_indices)[sorted_indices]
    ix_theta_values = np.array(ix_theta_values)[sorted_indices]
    return ix_theta_indices, ix_theta_values

def extract_ix1_xi(MI_results, MI_labels):
    # This matches labels like I(X_1; X_3), I(X_1; X_4), ...
    ix_indices = []
    ix_values = []
    for idx, label in enumerate(MI_labels):
        # Match: I(X_1; X_i)
        match = re.match(r"I\(X_1; X_(\d+)\)", label)
        if match:
            i = int(match.group(1))
            ix_indices.append(i)
            ix_values.append(MI_results[idx])
    sorted_indices = np.argsort(ix_indices)
    ix_indices = np.array(ix_indices)[sorted_indices]
    ix_values = np.array(ix_values)[sorted_indices]
    return ix_indices, ix_values

def plot_overlay(
    npz_files,
    labels=None,
    extract_fn=None,
    ylabel="",
    title="",
    save_pdf=None
):
    plt.figure(figsize=(7, 4))
    for idx, npz_file in enumerate(npz_files):
        data = np.load(npz_file, allow_pickle=True)
        MI_results = data["MI_results"]
        MI_labels = data["MI_labels"]
        x_indices, y_values = extract_fn(MI_results, MI_labels)
        label = labels[idx] if labels is not None else os.path.basename(os.path.dirname(npz_file))
        plt.plot(x_indices, y_values, marker='o', label=label)
    plt.xlabel("i (Modality index)")
    plt.ylabel(ylabel)
    plt.title(title)
    plt.grid(True)
    plt.legend()
    plt.tight_layout()
    if save_pdf:
        plt.savefig(save_pdf, dpi=300)
    plt.show()

# ==== USAGE EXAMPLE ====

npz_files = [
    "../output_dir/datasets/multimodal_decay_DAG_alpha1.1_beta1.1/CaseConstructive_0.npz",
    "../output_dir/datasets/multimodal_decay_DAG_alpha1.2_beta1.2/CaseConstructive_0.npz",
    "../output_dir/datasets/multimodal_decay_DAG_alpha2_beta1.2/CaseConstructive_0.npz",
    "../output_dir/datasets/multimodal_decay_DAG_alpha1.2_beta2/CaseConstructive_0.npz",
    "../output_dir/datasets/multimodal_decay_DAG_alpha2_beta2/CaseConstructive_0.npz",
]

labels = [
    "a=1.1, b=1.1",
    "a=1.2, b=1.2",
    "a=2, b=1.2",
    "a=1.2, b=2",
    "a=2, b=2",
]

# Plot I(X_i; θ) vs. i
plot_overlay(
    npz_files,
    labels=labels,
    extract_fn=extract_ix_theta,
    ylabel=r"$I(X_i;\,\theta)$",
    title=r"Mutual Information $I(X_i;\theta)$ vs $i$",
    save_pdf="ix_theta_overlay.pdf"
)

# Plot I(X_1; X_i) vs. i
plot_overlay(
    npz_files,
    labels=labels,
    extract_fn=extract_ix1_xi,
    ylabel=r"$I(X_1;\,X_i)$",
    title=r"Mutual Information $I(X_1; X_i)$ vs $i$",
    save_pdf="ix1_xi_overlay.pdf"
)
