import numpy as np
import json, os
from collections import defaultdict
import pandas as pd


WINDOW_SIZE = 10
STRIDE = 5

CONF_MODE = "quantile"
Q_VAL = 0.05
USE_TOP_Q = False


def seq_conf(token_probs):
    log_p = np.log(token_probs)
    if CONF_MODE == "min":
        return np.exp(log_p.min())
    elif CONF_MODE == "quantile":
        q = 1 - Q_VAL if USE_TOP_Q else Q_VAL
        return np.exp(np.quantile(log_p, q))

def detect_unstable_regions(ent_list, window_size=WINDOW_SIZE, stride=STRIDE):
    spike_flags = np.array([False] * len(ent_list))
    unstable_regions = []
    is_unstable = False
    unstable_start = None
    unstable_end = None

    global_mean = np.mean(ent_list)

    i = 0

    while i + window_size <= len(ent_list):
        curr_window = ent_list[i:i + window_size]
        curr_mean = np.mean(curr_window)

        if curr_mean > global_mean:
            if not is_unstable:
                is_unstable = True
                unstable_start = i
            spike_flags[i:i+window_size] = True
            unstable_end = i+window_size-1
        else:
            if is_unstable:
                unstable_regions.append((unstable_start, unstable_end))
                is_unstable = False
                unstable_start, unstable_end = None, None
        
        i += stride

    if is_unstable and unstable_start is not None:
        unstable_regions.append((unstable_start, len(ent_list)-1))
    
    return sorted(set(unstable_regions), key=lambda x: x[0]), spike_flags



def compute_label_stats_by_model(df, group_name):
    rows = []
    print(f"===== Group: {group_name} =====")
    for model in df['Model'].unique():
        print(f'Evaluation Model: {model}')
        stats = defaultdict(int)
        df_model = df[df['Model'] == model]

        for label in ['ALL', 'Hallucinated_Solution', 'Creative_Solution', 'Typical_Solution']:
            sub = df_model.copy() if label == 'ALL' else df_model[df_model['Label'] == label]
            if sub.empty:
                row = {
                    'Group': group_name,
                    'Model': model,
                    'Label': label,
                    'Sample_Count': 0
                }
                rows.append(row)
                continue
            row = {
                'Group': group_name,
                'Model': model,
                'Label': label,
                'Sample_Count': len(sub),
                'Sample_ratio': round(len(sub)/len(df_model)*100, 2)
            }
            stats[label] = len(sub)

            for metric in ['Token_Length', 'Entropy', 'HESR', 'Jaccard']:
                row[f'{metric}_Mean'] = round(np.mean(sub[metric]), 4)
                row[f'{metric}_Std'] = round(np.std(sub[metric]), 4)
            rows.append(row)

        print(f"ALL: {stats['ALL']}", end=', ')
        print(f"Hallucination: {stats['Hallucinated_Solution']}", end=', ')
        print(f"Creative: {stats['Creative_Solution']}", end=', ')
        print(f"Typical: {stats['Typical_Solution']}")
        print(f"Cor: {(stats['Creative_Solution']+stats['Typical_Solution'])/stats['ALL']*100:.2f}%, Novel/Cor: {stats['Creative_Solution']/(stats['Creative_Solution']+stats['Typical_Solution'])*100:.2f}%")

    return rows



for ROB_THRES in [0.2, 0.25, 0.3, 0.35, 0.4]:
    for group, gen in [['Qwen', 'math'], ['Qwen', 'inst'], ['Deepseek', 'math'], ['Deepseek', 'rl']]:
        models = ['base', 'math', 'Inst', 'rl'] if group == "Deepseek" else ['base', 'math', 'inst', 'Distill']
        data_dir = f'{group.lower()}_data'
        result_dir = f"results/win_{WINDOW_SIZE}_stride_{STRIDE}_qval{Q_VAL}_robth_{ROB_THRES}"
        os.makedirs(result_dir, exist_ok=True)
        output_path = f'{result_dir}/{group}-gen_{gen}_HE_analysis.csv'

        print(f"ROB Threshold: {ROB_THRES}")
        print(f'Generation model: {group}-{gen}')

        all_results = {model: [] for model in models} 

        for model in models:
            json_path = os.path.join(data_dir, f'{group}-gen_{gen}-eval_{model}.json')
            if not os.path.exists(json_path):
                continue

            with open(json_path, 'r') as f:
                samples = json.load(f)

            for sample in samples:
                if 'token_entropy_info' not in sample:
                    all_results[model].append({})
                    continue
                ent_list = [t['entropy'] for t in sample['token_entropy_info']]
                token_probs = [t["pred_prob"] for t in sample["token_entropy_info"]]
                conf = seq_conf(token_probs)
                total_len = len(ent_list)
                hes, hesm = detect_unstable_regions(ent_list)
                lengths = [end - start for start, end in hes]

                result_row = {
                    'Model': f"{group}-{model}",
                    'Label': sample['evaluation']['final_decision'],
                    'Entropy_Group': 'ROB' if conf > ROB_THRES else 'NonROB',
                    'Entropy': np.mean(ent_list),
                    'Token_Length': total_len,
                    'HESM': hesm,
                    'HESR': round(sum(hesm) / total_len, 4),
                }
                all_results[model].append(result_row)

        results = []
        model_g = all_results[gen].copy()  # generation model
        for model in models:
            model_e = all_results[model]  # evaluation model
            for g, e in zip(model_g, model_e):
                if not g or not e:
                    continue
                HESM_g = g['HESM']
                HESM_e = e['HESM']
                
                set_g = set(i for i, val in enumerate(HESM_g) if val)
                set_e = set(i for i, val in enumerate(HESM_e) if val)

                if set_g and set_e:
                    intersection = set_g & set_e
                    union = set_g | set_e

                    e['Jaccard'] = len(intersection) / len(union) if union else 0
                    
                results.append(e)

        df_all = pd.DataFrame(results)

        stats_all = compute_label_stats_by_model(df_all, 'ALL')
        stats_rob = compute_label_stats_by_model(df_all[df_all['Entropy_Group'] == 'ROB'], 'ROB')
        stats_nonrob = compute_label_stats_by_model(df_all[df_all['Entropy_Group'] == 'NonROB'], 'NonROB')

        all_rows = stats_all + stats_rob + stats_nonrob
        final_rows = []
        prev_group, prev_model = None, None
        for row in all_rows:
            curr_group, curr_model = row['Group'], row['Model']
            if prev_group is not None and (curr_group != prev_group):
                final_rows.append({})
            final_rows.append(row)
            prev_group, prev_model = curr_group, curr_model

        df_stats = pd.DataFrame(final_rows)
        df_stats.to_csv(f'{output_path}', index=False)
        print(f'save results to {output_path}\n\n')