import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import pandas as pd

COLOR_CYCLE = plt.rcParams["axes.prop_cycle"].by_key()["color"]
ALPHA = 0.8

AXIS_LABEL_FONTSIZE = 22
TICK_FONTSIZE = 15
LEGEND_FONTSIZE = 16
TITLE_FONTSIZE = 16

FIGSIZE = (8, 6)

def save_runtime_scatter_plot(
    kind_runtime, 
    baseline_runtime, 
    save_path="scatter.pdf", 
    labels=("QUASAR", "Python"), 
    improved_only=False, 
    truncate=None,
    log_scale=False
    ):
    shared_keys = sorted(set(kind_runtime.keys()) & set(baseline_runtime.keys()))
    if not shared_keys:
        print("No shared program IDs found for scatter plot.")
        return

    kind_values = []
    baseline_values = []

    for k in shared_keys:
        kv = kind_runtime[k] / 1000
        bv = baseline_runtime[k] / 1000
        if truncate is not None and (kv > truncate or bv > truncate):
            continue
        if improved_only and kv / bv >= 0.975:
            continue
        kind_values.append(kv)
        baseline_values.append(bv)

    if not kind_values:
        print("No data points remain after truncation.")
        return
    
    plt.figure(figsize=FIGSIZE)
    plt.scatter(baseline_values, kind_values, alpha=ALPHA, label='Programs')
    min_val = min(min(kind_values), min(baseline_values))
    max_val = max(max(kind_values), max(baseline_values))
    plt.plot([min_val, max_val], [min_val, max_val], '--', color='gray', label='Equal Runtime')
    plt.ylabel(f'{labels[0]} runtime (s)', fontsize=AXIS_LABEL_FONTSIZE)
    plt.xlabel(f'{labels[1]} runtime (s)', fontsize=AXIS_LABEL_FONTSIZE)
    if log_scale:
        plt.xscale("log")
        plt.yscale("log")
    plt.legend(fontsize=LEGEND_FONTSIZE)
    plt.xticks(fontsize=TICK_FONTSIZE)
    plt.yticks(fontsize=TICK_FONTSIZE)
    plt.grid(True, linestyle='--', alpha=0.5)
    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()


def save_runtime_multi_baselines_scatter_plot(kind_runtime, baseline_runtimes, save_path, kind_label, baseline_labels, truncate=None):
    if len(baseline_runtimes) != len(baseline_labels):
        raise ValueError("Number of baseline runtimes and baseline labels must match.")

    plt.figure(figsize=FIGSIZE)
    all_kind_vals, all_base_vals = [], []

    for i, (baseline_runtime, label) in enumerate(zip(baseline_runtimes, baseline_labels)):
        shared_keys = sorted(set(kind_runtime.keys()) & set(baseline_runtime.keys()))
        if not shared_keys:
            print(f"No shared program IDs found for scatter plot with baseline: {label}")
            continue

        kind_values = []
        baseline_values = []

        for k in shared_keys:
            kv = kind_runtime[k] / 1000  # Convert to seconds
            bv = baseline_runtime[k] / 1000
            if truncate is not None and (kv > truncate or bv > truncate):
                continue
            kind_values.append(kv)
            baseline_values.append(bv)

        if not kind_values:
            print(f"No data points remain after truncation for baseline: {label}")
            continue

        plt.scatter(
            baseline_values, kind_values,
            alpha=ALPHA,
            label=f'{label}',
            color=COLOR_CYCLE[i]
        )
        all_kind_vals.extend(kind_values)
        all_base_vals.extend(baseline_values)

    # Draw diagonal reference line
    if all_kind_vals and all_base_vals:
        min_val = min(min(all_kind_vals), min(all_base_vals))
        max_val = max(max(all_kind_vals), max(all_base_vals))
        plt.plot([min_val, max_val], [min_val, max_val], '--', color='gray', label='Equal Runtime')

    # Axis labels, grid, legend
    plt.xlabel(f'Baseline actions runtime (s)', fontsize=AXIS_LABEL_FONTSIZE)
    plt.ylabel(f'{kind_label} runtime (s)', fontsize=AXIS_LABEL_FONTSIZE)
    plt.legend(fontsize=LEGEND_FONTSIZE)
    plt.xticks(fontsize=TICK_FONTSIZE)
    plt.yticks(fontsize=TICK_FONTSIZE)
    plt.grid(True, linestyle='--', alpha=0.5)
    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()

def save_interaction_scatter_plot(before_vals, after_vals, save_path, truncate=None):
    # Optional truncation
    points = [(b, a) for b, a in zip(before_vals, after_vals)
              if truncate is None or (b <= truncate and a <= truncate)]

    if not points:
        print("No points to plot after applying truncation.")
        return

    before, after = zip(*points)

    plt.figure(figsize=FIGSIZE)
    plt.scatter(before, after, alpha=ALPHA, label='Programs')
    min_val = min(min(before), min(after))
    max_val = max(max(before), max(after))
    plt.plot([min_val, max_val], [min_val, max_val], '--', color='gray', label='Equal Interaction Count')
    plt.xlabel("Before Interaction Count", fontsize=AXIS_LABEL_FONTSIZE)
    plt.ylabel("After Interaction Count", fontsize=AXIS_LABEL_FONTSIZE)
    plt.legend(fontsize=LEGEND_FONTSIZE)
    plt.xticks(fontsize=TICK_FONTSIZE)
    plt.yticks(fontsize=TICK_FONTSIZE)
    plt.grid(True, linestyle='--', alpha=0.5)
    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()

def save_interaction_hist2d_plot(before_vals, after_vals, before_label, after_label, save_path, truncate=None, improved_only=False):
    if improved_only:
        points = [(b, a) for b, a in zip(before_vals, after_vals)
                  if b != a and (truncate is None or (b <= truncate and a <= truncate))]
    else:
        points = [(b, a) for b, a in zip(before_vals, after_vals)
                  if truncate is None or (b <= truncate and a <= truncate)]

    if not points:
        print("No points to plot after applying truncation.")
        return

    before, after = zip(*points)
    max_val = max(max(before), max(after))

    plt.figure(figsize=FIGSIZE)
    bins = np.arange(0, max_val + 1.5)  # e.g., bins at [0, 1, 2, ..., max_val + 1]
    plt.hist2d(before, after, bins=[bins, bins], cmap="Greens", norm=matplotlib.colors.LogNorm())
    cbar = plt.colorbar()
    cbar.ax.tick_params(labelsize=TICK_FONTSIZE) 
    cbar.set_label(r'$\log$(# Programs)', fontsize=LEGEND_FONTSIZE) 
    plt.plot([0, max_val], [0, max_val], '--', color='gray', label='Equal Interaction Count')
    plt.xlabel(f"{before_label} Interaction Count", fontsize=AXIS_LABEL_FONTSIZE)
    plt.ylabel(f"{after_label} Interaction Count", fontsize=AXIS_LABEL_FONTSIZE)
    plt.legend(fontsize=LEGEND_FONTSIZE)
    positions = np.arange(max_val)
    centers = positions + 0.5
    labels = positions
    plt.xticks(ticks=centers, labels=labels, fontsize=TICK_FONTSIZE)
    plt.yticks(ticks=centers, labels=labels, fontsize=TICK_FONTSIZE)
    # plt.xticks(np.arange(max_val), fontsize=TICK_FONTSIZE)
    # plt.yticks(np.arange(max_val), fontsize=TICK_FONTSIZE)
    plt.xlim(0, max_val)
    plt.ylim(0, max_val)
    plt.grid(True, linestyle='--', alpha=0.5)
    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()

def save_improvement_factor(impr_factor, save_path, label="QUASAR / Python", color="green", bins=50, log_scale=False, title=None):
    plt.figure(figsize=FIGSIZE)
    
    mean_ratio = np.mean(impr_factor)
    plt.hist(
        impr_factor,
        bins=bins,
        alpha=ALPHA,
        color=color,
        histtype="stepfilled",
        linewidth=1.5,
    )

    plt.axvline(mean_ratio, linestyle="--", linewidth=2, color="red", label=f"mean={mean_ratio:.2f}")
    plt.axvline(1, linestyle="--", color="black", linewidth=1.5, label="Parity (1x)")
    
    if log_scale:
        plt.yscale("log")

    plt.xlabel(label, fontsize=AXIS_LABEL_FONTSIZE)
    if log_scale:
        plt.ylabel("Frequency (log scale)", fontsize=AXIS_LABEL_FONTSIZE)
    else:
        plt.ylabel("Frequency", fontsize=AXIS_LABEL_FONTSIZE)
    plt.xticks(fontsize=TICK_FONTSIZE)
    plt.yticks(fontsize=TICK_FONTSIZE)
    if title is not None:
        plt.title(title, fontsize=TITLE_FONTSIZE)
    plt.legend(loc="upper left", fontsize=LEGEND_FONTSIZE)
    plt.grid(True, linestyle="--", alpha=0.5)
    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()
    
def save_ratio_histogram(ratios, save_path, label="QUASAR / Python", color="blue", bins=50, log_scale=False, title=None):
    plt.figure(figsize=FIGSIZE)

    for i, (kind, ratios) in enumerate(ratios.items()):
        if color is None:
            color = COLOR_CYCLE[i % len(COLOR_CYCLE)]
        mean_ratio = np.mean(ratios)

        plt.hist(
            ratios,
            bins=bins,
            alpha=ALPHA,
            color=color,
            histtype="stepfilled",
            linewidth=1.5,
        )

        plt.axvline(mean_ratio, linestyle="--", linewidth=2, color="red", label=f"mean={mean_ratio:.2f}")
        plt.axvline(1, linestyle="--", color="black", linewidth=1.5, label="Parity (1x)")
        
        if log_scale:
            plt.yscale("log")

    plt.xlabel(label, fontsize=AXIS_LABEL_FONTSIZE)
    plt.ylabel("Frequency", fontsize=AXIS_LABEL_FONTSIZE)
    plt.xticks(fontsize=TICK_FONTSIZE)
    plt.yticks(fontsize=TICK_FONTSIZE)
    if title is not None:
        plt.title(title, fontsize=TITLE_FONTSIZE)
    plt.legend(loc="upper left", fontsize=LEGEND_FONTSIZE)
    plt.grid(True, linestyle="--", alpha=0.5)
    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()

def save_runtime_histogram(all_runtimes, save_path, truncate=None, title="Execution Time Comparison (Histogram)"):
    plt.figure(figsize=(10, 6))

    for i, (kind, runtimes) in enumerate(all_runtimes.items()):
        if truncate is not None:
            runtimes = [r for r in runtimes if r <= truncate]
        if not runtimes:
            continue
        
        color = COLOR_CYCLE[i % len(COLOR_CYCLE)]

        plt.hist(
            runtimes,
            bins=80,
            alpha=ALPHA,
            label=f"{kind} (mean={np.mean(runtimes):.0f}ms)",
            color=color,
            histtype="step",
            linewidth=1.5,
            density=True
        )

        plt.axvline(np.mean(runtimes), linestyle="--", linewidth=2, color=color, label="_nolegend_")

    plt.xlabel("Execution Time (ms)")
    plt.ylabel("Frequency")
    plt.title(title)
    plt.legend()
    plt.grid(True, linestyle="--", alpha=0.5)
    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()

def save_runtime_density_plot(all_runtimes, save_path, truncate=None, title="Execution Time Comparison (Density Plot)"):
    plt.figure(figsize=(10, 6))

    data = []
    for kind, runtimes in all_runtimes.items():
        if truncate is not None:
            runtimes = [rt for rt in runtimes if rt <= truncate]
        data.extend([{"execution": kind, "runtime": rt} for rt in runtimes])
    df = pd.DataFrame(data)

    sns.kdeplot(data=df, x="runtime", hue="execution", fill=True, common_norm=False, alpha=ALPHA, linewidth=1.5)
    plt.xlabel("Execution Time (ms)")
    plt.ylabel("Density")
    plt.title(title)
    plt.grid(True, linestyle="--", alpha=0.5)
    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()

def save_runtime_cdf(all_runtimes, save_path, title="Execution Time Comparison (CDF)"):
    plt.figure(figsize=(10, 6))

    for i, (kind, runtimes) in enumerate(all_runtimes.items()):
        color = COLOR_CYCLE[i % len(COLOR_CYCLE)]
        sorted_runtimes = np.sort(runtimes)
        cdf = np.arange(len(sorted_runtimes)) / len(sorted_runtimes)
        plt.plot(sorted_runtimes, cdf, label=kind, color=color, linewidth=1.8)

    plt.xlabel("Execution Time (ms)")
    plt.ylabel("Cumulative Probability")
    plt.title(title)
    plt.legend()
    plt.grid(True, linestyle="--", alpha=0.5)
    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()

def save_runtime_boxplot(all_runtimes, save_path, log_scale=False, truncate=None, title="Execution Time Comparison (Box Plot)"):
    plt.figure(figsize=(10, 6))
    data = list(all_runtimes.values())
    if truncate is not None:
        data = [[r for r in runtimes if r <= truncate] for runtimes in data]
    labels = list(all_runtimes.keys())

    plt.boxplot(data, labels=labels, showfliers=True)
    if log_scale:
        plt.yscale("log")
        plt.ylabel("Execution Time (ms - log scale)")
    else:
        plt.ylabel("Execution Time (ms)")
    plt.title(title)
    plt.grid(True, linestyle="--", alpha=0.5)
    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()
