import os
import json
import numpy as np
from sklearn.metrics import roc_curve, roc_auc_score

def get_fpr_tpr_scores(TP, FP, N):
    y_true = [0] * N + [1] * N
    
    y_scores = FP + TP
    
    fpr, tpr, thresholds = roc_curve(y_true, y_scores)
    
    auc_score = roc_auc_score(y_true, y_scores)
    
    return {
        'fpr': fpr.tolist(),
        'tpr': tpr.tolist(), 
        'thresholds': thresholds.tolist(),
        'auc': float(auc_score)
    }

def load_watermarking_data(basepath, file_name):
    wo_water_filename = f"{file_name.replace('.json','')}_wo_watermark.json"

    with open(os.path.join(basepath, file_name), 'r', encoding='utf-8') as f:
        water = json.load(f) 
    with open(os.path.join(basepath, wo_water_filename), 'r', encoding='utf-8') as f:
        wo_water = json.load(f)
        
    return water, wo_water

def get_fpr_tpr(args):
    attack_suffix = getattr(args, 'attack_suffix', None)
    if attack_suffix:
        result_dir = f"./results/fpr_tpr_attacked/{attack_suffix}/{args.dataset_name}/{args.model_name}_{args.model_size}/"
        zscore_dir = f"./results/zscore_attacked/{attack_suffix}/{args.dataset_name}/{args.model_name}_{args.model_size}/"
        suffix = f"_attacked_{attack_suffix}"
    else:
        result_dir = f"./results/fpr_tpr/{args.dataset_name}/{args.model_name}_{args.model_size}/"
        zscore_dir = f"./results/zscore/{args.dataset_name}/{args.model_name}_{args.model_size}/"
        suffix = ""

    os.makedirs(result_dir, exist_ok=True)
    
    all_results = {}

    if not os.path.exists(zscore_dir):
        print(f"Warning: zscore directory does not exist: {zscore_dir}")
        return

    for file_name in os.listdir(zscore_dir):
        if 'prompt' in file_name:
            continue
        if 'wo_watermark' in file_name:
            continue
        if not file_name.endswith('.json'):
            continue

        try:
            print(f"Processing {file_name}...")
            
            water, wo_water = load_watermarking_data(zscore_dir, file_name)
            
            if len(water) == 0 or len(wo_water) == 0:
                print(f"Warning: Empty data in {file_name}")
                continue
                
            data_num = len(water)
            
            roc_results = get_fpr_tpr_scores(water, wo_water, data_num)
            
            file_key = file_name.replace('.json', '')
            all_results[file_key] = {
                'roc_curve': roc_results,
                'data_info': {
                    'total_watermarked': data_num,
                    'total_non_watermarked': data_num,
                    'file_name': file_name
                }
            }
            
        except Exception as e:
            print(f"Error processing {file_name}: {str(e)}")
            continue

    output_file = os.path.join(result_dir, f"fpr_tpr_results{suffix}.json")

    summary_file = os.path.join(result_dir, f"fpr_tpr_summary{suffix}.txt")
    with open(summary_file, 'w', encoding='utf-8') as f:
        f.write("# FPR/TPR Analysis Results\n")
        f.write("# TPR@X%: True Positive Rate when False Positive Rate ≤ X%\n")
        f.write("# Higher TPR@low_FPR indicates better watermark detection performance\n")
        f.write("=" * 80 + "\n")
        f.write("File Name :: AUC :: TPR@1% :: TPR@5% :: TPR@10%\n")
        f.write("=" * 80 + "\n")
        
        for file_key, results in all_results.items():
            auc = results['roc_curve']['auc']
            fpr_array = np.array(results['roc_curve']['fpr'])
            tpr_array = np.array(results['roc_curve']['tpr'])
            
            def get_tpr_at_max_fpr(target_fpr):
                idx = np.where(fpr_array <= target_fpr)[0]
                if len(idx) > 0:
                    return np.max(tpr_array[idx])
                return 0.0
            
            tpr = get_tpr_at_max_fpr(0.05)
            
            f.write(f"{file_key} :: {tpr:.3f}\n")