import os
import numpy as np
from scipy.ndimage import uniform_filter1d
import matplotlib.pyplot as plt

def break_into_pairs(ranking):
    return [(ranking[i], ranking[j]) for i in range(len(ranking)) for j in range(i + 1, len(ranking))]

def generate_pl_ranking(indices, X, theta):
    ranking = []
    pool = list(indices)
    while pool:
        scores = np.exp(X[pool] @ theta)
        probs = scores / np.sum(scores)
        chosen = np.random.choice(pool, p=probs)
        ranking.append(chosen)
        pool.remove(chosen)
    return ranking

def smooth_curve(data, window_size=50):
    return uniform_filter1d(data, size=window_size, mode="nearest")

def plot_side_by_side(result_dict1, result_dict2, title1, title2, path, ylabel, T, K_list, update_frequency):
    x_vals = np.arange(0, T + 1, update_frequency)
    shrink_factor = 0.5

    fig, axs = plt.subplots(1, 2, figsize=(14, 5), sharey=True)

    for K in K_list:
        mean1, std1 = result_dict1[K]
        mean2, std2 = result_dict2[K]

        min_len = min(len(x_vals), len(mean1), len(std1), len(mean2), len(std2))
        x_plot = x_vals[:min_len]

        axs[0].errorbar(x_plot, mean1[:min_len], yerr=shrink_factor*std1[:min_len],
                        label=f"K={K}", fmt='-o', capsize=4, elinewidth=0.8)
        axs[1].errorbar(x_plot, mean2[:min_len], yerr=shrink_factor*std2[:min_len],
                        label=f"K={K}", fmt='-o', capsize=4, elinewidth=0.8)

    axs[0].set_title(title1)
    axs[0].set_xlabel("Round")
    axs[0].set_ylabel(ylabel)
    axs[0].legend()

    axs[1].set_title(title2)
    axs[1].set_xlabel("Round")
    axs[1].legend()

    for ax in axs:
        ax.set_xlim([-0.2, T + 1])
        ax.set_ylim([-0.05, 1.05])
        ax.set_xticks(range(0, T + 1, 40))
        ax.grid(True, linestyle='--', linewidth=0.5)

    plt.tight_layout()
    os.makedirs(os.path.dirname(path), exist_ok=True)
    plt.savefig(path)
    plt.close()
    

def plot_selector_regret_subplots_by_method(
    results_pl, results_rb, path, T, K_list, update_frequency, max_points=10
):
    x_vals = np.arange(0, T + 1, update_frequency)
    num_plots = len(K_list)

    fig, axs = plt.subplots(2, num_plots, figsize=(5 * num_plots, 6), sharey="row")

    if num_plots == 1:
        axs = np.array([[axs[0]], [axs[1]]])

    for col_idx, K in enumerate(K_list):
        # Plot PL regrets (row 0)
        for method, (mean, std) in results_pl[K].items():
            sparse_idx = np.linspace(0, len(mean) - 1, min(max_points, len(mean))).astype(int)
            x_sparse = x_vals[sparse_idx]
            axs[0, col_idx].errorbar(x_sparse, mean[sparse_idx], yerr=0.5 * std[sparse_idx],
                                     label=method, fmt='-o', capsize=3)
        axs[0, col_idx].set_title(f"K={K}")
        axs[0, col_idx].grid(True, linestyle='--')
        axs[0, col_idx].set_xlabel("Round")
        if col_idx == 0:
            axs[0, col_idx].set_ylabel("PL Regret")
        axs[0, col_idx].legend(fontsize=9)

        # Plot RB regrets (row 1)
        for method, (mean, std) in results_rb[K].items():
            sparse_idx = np.linspace(0, len(mean) - 1, min(max_points, len(mean))).astype(int)
            x_sparse = x_vals[sparse_idx]
            axs[1, col_idx].errorbar(x_sparse, mean[sparse_idx], yerr=0.5 * std[sparse_idx],
                                     label=method, fmt='-o', capsize=3)
        axs[1, col_idx].set_xlabel("Round")
        axs[1, col_idx].grid(True, linestyle='--')
        if col_idx == 0:
            axs[1, col_idx].set_ylabel("RB Regret")
        axs[1, col_idx].legend(fontsize=9)

    plt.tight_layout()
    os.makedirs(os.path.dirname(path), exist_ok=True)
    fig.savefig(path)
    plt.close(fig)