import numpy as np
import pandas as pd
from tqdm import tqdm

def prepare_df_conformal(df_input, max_val=None):
    if max_val is None:
        max_val = df_input['score'].max()
    # add empty generation
    df = df_input.groupby('topic').head(1)
    df.loc[:, 'text'] = ''
    df.loc[:, 'score'] = max_val + 1e-3 # make score the highest
    df.loc[:, 'label'] = 'S'
    df.loc[:, 'correct'] = True
    df = pd.concat([df_input, df])

    # score
    df.sort_values(['topic', 'score', 'label'], inplace=True)

    return df

def get_safe_threshold(df, score_col='score', target=1.0):
    df = df.sort_values([score_col, 'correct'], ascending=[False, True])
    df.reset_index(drop=True, inplace=True)
    df.loc[:, 'correct_cumulative'] = df['correct'].cumsum() / (np.arange(len(df)) + 1)
    df.loc[:, 'under_target'] = df['correct_cumulative'] < target
    if (~df['under_target']).all():
        return df.loc[:, score_col].min() - 1e-8
    idx = df['under_target'].argmax() # - 1
    threshold = df.loc[idx, score_col]
    return threshold

def processed_filtered_df(df_filtered):
    # checks which topics are completely supported
    df_all_supported_with_abstentions = df_filtered.groupby('topic')['label'].apply(lambda x: (x == 'S').all())

    # get statistics for only the topics that we are not abstaining from
    df_filtered = df_filtered[df_filtered['text'] != '']
    df_all_supported = df_filtered.groupby('topic')['label'].apply(lambda x: (x == 'S').all())
    df_frac_supported = df_filtered.groupby('topic')['label'].apply(lambda x: (x == 'S').mean())

    df_all_supported_with_abstentions = df_all_supported_with_abstentions.to_frame()
    df_all_supported = df_all_supported.to_frame()
    df_frac_supported = df_frac_supported.to_frame()

    return df_all_supported_with_abstentions, df_all_supported, df_frac_supported

# Get results

def get_conformal_summary(df_filtered_dict, step=0.05):
    values = []
    for alpha in np.round(np.arange(step, 1, step), 3):
        _, df_filtered_no_abstentions, df_all_supported_with_abstentions, df_all_supported, df_frac_supported = df_filtered_dict[alpha]

        # get proportion of topics whose facts are all supported
        coverage = df_all_supported_with_abstentions['label'].mean()

        # get proportion of topics whose facts are all supported, excluding abstentions
        frac_all_supported = df_all_supported['label'].mean()

        # get proportion of facts supported, averaged over topics
        frac_facts_supported = df_frac_supported['label'].mean()

        # get number of topics kept
        num_topics_kept = len(df_filtered_no_abstentions['topic'].unique())
        
        # get number of facts kept per topic
        num_facts_kept = df_filtered_no_abstentions.groupby('topic').size().mean()

        values.append([1-alpha, coverage, frac_all_supported, frac_facts_supported, num_topics_kept, num_facts_kept])

    columns = ['1-alpha', 'coverage', 'frac_all_supported', 'frac_facts_supported', 'num_topics_kept', 'num_facts_kept']
    df_summary = pd.DataFrame(values, columns=columns)
    return df_summary


# Getting results by groups

def add_groups_to_filtered_dict(filtered_dict, df_groups):
    for alpha, list_df in filtered_dict.items():
        for i, df in enumerate(list_df):
            if i == 0 or i == 1:
                continue
            filtered_dict[alpha][i] = pd.merge(df, df_groups, left_index=True, right_index=True)

def get_grouped_conformal_summary(group_column, df_filtered_dict, step=0.05):
    values = []
    for alpha in np.round(np.arange(step, 1, step), 3):
        _, df_filtered_no_abstentions, df_all_supported_with_abstentions, df_all_supported, df_frac_supported = df_filtered_dict[alpha]

        size = df_all_supported_with_abstentions.groupby(group_column)['label'].size().reset_index()
        size.columns = group_column + ['size']

        # get proportion of topics whose facts are all supported
        coverage = df_all_supported_with_abstentions.groupby(group_column)['label'].mean().reset_index()
        coverage.columns = group_column + ['coverage']

        # get proportion of topics whose facts are all supported, excluding abstentions
        frac_all_supported = df_all_supported.groupby(group_column)['label'].mean().reset_index()
        frac_all_supported.columns = group_column + ['frac_all_supported']

        # get proportion of facts supported, averaged over topics
        frac_facts_supported = df_frac_supported.groupby(group_column)['label'].mean().reset_index()
        frac_facts_supported.columns = group_column + ['frac_facts_supported']

        # get number of topics kept
        num_topics_kept = df_filtered_no_abstentions.groupby(group_column).apply(
            lambda x: len(x['topic'].unique()), include_groups=False
        ).reset_index()
        num_topics_kept.columns = group_column + ['num_topics_kept']
        
        # get number of facts kept per topic
        num_facts_kept = df_filtered_no_abstentions.groupby(group_column).apply(
            lambda x: x.groupby('topic').size().mean(), include_groups=False
        ).reset_index()
        num_facts_kept.columns = group_column + ['num_facts_kept']

        dfs = [size, coverage, frac_all_supported, frac_facts_supported, num_topics_kept, num_facts_kept]
        merged_df = dfs[0]
        for df in dfs[1:]:
            merged_df = pd.merge(merged_df, df, on=group_column)
        merged_df.insert(0, '1-alpha', 1 - alpha)

        values.append(merged_df)

    df_summary = pd.concat(values)

    return df_summary.reset_index(drop=True)
