import matplotlib.pyplot as plt
import numpy as np
import os
import scienceplots

# Try to use science style with no-latex option (avoids LaTeX errors)
try:
    plt.style.use(['science', 'no-latex'])
except Exception as e:
    print(f"Warning: Could not apply science style: {e}")
    # Basic clean style fallback if SciencePlots fails
    plt.rcParams.update({
        'font.family': 'serif',
        'font.serif': ['Times New Roman', 'Computer Modern Roman'],
        'font.size': 9,  # Standard size for paper figures
        'axes.labelsize': 10,
        'legend.fontsize': 9,
        'xtick.labelsize': 8,
        'ytick.labelsize': 8
    })
PLOT = "pentagon"  # Options: "sharp", "pentagon", "star"

def make_monotonic_full(costs):
    monotonic = []
    min_cost = float('inf')
    for c in costs:
        if c < min_cost:
            min_cost = c
        monotonic.append(min_cost)
    return np.array(monotonic)

# --- Configuration ---
# List of tuples: (mean_file_path, std_file_path, label_for_legend)
# Set std_file_path to None if you don't want to show std for that dataset

if PLOT == "sharp":
    data_files = [
        ("../plots/statistics/sharp_cma_mean.npy", 
        "../plots/statistics/sharp_cma_std.npy", 
        "CMA"),
        
        ("../plots/statistics/sharp_hc_mean.npy", 
        "../plots/statistics/sharp_hc_std.npy", 
        "HC"),

        ("../plots/statistics/sharp_random_mean.npy", 
        "../plots/statistics/sharp_random_std.npy", 
        "RS"),
    ]

elif PLOT == "pentagon":
    data_files = [
        ("../plots/statistics/pentagon_cma_mean.npy", 
         "../plots/statistics/pentagon_cma_std.npy", 
         "CMA"),
        
        ("../plots/statistics/pentagon_hc_mean.npy", 
         "../plots/statistics/pentagon_hc_std.npy", 
         "HC"),

        ("../plots/statistics/pentagon_random_mean.npy", 
         "../plots/statistics/pentagon_random_std.npy", 
         "RS"),
    ]

elif PLOT == "star":
    data_files = [
        ("../plots/statistics/star_cma_mean.npy", 
        "../plots/statistics/star_cma_std.npy", 
        "CMA"),
        
        ("../plots/statistics/star_hc_mean.npy", 
        "../plots/statistics/star_hc_std.npy", 
        "HC"),

        ("../plots/statistics/star_random_mean.npy", 
        "../plots/statistics/star_random_std.npy", 
        "RS"),
    ]

elif PLOT == "circle":
    data_files = [
        # ("../plots/statistics/star_cma_mean.npy", 
        # "../plots/statistics/star_cma_std.npy", 
        # "CMA"),
        
        ("../plots/statistics/circle_hc_mean.npy", 
        "../plots/statistics/circle_hc_std.npy", 
        "HC"),

        ("../plots/statistics/circle_random_mean.npy", 
        "../plots/statistics/circle_random_std.npy", 
        "RS"),
    ]

elif PLOT == "line":
    data_files = [
        ("../plots/statistics/line_cma_mean.npy", 
        "../plots/statistics/line_cma_std.npy", 
        "CMA"),
        
        ("../plots/statistics/line_hc_mean.npy", 
        "../plots/statistics/line_hc_std.npy", 
        "HC"),

        ("../plots/statistics/line_random_mean.npy", 
        "../plots/statistics/line_random_std.npy", 
        "RS"),
    ]

elif PLOT == "wall":
    data_files = [
        ("../plots/statistics/wall_cma_mean.npy", 
        "../plots/statistics/wall_cma_std.npy", 
        "CMA"),
        
        ("../plots/statistics/wall_hc_mean.npy", 
        "../plots/statistics/wall_hc_std.npy", 
        "HC"),

        ("../plots/statistics/wall_random_mean.npy", 
        "../plots/statistics/wall_random_std.npy", 
        "RS"),
    ]

# Plotting Toggles
use_log = True  # Use log scale for y-axis
use_monotonic = False
cost_upper_bound = 1e6  # Ignore costs above this value
show_std = True  # Show standard deviation bands
min_std_ratio = 0.1  # Minimum std deviation as a ratio of the mean value (prevent too small lower bounds)

# Output file name
output_filename = "cost_comparison.pdf"
# --- End Configuration ---

# Define a list of distinct colors (add more if you have more files)
colors = ["#6a51a3", "#e6550d", "#31a354", "#d94801", "#8c6bb1", "#7fcdbb"]

plt.figure(figsize=(4, 3))  # width, height in inches
# Process and plot each file
for i, (mean_file_path, std_file_path, label) in enumerate(data_files):
    # Load mean data
    mean_full_path = os.path.join(os.path.dirname(__file__), mean_file_path)
    if not os.path.exists(mean_full_path):
        print(f"Warning: Mean file not found at {mean_full_path}, skipping.")
        continue

    try:
        costs = np.load(mean_full_path)
    except Exception as e:
        print(f"Error loading {mean_full_path}: {e}, skipping.")
        continue

    # Mask out outliers
    costs_masked = np.where(costs > cost_upper_bound, np.nan, costs)

    # Apply monotonic transformation if enabled
    if use_monotonic:
        costs_plot = make_monotonic_full(np.nan_to_num(costs_masked, nan=cost_upper_bound))
    else:
        costs_plot = costs_masked

    # Generate x values
    x_values = np.arange(len(costs_plot))

    # Plot the data
    line = plt.plot(
        x_values,
        costs_plot,
        label=label,
        linewidth=2.0,
        color=colors[i % len(colors)],
        alpha=0.9
    )[0]
    
    # Plot std deviation bands if available and requested
    if show_std and std_file_path is not None:
        std_full_path = os.path.join(os.path.dirname(__file__), std_file_path)
        
        if os.path.exists(std_full_path):
            try:
                std_data = np.load(std_full_path)  # Use the correct path instead of hardcoded path

                if len(std_data) == len(costs_plot):
                    # Use 1.96 * std for 95% confidence interval
                    ci_multiplier = 1.96
                    
                    # Calculate standard deviation bounds
                    std_upper = ci_multiplier * std_data
                    std_lower = ci_multiplier * std_data
                    
                    # Clip the lower bound to avoid going below a certain percentage of the mean
                    # This prevents excessive downward spikes in the confidence band
                    if use_log:
                        # For log scale, make sure lower bound isn't too close to zero or negative
                        min_allowed_lower = costs_plot * min_std_ratio
                        lower_bound = np.maximum(costs_plot - std_lower, min_allowed_lower)
                        
                        # Also set an absolute minimum to avoid log(0) issues
                        lower_bound = np.maximum(lower_bound, 1e-10)
                    else:
                        # For linear scale, still apply min_std_ratio but can allow small values
                        min_allowed_lower = costs_plot * min_std_ratio
                        lower_bound = np.maximum(costs_plot - std_lower, min_allowed_lower)
                        
                    # Upper bound calculation
                    upper_bound = costs_plot + std_upper
                    
                    # Apply outlier masking to std bounds
                    upper_bound = np.where(upper_bound > cost_upper_bound, cost_upper_bound, upper_bound)
                    
                    # Plot the confidence interval band
                    plt.fill_between(
                        x_values,
                        lower_bound,
                        upper_bound,
                        color=line.get_color(),
                        alpha=0.2,
                        label=f"{label} (95% CI)"
                    )
                    print(f"Added 95% confidence interval band for {label}")
                else:
                    print(f"Warning: Standard deviation data length ({len(std_data)}) doesn't match mean data length ({len(costs_plot)}) for {label}")
            except Exception as e:
                print(f"Error loading or plotting std data for {label}: {e}")
        else:
            print(f"Standard deviation file not found at {std_full_path}")

# --- Final Plot Formatting ---
plt.xlabel('steps', fontsize=13)
plt.ylabel('lowest cost', fontsize=13)
title_parts = []
if use_monotonic:
    title_parts.append("Lowest")
title_parts.append("Cost over Iterations")
if show_std:
    title_parts.append("with 95% Confidence Intervals")
if use_log:
    title_parts.append("(log scale)")
#plt.title(' '.join(title_parts), fontsize=15, pad=12)

if use_log:
    plt.yscale('log')

# Adjust legend - show only the main lines in legend, not CI bands
# Modified to remove the border (frameon=False)
handles, labels = plt.gca().get_legend_handles_labels()
unique_labels = []
unique_handles = []
for h, l in zip(handles, labels):
    if "(95% CI)" not in l:
        unique_labels.append(l)
        unique_handles.append(h)
plt.legend(unique_handles, unique_labels, fontsize=11, loc='best', frameon=False)

plt.grid(False)
plt.tight_layout(pad=1.5)

# Save the figure
output_filename = f"{PLOT}_cost_comparison.pdf"  # Include the plot type in filename
output_path = os.path.join(os.path.dirname(__file__), output_filename)
try:
    plt.savefig(output_path, format='pdf', dpi=300)
    print(f"Plot saved to {output_path}")
except Exception as e:
    print(f"Error saving plot to {output_path}: {e}")

# Optionally display the plot
plt.show()