import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np


def break_ties(group):
    if len(group) == 1:
        return group
    n = len(group)
    shuffled_idx = np.random.permutation(group.index)
    # get the scalar value of the first trajectory_value in this group
    base_value = abs(group["trajectory_value"].iloc[0])
    if base_value == 0:
        base_value = 1
    increment = base_value * 1e-8  # scale small relative to the magnitude
    # create small distinct offsets
    offsets = np.arange(1, n + 1) * increment
    np.random.shuffle(offsets)
    group.loc[shuffled_idx, "trajectory_value"] += offsets
    return group


def weighted_threshold_discrete(df, value_col, weight_col, X):
    """
    Find threshold v (an actual value from the dataset) such that
    the sum of weights where value < v is approximately X% of the total.
    """
    df_sorted = df.sort_values(value_col).reset_index(drop=True)
    df_sorted["cum_weight"] = df_sorted[weight_col].cumsum()
    total_weight = df_sorted[weight_col].sum()
    df_sorted["cum_frac"] = df_sorted["cum_weight"] / total_weight

    idx = (df_sorted["cum_frac"] >= X).idxmax()  # first index meeting condition
    v = df_sorted.loc[idx, value_col]

    # Compute actual achieved fraction
    actual_weight_above = df_sorted.loc[df_sorted[value_col] > v, weight_col].sum()
    actual_frac = actual_weight_above / total_weight

    return v, actual_frac


def find_threshold_for_generative_abstention(df, value_col, X, total_weight):
    df_sorted = df.sort_values(value_col).reset_index(drop=True)
    seen_samples = {}
    cumsum_values = []
    value = 0
    for _, row in df_sorted.iterrows():
        s, t, row_total = row["sample_index"], row['token_index'], row['output_length']
        if s not in seen_samples.keys():
            value += row_total - t
            seen_samples[s] = t
        else:
            prev_t = seen_samples[s]
            if prev_t > t:
                value += prev_t - t
                seen_samples[s] = t
            # else no need to change anything
        cumsum_values.append(value)
    df_sorted["token_cumsum"] = cumsum_values
    df_sorted["cum_frac"] = df_sorted["token_cumsum"] / total_weight
    idx = (df_sorted["cum_frac"] >= X).idxmax()  # first index meeting condition
    v = df_sorted.loc[idx, value_col]
    return v, df_sorted.loc[idx]["cum_frac"]


def select_row_min_token_abstained(group, T):
    below_T = group[group["trajectory_value"] < T]
    if not below_T.empty:
        return below_T.loc[below_T["token_index"].idxmin()]
    else:
        row = group.loc[group["token_index"].idxmax()].copy()
        row["token_index"] += 1
        return row


def get_exact_percentile_rows(data, T, ratio, column):
    target_n = int(len(data) * (1 - ratio))
    above_T = data[data[column] > T]
    equal_T = data[data[column] == T]
    needed = target_n - len(above_T)
    if needed > 0:
        sampled_equal_T = equal_T.sample(n=needed, random_state=42)
        selected = pd.concat([above_T, sampled_equal_T], ignore_index=True)
    else:
        selected = above_T.head(target_n)
    return selected


def add_saved_token_stats(df, baselines):
    full_df = df[df["model"] == "full"][["abstention_ratio", "saved_tokens"]].rename(
        columns={"saved_tokens": "full_saved_tokens"}
    )
    other_df = (
        df[df["model"].isin(baselines)]
            .groupby("abstention_ratio", as_index=False)["saved_tokens"]
            .mean()
            .rename(columns={"saved_tokens": "avg_other_saved_tokens"})
    )
    result = pd.merge(full_df, other_df, on="abstention_ratio", how="inner")
    print(result)
    result["full_vs_avg_pct"] = (
            result["full_saved_tokens"] / result["avg_other_saved_tokens"] * 100
    )
    df = df.merge(result, on='abstention_ratio', how='left')
    df['full_vs_avg_pct'] = df.apply(
        lambda r: r['full_vs_avg_pct'] if r['model'] == 'full' else None, axis=1
    )
    return df


def draw_graph_main_results(df, title_string, save_path=None):
    plt.figure(figsize=(8, 6))
    sns.lineplot(
        data=df[df["model"] != "no_abstention"],
        x="abstention_ratio",
        y="accuracy",
        hue="model",
        marker="o"
    )
    no_abst_df = df[df["model"] == "no_abstention"]
    plt.plot(
        no_abst_df["abstention_ratio"],
        no_abst_df["accuracy"],
        color="grey",
        marker="o",
        label="no_abstention",
        linewidth=2,
        linestyle=(0, (5, 5)),
        markersize=4
    )
    full_df = df[df["model"] == "full"]
    for _, row in full_df.iterrows():
        plt.text(
            row["abstention_ratio"],
            row["accuracy"],
            f'{row["full_vs_avg_pct"]:.1f}%',  # format e.g. 95.3%
            fontsize=9,
            color="black",
            ha="right",
            va="bottom"
        )
    plt.title("Accuracy vs. Abstention Ratio per Model for " + title_string)
    plt.xlabel("Abstention Ratio")
    plt.ylabel("Accuracy")
    plt.legend(title="Model")
    plt.grid(True, linestyle="--", alpha=0.6)
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Saved plot: {save_path}")
    else:
        plt.show()
    plt.close()


def draw_graph_saved_tokens(df, title_string, abstention_acc, save_path=None):
    plt.figure(figsize=(8, 6))
    sns.lineplot(
        data=df[df["model"] != "no_abstention"],
        x="saved_tokens",
        y="accuracy",
        hue="model",
        marker="o"
    )
    no_abst_df = df[df["model"] == "no_abstention"]
    plt.plot(
        no_abst_df["saved_tokens"],
        no_abst_df["accuracy"],
        color="grey",
        marker="o",
        label="no_abstention",
        linewidth=2,
        linestyle=(0, (5, 5)),
        markersize=4
    )
    plt.title("Accuracy vs. Saved for " + title_string + ". Abstention reward for accuracy calc: " + str(abstention_acc))
    plt.xlabel("Saved tokens")
    plt.ylabel("Accuracy")
    plt.legend(title="Model")
    plt.grid(True, linestyle="--", alpha=0.6)
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Saved plot: {save_path}")
    else:
        plt.show()
    plt.close()


def draw_graph_abstention_rate_reward(df, title_string, save_path=None):
    plt.figure(figsize=(12, 8))
    
    # Get unique models (excluding no_abstention)
    models = [m for m in df["model"].unique() if m != "no_abstention"]
    colors = plt.cm.tab10.colors
    model_colors = {m: colors[i] for i, m in enumerate(models)}
    
    # Plot each model with annotations
    for model in models:
        model_df = df[df["model"] == model]
        plt.plot(
            model_df["abstention_ratio"],
            model_df["reward"],
            marker="o",
            label=model,
            color=model_colors[model]
        )
        # Annotate with abstention_reward values
        for _, row in model_df.iterrows():
            if row["abstention_reward"] is not None:
                plt.annotate(
                    f'{row["abstention_reward"]:.2f}',
                    (row["abstention_ratio"], row["reward"]),
                    textcoords="offset points",
                    xytext=(0, 5),
                    ha='center',
                    fontsize=6,
                    color=model_colors[model],
                    alpha=0.8
                )
    
    # Plot no_abstention baseline
    no_abst_df = df[df["model"] == "no_abstention"]
    plt.plot(
        no_abst_df["abstention_ratio"],
        no_abst_df["reward"],
        color="grey",
        marker="o",
        label="no_abstention",
        linewidth=2,
        linestyle=(0, (5, 5)),
        markersize=4
    )
    
    plt.title("Reward vs. Abstention Rate: " + title_string + "\n(annotations show threshold T used as r_⊥)")
    plt.xlabel("Abstention Rate")
    plt.ylabel("Reward")
    plt.legend(title="Model", loc="best")
    plt.grid(True, linestyle="--", alpha=0.6)
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Saved plot: {save_path}")
    else:
        plt.show()
    plt.close()