# python ROB_probability.py deepseek_math

import sys
import pandas as pd
import numpy as np

from utils import *

CONF_MODE   = "quantile"
Q_VAL       = 0.05    
USE_TOP_Q   = False 
ROB_THRES   = 0.35  

def seq_conf(token_probs):
    log_p = np.log(token_probs)
    if CONF_MODE == "min":
        return np.exp(log_p.min())
    elif CONF_MODE == "quantile":
        q = 1 - Q_VAL if USE_TOP_Q else Q_VAL
        return np.exp(np.quantile(log_p, q))
    

def summarize_top_prob_all_tokens(data_by_model, thresholds=range(50, 100, 5), conf_thres=0.35):
    stats = {}

    for model, items in data_by_model.items():
        group_probs = {
            'ROB': [],
            'Non-ROB': []
        }
        sample_counts = {
            'Total': 0,
            'ROB': 0,
            'Non-ROB': 0
        }

        for item in items:
            tokens = item.get('token_entropy_info', [])
            if not tokens:
                continue

            token_probs = [t.get('pred_prob') for t in tokens if 'pred_prob' in t]
            top_probs = [t.get('top_prob') for t in tokens if 'top_prob' in t]

            if not token_probs or not top_probs:
                continue

            conf = seq_conf(token_probs)
            group = 'ROB' if conf > conf_thres else 'Non-ROB'
            group_probs[group].extend(top_probs)

            sample_counts['Total'] += 1
            sample_counts[group] += 1

        model_summary = {
            'sample_counts': sample_counts,
            'top_prob_stats': {}
        }

        for group_name, probs in group_probs.items():
            probs = np.array(probs)
            summary = {
                'mean': np.mean(probs),
                'std': np.std(probs)
            }
            for thresh in thresholds:
                label = f"≥{thresh}"
                summary[label] = np.mean(probs >= (thresh / 100))
            model_summary['top_prob_stats'][group_name] = summary

        stats[model] = model_summary

    return stats


import pandas as pd

def print_top_prob_stats_table(results, thresholds=[75, 80, 85]):
    rows = []

    for model, model_result in results.items():
        counts = model_result['sample_counts']

        for group in ['ROB', 'Non-ROB']:
            stat = model_result['top_prob_stats'][group]
            row = {
                'Model': model,
                'Group': group,
                'Samples': counts[group],
                'Mean': round(stat['mean'], 4),
                'Std': round(stat['std'], 4)
            }
            for t in thresholds:
                key = f'≥{t}'
                row[key] = round(stat[key], 4)
            rows.append(row)

    df = pd.DataFrame(rows)
    df = df[['Model', 'Group', 'Samples', 'Mean', 'Std'] + [f'≥{t}' for t in thresholds]]
    df  = df.T
    print("\n📊 Top Token Probability Statistics by Group")
    import ace_tools_open as tools; tools.display_dataframe_to_user(name="Top Token Probability Stats", dataframe=df)


def compute_mismatch_ratio_by_rob(data_by_model, conf_thres=0.35):
    from collections import defaultdict
    result_rows = []

    for model, samples in data_by_model.items():
        stats = {
            'Total': {'mismatch': 0, 'count': 0, 'samples': 0},
            'ROB': {'mismatch': 0, 'count': 0, 'samples': 0},
            'Non-ROB': {'mismatch': 0, 'count': 0, 'samples': 0}
        }

        for sample in samples:
            tokens = sample.get('token_entropy_info', [])
            if not tokens:
                continue

            token_probs = [t.get('pred_prob') for t in tokens if 'pred_prob' in t]
            if not token_probs:
                continue

            conf = seq_conf(token_probs)
            group = 'ROB' if conf > conf_thres else 'Non-ROB'

            mismatch_in_sample = 0
            token_in_sample = 0

            for t in tokens:
                if 'pred_token' not in t or 'top_token' not in t:
                    continue
                token_in_sample += 1
                stats['Total']['count'] += 1
                stats[group]['count'] += 1

                if t['pred_token'] != t['top_token']:
                    stats['Total']['mismatch'] += 1
                    stats[group]['mismatch'] += 1
                    mismatch_in_sample += 1

            if token_in_sample > 0:
                stats['Total']['samples'] += 1
                stats[group]['samples'] += 1

        result_rows.append({
            'Model': model,
            'Total Tokens': stats['Total']['count'],
            'Total Samples': stats['Total']['samples'],
            'Total Mismatch Ratio': stats['Total']['mismatch'] / stats['Total']['count'] if stats['Total']['count'] else 0,
            'ROB Tokens': stats['ROB']['count'],
            'ROB Samples': stats['ROB']['samples'],
            'ROB Mismatch Ratio': stats['ROB']['mismatch'] / stats['ROB']['count'] if stats['ROB']['count'] else 0,
            'Non-ROB Tokens': stats['Non-ROB']['count'],
            'Non-ROB Samples': stats['Non-ROB']['samples'],
            'Non-ROB Mismatch Ratio': stats['Non-ROB']['mismatch'] / stats['Non-ROB']['count'] if stats['Non-ROB']['count'] else 0
        })

    df = pd.DataFrame(result_rows)
    import ace_tools_open as tools; tools.display_dataframe_to_user(name="Mismatch Ratio by ROB Group", dataframe=df)
    return df


if __name__ == "__main__":
    if len(sys.argv) < 2:
        print("use: python ROB_probability.py [generator_name] [response_type]")
        sys.exit(1)

    generator = sys.argv[1]
    response_type = sys.argv[2] if len(sys.argv) > 2 else 'all'
    
    file_map = get_file_map(generator)
    print(f"✅ Generator: {generator}")

    data_by_model = extract_json(file_map, response_type)

    thresholds = range(75, 90, 5)
    # results = summarize_top_prob_all_tokens(data_by_model, thresholds=thresholds)

    # print_top_prob_stats_table(results, thresholds=thresholds)

    df_result = compute_mismatch_ratio_by_rob(data_by_model)