import os
import json
import re
from collections import defaultdict

def parse_evaluation_summary(file_path):
    if not os.path.exists(file_path):
        return None
    metrics = {}
    with open(file_path, 'r') as f:
        content = f.read()
    
    patterns = {
        'l2_mse': r'l2_mse:\s*([\d.]+)',
        'lpips_alex': r'lpips_alex:\s*([\d.]+)',
        'brisque_adv': r'brisque_adv:\s*([\d.]+)',
    }
    for key, pattern in patterns.items():
        match = re.search(pattern, content)
        if match:
            metrics[key] = float(match.group(1))
    return metrics if len(metrics) == 3 else None

def parse_eval_id_json(file_path):
    if not os.path.exists(file_path):
        return None
    try:
        with open(file_path, 'r') as f:
            data = json.load(f)
        return data.get('averages', {}).get('relative_id_mean')
    except:
        return None

def get_whitebox_metrics(batch_dir, target_model, method):
    whitebox_folder = f"result_{target_model}_{method}"
    whitebox_path = os.path.join(batch_dir, whitebox_folder)
    
    if not os.path.exists(whitebox_path):
        for folder in os.listdir(batch_dir):
            if folder.startswith('result_') and target_model in folder and folder.endswith(f'_{method}'):
                whitebox_path = os.path.join(batch_dir, folder)
                break
    
    summary_path = os.path.join(whitebox_path, 'evaluation_summary.txt')
    eval_id_path = os.path.join(whitebox_path, 'eval_id.json')
    
    metrics = parse_evaluation_summary(summary_path)
    if metrics is None:
        eval_json = os.path.join(whitebox_path, 'evaluation_results.json')
        if os.path.exists(eval_json):
            with open(eval_json, 'r') as f:
                data = json.load(f)
            metrics = {k: data.get(k, 0) for k in ['l2_mse', 'lpips_alex', 'brisque_adv']}

    rel_id = parse_eval_id_json(eval_id_path)
    if rel_id is None:
        id_res_path = f"/id_result/{target_model}_{method}.json"
        rel_id = parse_eval_id_json(id_res_path)
    
    if rel_id is None and target_model == 'psp_mix':
        id_res_path = f"/id_result/psp_{method}.json"
        rel_id = parse_eval_id_json(id_res_path)
        
    return metrics, rel_id

def calculate_source_transferability():
    batch_dirs = [

        
        # your result paths here
        '''e.g. 
        "/batch_results_pgd",
        "/batch_results_anti", ...'''
    ]
    models = ['blendface', 'diffae', 'psp_mix', 'simswap', 'stargan', 'styleclip']
    final_report = {}

    for batch_dir in batch_dirs:
        if not os.path.exists(batch_dir): continue
        
        method = os.path.basename(batch_dir).replace("batch_results_", "")
        blackbox_dir = os.path.join(batch_dir, "black_box_results")
        if not os.path.exists(blackbox_dir): continue

        print(f"\n{'='*80}")
        print(f" METHOD: {method.upper()} | Analyzing Source Transferability Power")
        print(f"{'='*80}")

        wb_cache = {}
        for m in models:
            wb_met, wb_id = get_whitebox_metrics(batch_dir, m, method)
            if wb_met and wb_id:
                wb_cache[m] = {'metrics': wb_met, 'id': wb_id}

        source_data = defaultdict(list)

        for bb_folder in sorted(os.listdir(blackbox_dir)):
            if not bb_folder.startswith('black_box_'): continue
            
            parts = bb_folder.replace('black_box_', '').split('_')
            if len(parts) >= 2:
                if parts[-1] == 'mix': 
                    bb_target = f"{parts[-2]}_{parts[-1]}"
                    bb_source = '_'.join(parts[:-2])
                else:
                    bb_target = parts[-1]
                    bb_source = '_'.join(parts[:-1])
                if bb_source == 'psp': bb_source = 'psp_mix'
            else: continue

            if bb_source == bb_target: continue
            if bb_target not in wb_cache: continue

            bb_path = os.path.join(blackbox_dir, bb_folder)
            bb_met = parse_evaluation_summary(os.path.join(bb_path, 'evaluation_summary.txt'))
            bb_id = parse_eval_id_json(os.path.join(bb_path, 'eval_id.json'))

            if bb_met and bb_id is not None:
                wb = wb_cache[bb_target]
                
                l2_n = min(1.0, bb_met['l2_mse'] / wb['metrics']['l2_mse']) if wb['metrics']['l2_mse'] > 0 else 0
                lpips_n = min(1.0, bb_met['lpips_alex'] / wb['metrics']['lpips_alex']) if wb['metrics']['lpips_alex'] > 0 else 0
                brisque_n = min(1.0, bb_met['brisque_adv'] / wb['metrics']['brisque_adv']) if wb['metrics']['brisque_adv'] > 0 else 0
                id_n = min(1.0, bb_id / wb['id']) if wb['id'] and wb['id'] > 0 else 0
                
                pair_score = (l2_n + lpips_n + brisque_n + id_n) / 4
                source_data[bb_source].append({
                    'target': bb_target,
                    'score': pair_score,
                    'details': {'l2': l2_n, 'lpips': lpips_n, 'brisque': brisque_n, 'id': id_n}
                })

        method_summary = {}
        for src in models:
            results = source_data.get(src, [])
            if not results: continue
            
            avg_score = sum(r['score'] for r in results) / len(results)
            method_summary[src] = {
                'avg_transfer_power': avg_score,
                'target_count': len(results),
                'detailed_targets': results
            }
            print(f"  Source: {src:12} | Score: {avg_score:.4f} | (Transferred to {len(results)} models)")
        
        final_report[method] = method_summary

    out_path = "source_transfer_results.json"
    with open(out_path, 'w') as f:
        json.dump(final_report, f, indent=2)
    print(f"\nDetailed results saved to: {out_path}")

if __name__ == "__main__":
    calculate_source_transferability()