import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from src.priors.lc_prior import LCPrior

plt.rcParams['font.family'] = 'Times New Roman'
plt.rcParams['font.size'] = 25  # Adjust the font size as needed
plt.rcParams['font.weight'] = 'bold'
plt.rcParams['axes.titleweight'] = 'bold'
plt.rcParams['axes.labelweight'] = 'bold'


def plot_hypothesis(ranked_df, metric, figsize=(8, 6)):
    # split into the two groups
    true_vals = ranked_df.loc[ranked_df['hypothesis'], metric]
    if metric == 'rank':
        false_vals = ranked_df.loc[~ranked_df['hypothesis'], metric] \
                         .sample(len(true_vals), random_state=42)
    else:
        false_vals = ranked_df.loc[~ranked_df['hypothesis'], metric]

    groups = [true_vals, false_vals]
    labels = ['Interesting Correlations', 'Random Sample']

    # create figure with configurable size
    fig, ax = plt.subplots(figsize=figsize)

    # 1) violin
    vp = ax.violinplot(groups,
                       positions=[1, 2],
                       showmeans=False,
                       showmedians=False,
                       widths=0.8)

    # 2) boxplot overlay
    bp = ax.boxplot(groups,
                    positions=[1, 2],
                    widths=0.2,
                    patch_artist=False,
                    showfliers=False)

    # 3) raw points
    for i, vals in enumerate(groups, start=1):
        x = np.random.normal(loc=i, scale=0.05, size=len(vals))
        ax.scatter(x, vals, alpha=0.6, s=10)

    # 4) annotate medians
    medians = [np.median(g) for g in groups]
    for pos, med in zip([1, 2], medians):
        ax.text(pos, med,
                f"{med:.2f}",
                ha='center', va='bottom',
                fontsize=23, fontweight='bold',
                color='black',)

    # 5) finalize
    ax.set_xticks([1, 2])
    ax.set_xticklabels(labels)
    if metric == 'density_value':
        ax.set_ylabel(r'Density $p(r_{obs})$')
    else:
        ax.set_ylabel(metric)
    plt.tight_layout()
    print(f"saved at figures/{metric}_chicago_open_data.png")
    plt.savefig(f'figures/{metric}_chicago_open_data.png', dpi=300)
    plt.savefig(f'figures/{metric}_chicago_open_data.pdf')
def compute_precision_at_K(ranked_df):
    """
    Compute the precision at K for the ranked dataframe.
    :param ranked_df: DataFrame containing the ranked hypotheses
    :return: Precision at K
    """
    # Get the top K hypotheses
    all_K = [5, 10, 15]
    for K in all_K:
        top_k = ranked_df.head(K)
        # Calculate precision
        precision_at_k = top_k['hypothesis'].mean()
        print(precision_at_k)
    return precision_at_k

def compute_recall_at_K(ranked_df):
    all_K = [5, 10, 15]
    for K in all_K:
        top_k = ranked_df.head(K)
        top_k = ranked_df.head(15)
        # Calculate precision
        recall_at_k = top_k['hypothesis'].sum()/15
        print(recall_at_k)

def compute_average_rank(ranked_df):
    """
    Compute the average rank of the interesting hypotheses.
    :param ranked_df: DataFrame containing the ranked hypotheses
    :return: Average rank
    """
    # Get the ranks of the interesting hypotheses
    interesting_hypotheses = ranked_df[ranked_df['hypothesis']]
    average_rank = interesting_hypotheses['rank'].describe()
    # compute baseline rank

    print(average_rank)
    return average_rank
if __name__ == "__main__":
    prior_model = LCPrior(agent=None)
    prior_model.ignore_zero = True
    df = pd.read_csv('outputs/chicago_open_correlations_cleaned/chicago_open_correlations_cleaned_gpt-4o_lc_prior_iter_0.csv')
    df['density_value'] = df.apply(lambda row: prior_model.get_density_at(row['r_obs'], row['distribution'], bw=0.4), axis=1)
    ranked_df = df.sort_values(by='density_value', ascending=True)
    
    # sort them by |r_obs|
    # df['abs_r'] = df['r_obs'].abs()
    # ranked_df = df.sort_values(by='abs_r', ascending=False)
    
    # reset the index so we can use positional ranks 1…N
    ranked_df = ranked_df.reset_index(drop=True)

    # assign rank = position in the sorted list (1 = lowest density)
    ranked_df['rank'] = ranked_df.index + 1
    ranked_df.to_csv('ranked_df.csv', index=False)
    # plot_hypothesis(ranked_df, 'density_value')
    compute_precision_at_K(ranked_df)
    compute_recall_at_K(ranked_df)
    compute_average_rank(ranked_df)