import os
import json
import numpy as np
import matplotlib.pyplot as plt
from scipy.special import comb
from scipy.optimize import brentq
import matplotlib.ticker as mticker


def compute_A_B(n, s_star):
    A = 0
    B = 0
    total_search_num = 0
    total_search_num+=n
    for s in range(2, s_star + 1):
        binom = comb(n, s, exact=True)
        A += binom
        B += binom * (2**s - 2)
        total_search_num += binom
    return A, B, total_search_num

def estimate_p_from_avg_n_star(n_star_avg):
    if n_star_avg==1:
        return 1
    lower = 1 / (2**(n_star_avg + 1) - 2)
    upper = 1 / (2**n_star_avg - 2)
    return min(1, 0.5 * (lower + upper))


def plot_multiple_n_star_distributions(n_star_sets, p_empiricals, p_bounds, titles, reduced_ratios):
    plt.rcParams["font.family"] = "Times New Roman"
    fig = plt.figure(figsize=(16.5, 10))
    axs = []

    top_y = 0.58 
    top_height = 0.35
    top_width = 0.26
    top_gap = 0.09
    for i in range(3):
        left = 0.07 + i * (top_width + top_gap)
        axs.append(fig.add_axes([left, top_y, top_width, top_height]))

    bottom_y = 0.10
    bottom_height = 0.33
    bottom_width = 0.26
    middle_gap = 0.09

    outer_margin = (1.0 - 2 * bottom_width - middle_gap) / 2  # = 0.195

    center_offset = 0.05
    left1 = outer_margin + center_offset
    left2 = left1 + bottom_width + middle_gap

    axs.append(fig.add_axes([left1, bottom_y, bottom_width, bottom_height]))
    axs.append(fig.add_axes([left2, bottom_y, bottom_width, bottom_height]))

    ax2_list = []
    num_plots = len(n_star_sets)

    for i in range(num_plots):
        ax = axs[i]
        n_star_values = n_star_sets[i]
        p_empirical = p_empiricals[i]
        p_bound = p_bounds[i]
        title = titles[i]
        ratio = reduced_ratios[i]

        n_star_avg = np.mean(n_star_values)
        counts, bins = np.histogram(
            n_star_values, bins=np.arange(1, max(n_star_values) + 2) - 0.5, density=True
        )
        bin_centers = (bins[:-1] + bins[1:]) / 2
        y_max = max(counts)
        ylim_temp = np.arange(0, 1, 0.05)
        freq_ylim = ylim_temp[ylim_temp > y_max][0]

        ax.bar(bin_centers, counts, width=1, color='lightgray', edgecolor='gray')
        ax.set_yticks([0.0, 0.1, 0.2, 0.3])
        ax.set_ylim(0, freq_ylim)

        if i == 0:
            ax.axvline(n_star_avg, color='gray', linestyle='--', linewidth=3, zorder=0, label='average n*')
        else:
            ax.axvline(n_star_avg, color='gray', linestyle='--', linewidth=3, zorder=0)

        space_left = n_star_avg - 1
        space_right = max(n_star_values) - n_star_avg
        margin = 0.1
        x_min, x_max = 1, max(n_star_values)
        avg_text_x = max(x_min, n_star_avg - margin) if space_left > space_right else min(x_max, n_star_avg + margin)
        h_align = 'right' if space_left > space_right else 'left'

        ax.text(
            avg_text_x, freq_ylim * 0.9,
            f"{n_star_avg:.2f}",
            rotation=0,
            verticalalignment='center',
            horizontalalignment=h_align,
            color='dimgray',
            fontsize=23
        )

        ax.set_xlabel(f"n* (only {int(ratio*100):02d}%)", fontsize=24)
        if i % 3 == 0:
            ax.set_ylabel("Frequency", fontsize=24)
        else:
            ax.set_ylabel("")
            ax.set_yticklabels([])

        ax.set_title(title, fontsize=30, pad=15)
        ax.xaxis.set_major_locator(mticker.MaxNLocator(integer=True))
        if max(n_star_values) > 10:
            ax.set_xticks(np.arange(1, max(n_star_values) + 1, 2))
        else:
            ax.set_xticks(np.arange(1, max(n_star_values) + 1))
        ax.yaxis.set_major_formatter(mticker.FormatStrFormatter('%.1f'))
        ax.tick_params(axis='both', labelsize=20)

        ax2 = ax.twinx()
        ax2.tick_params(axis='both', labelsize=20)
        ax2.set_yticks([0.0, 0.05, 0.1, 0.15])
        ax2.set_ylim(-0.01, 0.18)

        if i % 3 == 2 or i == 4:
            ax2.set_ylabel("p", fontsize=28, rotation=0, labelpad=10)
        else:
            ax2.set_ylabel("")
            ax2.set_yticklabels([])

        ax2.yaxis.set_major_formatter(mticker.FormatStrFormatter('%.2f'))

        if i == 0:
            ax2.axhline(y=p_empirical, color='red', linestyle='-.', linewidth=3, label='Empirical p')
            ax2.axhline(y=p_bound, color='black', linestyle=':', linewidth=3, label='Polynomial Bound of p')
        else:
            ax2.axhline(y=p_empirical, color='red', linestyle='-.', linewidth=3)
            ax2.axhline(y=p_bound, color='black', linestyle=':', linewidth=3)

        ax2.text(bins[-1] + 0.2, p_empirical + 0.004, f"{p_empirical:.3f}",
                 color='red', horizontalalignment='right', fontsize=23)
        ax2.text(bins[-1] + 0.2, p_bound + 0.004, f"{p_bound:.3f}",
                 color='black', horizontalalignment='right', fontsize=23)

        ax2_list.append(ax2)

    handles, labels = [], []
    for ax in fig.axes:
        h, l = ax.get_legend_handles_labels()
        handles.extend(h)
        labels.extend(l)

    seen = set()
    unique = [(h, l) for h, l in zip(handles, labels) if not (l in seen or seen.add(l))]

    legend = fig.legend(
        *zip(*unique),
        loc='lower center',
        bbox_transform=fig.transFigure,
        bbox_to_anchor=(0.55, -0.07),
        ncol=3,
        frameon=True,
        fontsize=22
    )

    plt.savefig("dd.png", dpi=300, bbox_inches='tight')


    
    
def main():
    
    target_folders = [
        "jobs_arxiv_for_test/250512_collection/20250505_0612_debug_gpt2-xs_exStWd",
        "jobs_arxiv_for_test/250512_collection/20250505_0612_debug_gpt2-xs_lama_trex_exStWd",
        "jobs_arxiv_for_test/250512_collection/20250505_0612_debug_pythia-1b_exStWd",
        "jobs_arxiv_for_test/250512_collection/20250505_0612_debug_pythia-1b_lama_trex_exStWd",
        "jobs_arxiv_for_test/250512_collection/20250505_0612_debug_pythia-14m_exStWd",
        "jobs_arxiv_for_test/250512_collection/20250505_0612_debug_pythia-14m_lama_trex_exStWd"
    ]

    target_folders = [
        "jobs_arxiv_for_test/250512_collection/20250505_0612_debug_gpt2-xs_exStWd",
        "jobs_arxiv_for_test/250512_collection/20250505_0612_debug_gpt2-xs_lama_trex_exStWd",
        "jobs_arxiv_for_test/250512_collection/20250505_0612_debug_pythia-1b_exStWd",
        "jobs_arxiv_for_test/250512_collection/20250505_0612_debug_pythia-1b_lama_trex_exStWd",
    ]
    
    target_folders = [
        "jobs_arxiv_for_test/250512_collection/20250505_0612_debug_gpt2-xs_lama_trex_exStWd",
        "jobs_arxiv_for_test/250512_collection/20250505_0612_debug_pythia-1b_lama_trex_exStWd",
        "jobs_arxiv_for_test/250512_collection/20250505_0612_debug_pythia-14m_lama_trex_exStWd",
        "jobs_arxiv_for_test/250512_collection/vit_tiny_patch16_224_converted",
        "jobs_arxiv_for_test/250512_collection/deit_tiny_patch16_224_converted",
    ]
    
    
    n_star_list = []
    empirical_ps = []
    bound_ps = []
    titles = []
    
    reduced_ratios = []
    for target_folder in target_folders:
        if "gpt2-xs" in target_folder:
            model_name = "GPT2-xs"
            n = 2 * 6 + 2
        elif "pythia-14m" in target_folder:
            model_name = "Pythia-14m"
            n = 4 + 2
        elif "pythia-1b" in target_folder:
            model_name = "Pythia-1b"
            n = 8 + 2
        elif "vit" in target_folder:
            n = 8
            model_name = "ViT-tiny"
        elif "deit" in target_folder:
            n = 8
            model_name = "DeiT-tiny"
        else:
            import pdb; pdb.set_trace()
            
        if "lama" in target_folder:
            dataset_name = "lama_trex"
        else:
            dataset_name = "known1000"
        
        
        results_path = os.path.join(target_folder, "results")
        result_dirs = sorted(os.listdir(results_path))
      

        result_bowl = {}
        for result_dir in result_dirs:
            data_idx = int(result_dir.split(".")[0].split("R")[1])
            c_file_temp = "C{:06d}.json".format(data_idx)
            f = open(os.path.join(results_path, result_dir, c_file_temp), "r")
            curr_data = json.load(f)
            
            for b_idx, b_paths in curr_data.items():
                b_idx = int(b_idx)
                if b_idx not in result_bowl.keys():
                    result_bowl[b_idx] = []
                paths_len = [len(i) for i in b_paths]
                result_bowl[b_idx].append(max(paths_len))

        collection_steps = []
        for b_idx in sorted(list(result_bowl.keys())):
            collection_steps.append(result_bowl[b_idx])
        collection_steps =  np.asarray(collection_steps)
        
        avg_s_star = collection_steps.mean()
        est_p = estimate_p_from_avg_n_star(avg_s_star)
        bound_p = 1/(2**n-2)
        
        
        executed_step = [compute_A_B(n, i)[2] for i in collection_steps.flatten()]
        total_step = [compute_A_B(n, n)[2] for i in collection_steps.flatten()]
        
        n_star_list.append(collection_steps.flatten())
        empirical_ps.append(est_p)
        bound_ps.append(bound_p)
        # titles.append("{}: {}".format(dataset_name, model_name))
        # titles.append("{} on {}".format(model_name, dataset_name))
        titles.append("{}".format(model_name))
        reduced_ratios.append(sum(executed_step)/sum(total_step))
        
        
    plot_multiple_n_star_distributions(n_star_list, empirical_ps, bound_ps, titles, reduced_ratios)
    
    import pdb; pdb.set_trace()
        

main()