import numpy as np
import matplotlib.pyplot as plt
import json
import os

INPUT_FILENAME = "experiment_results.json"
PLOT_OUTPUT_FILENAME = "kl_divergence_vs_m_percentage_trials_mean_std_acc_report_saved.png" # Updated filename

def plot_data(data):
    if not data or "distribution_results" not in data or not data["distribution_results"]:
        print("No distribution results found in the data file.")
        return

    exp_params = data.get("experiment_params", {})
    n_trials = exp_params.get("N_TRIALS", "N/A")
    n_samples_per_trial = exp_params.get("N_SAMPLES_PER_TRIAL", "N/A")
    delta_clip = exp_params.get("DELTA_CLIP", 1e-6)

    plot_results_data = data["distribution_results"]

    plt.figure(figsize=(10, 6))

    valid_plot_entries = [entry for entry in plot_results_data if entry.get("m_values")]
    num_valid_plot_entries = len(valid_plot_entries)


    colors = plt.cm.tab10(np.linspace(0, 1, num_valid_plot_entries))

    print("\n--- Plotting KL Divergence Results ---")
    for i, result_entry in enumerate(valid_plot_entries):
        label = result_entry["label"]
        m_vals_abs = result_entry["m_values"] # Absolute m values
        mean_kl_vals = result_entry["mean_kl_values"]
        std_kl_vals = result_entry["std_kl_values"]

        if not m_vals_abs:
            print(f"No m_values to plot for {label}, skipping.")
            continue

        # --- MODIFICATION: Convert m_values to percentage ---
        if not m_vals_abs: # Should be caught by outer valid_plot_entries filter, but good for safety
             print(f"No m_values to plot for {label}, skipping.")
             continue

        max_m_for_dist = float(max(m_vals_abs)) if m_vals_abs else 0
        
        if max_m_for_dist == 0:
            # Handle cases like m_vals_abs = [0] or empty after filtering
            # If m_vals_abs is just [0], all percentages are 0.
            # If it's truly empty, it should have been skipped.
            # For safety, if max_m is 0, treat percentages as 0 or skip.
            if len(m_vals_abs) == 1 and m_vals_abs[0] == 0:
                 m_vals_percentage = [0.0]
            else:
                print(f"Warning: Max m_value is 0 for {label} with m_values {m_vals_abs}. Skipping this line.")
                continue # Skip this entry if max_m is 0 and m_vals isn't just [0]
        else:
            m_vals_percentage = [(m / max_m_for_dist) * 100 for m in m_vals_abs]
        # --- END MODIFICATION ---

        mean_kl_vals_np = np.array(mean_kl_vals)
        std_kl_vals_np = np.array(std_kl_vals)

        # Use m_vals_percentage for plotting
        plt.plot(m_vals_percentage, mean_kl_vals_np, linestyle='-', label=label, color=colors[i])

        if std_kl_vals_np is not None and len(std_kl_vals_np) == len(mean_kl_vals_np):
            fill_lower_bound = np.maximum(mean_kl_vals_np - std_kl_vals_np, delta_clip / 100)
            fill_upper_bound = mean_kl_vals_np + std_kl_vals_np

            plt.fill_between(m_vals_percentage, # Use m_vals_percentage
                             fill_lower_bound,
                             fill_upper_bound,
                             color=colors[i], alpha=0.2, interpolate=False)
        else:
            print(f"Warning: Standard deviation data missing or mismatched for {label}. Plotting mean only.")
            exit()


    plt.xlabel('Percentage of Trees in Ensemble (%)', fontsize=28) # Updated label
    plt.ylabel('Average KL Divergence', fontsize=28)
    plt.title(f'Information Leakage vs. %trees', fontsize=28) # Slightly updated title
    plt.legend(fontsize=28)
    plt.grid(True, which="both", ls="-", alpha=0.5)
    # plt.yscale('log')
    plt.xlim(-2, 100) # Set x-limits for percentage axis
    plt.ylim(0, 2.5)
    # plt.ylim(7e-2,1.0)
    # plt.ylim(bottom=1e-1, top=1.0) # Use delta_clip for a more robust lower bound
    plt.tight_layout(rect=[0.03, 0.03, 0.97, 0.97])
    plt.xticks(fontsize=24)
    plt.yticks(fontsize=24)

    try:
        plt.savefig(PLOT_OUTPUT_FILENAME)
        print(f"\nPlot saved as {PLOT_OUTPUT_FILENAME}")
    except Exception as e:
        print(f"Error saving plot: {e}")

    print("\n--- Accuracy Summaries (from saved data) ---")
    for result_entry in plot_results_data: # Iterate through original plot_results_data for summaries
        label = result_entry["label"]
        acc_summary = result_entry.get("accuracy_summary", {})
        mean_acc = acc_summary.get("mean_training_accuracy")
        std_acc = acc_summary.get("std_dev_training_accuracy")
        num_trials_acc = acc_summary.get("num_successful_trials_for_accuracy", 0)

        if mean_acc is not None:
            print(f"  {label}:")
            print(f"    Mean training accuracy across {num_trials_acc} successful trials: {mean_acc:.4f}")
            if std_acc is not None:
                print(f"    Std Dev of training accuracy across trials: {std_acc:.4f}")
        else:
            print(f"  {label}: No accuracy data recorded.")


if __name__ == "__main__":
    try:
        with open(INPUT_FILENAME, 'r') as f:
            results = json.load(f)
        plot_data(results)
    except json.JSONDecodeError as e:
        print(f"Error decoding JSON from {INPUT_FILENAME}: {e}")
    except IOError as e:
        print(f"Error reading file {INPUT_FILENAME}: {e}")
