#!/usr/bin/env python3
"""
Batch evaluation results collection script
Used to collect and organize comprehensive evaluation results for three modes under different intervals
"""

import os
import subprocess
import json
import pandas as pd
from datetime import datetime
import argparse
import sys

def run_evaluation(test_folder, original_folder="results/FLUX-DEV-50"):
    """Run single evaluation and return results"""
    eval_script = os.path.join(os.path.dirname(__file__), "eval.py")
    cmd = [
        sys.executable, eval_script,
        "--test_folder", test_folder,
        "--original_folder", original_folder,
        "--batch_size", "32",
        "--num_workers", "16"
    ]
    
    try:
        result = subprocess.run(cmd, capture_output=True, text=True, timeout=1800)  # 30 minute timeout
        if result.returncode == 0:
            # Parse output results
            output_lines = result.stdout.strip().split('\n')
            results = {}
            
            for line in output_lines:
                if "Total processed images:" in line:
                    results['total_images'] = int(line.split(':')[1].strip())
                elif "Successfully paired images:" in line:
                    results['valid_pairs'] = int(line.split(':')[1].strip())
                elif "Average CLIP Score:" in line:
                    results['clip_score'] = float(line.split(':')[1].strip())
                elif "Average ImageReward Score:" in line:
                    results['image_reward'] = float(line.split(':')[1].strip())
                elif "Average PSNR:" in line:
                    results['psnr'] = float(line.split(':')[1].strip())
                elif "Average SSIM:" in line:
                    results['ssim'] = float(line.split(':')[1].strip())
                elif "Average LPIPS:" in line:
                    results['lpips'] = float(line.split(':')[1].strip())
            
            return results
        else:
            print(f"Evaluation failed: {test_folder}")
            print(f"Error message: {result.stderr}")
            return None
            
    except subprocess.TimeoutExpired:
        print(f"Evaluation timeout: {test_folder}")
        return None
    except Exception as e:
        print(f"Evaluation exception: {test_folder}, error: {e}")
        return None

def collect_all_results(methods=['taylor', 'hicache'], interval_range=(1, 10)):
    """Collect all results"""
    all_results = []
    
    for method in methods:
        print(f"\nCollecting {method} mode results...")
        for interval in range(interval_range[0], interval_range[1] + 1):
            test_folder = f"results/{method}/interval_{interval}"
            
            if not os.path.exists(test_folder):
                print(f"Skipping non-existent folder: {test_folder}")
                continue
                
            print(f"  评估 interval_{interval}...")
            results = run_evaluation(test_folder)
            
            if results:
                results['method'] = method
                results['interval'] = interval
                results['test_folder'] = test_folder
                all_results.append(results)
                print(f"    ✓ 完成 {method} interval_{interval}")
            else:
                print(f"    ✗ 失败 {method} interval_{interval}")
    
    return all_results

def save_results_to_csv(results, filename):
    """保存结果到CSV文件"""
    df = pd.DataFrame(results)
    df = df.reindex(columns=['method', 'interval', 'clip_score', 'image_reward', 
                            'psnr', 'ssim', 'lpips', 'total_images', 'valid_pairs'])
    df.to_csv(filename, index=False)
    print(f"结果已保存到: {filename}")

def save_results_to_json(results, filename):
    """保存结果到JSON文件"""
    with open(filename, 'w', encoding='utf-8') as f:
        json.dump(results, f, ensure_ascii=False, indent=2)
    print(f"结果已保存到: {filename}")

def generate_markdown_report(results):
    """生成Markdown格式的报告"""
    df = pd.DataFrame(results)
    
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    report_file = f"results/comprehensive_analysis_{timestamp}.md"
    
    with open(report_file, 'w', encoding='utf-8') as f:
        f.write("# 综合评估分析报告\n\n")
        f.write(f"**生成时间**: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n")
        
        # 按方法分组显示结果
        methods = df['method'].unique()
        
        for method in methods:
            method_data = df[df['method'] == method].sort_values('interval')
            f.write(f"## {method.capitalize()} 方法\n\n")
            
            # 创建表格
            f.write("| Interval | CLIP Score | ImageReward | PSNR | SSIM | LPIPS |\n")
            f.write("|----------|------------|-------------|------|------|-------|\n")
            
            for _, row in method_data.iterrows():
                f.write(f"| {row['interval']} | {row['clip_score']:.4f} | {row['image_reward']:.4f} | "
                       f"{row['psnr']:.3f} | {row['ssim']:.4f} | {row['lpips']:.4f} |\n")
            
            f.write("\n")
        
        # 添加最佳性能分析
        f.write("## 最佳性能分析\n\n")
        
        metrics = ['clip_score', 'image_reward', 'psnr', 'ssim', 'lpips']
        metric_names = ['CLIP Score', 'ImageReward', 'PSNR', 'SSIM', 'LPIPS']
        
        for metric, name in zip(metrics, metric_names):
            if metric == 'lpips':  # LPIPS越低越好
                best_row = df.loc[df[metric].idxmin()]
                f.write(f"- **最佳{name}**: {best_row['method']} (interval={best_row['interval']}) = {best_row[metric]:.4f}\n")
            else:  # 其他指标越高越好
                best_row = df.loc[df[metric].idxmax()]
                f.write(f"- **最佳{name}**: {best_row['method']} (interval={best_row['interval']}) = {best_row[metric]:.4f}\n")
        
        f.write("\n")
        
        # 添加趋势分析
        f.write("## 趋势分析\n\n")
        for method in methods:
            method_data = df[df['method'] == method].sort_values('interval')
            f.write(f"### {method.capitalize()} 方法趋势\n")
            
            # 分析各指标的变化趋势
            for metric, name in zip(metrics, metric_names):
                values = method_data[metric].values
                if len(values) > 1:
                    trend = "上升" if values[-1] > values[0] else "下降"
                    f.write(f"- {name}: {values[0]:.4f} → {values[-1]:.4f} ({trend})\n")
            f.write("\n")
    
    print(f"Markdown报告已生成: {report_file}")
    return report_file

def main():
    parser = argparse.ArgumentParser(description="批量收集评估结果")
    parser.add_argument("--methods", nargs='+', default=['taylor', 'hicache'], 
                       help="要收集的方法列表")
    parser.add_argument("--interval_start", type=int, default=1, help="起始interval")
    parser.add_argument("--interval_end", type=int, default=10, help="结束interval")
    parser.add_argument("--output_dir", default="results", help="输出目录")
    
    args = parser.parse_args()
    
    print("="*60)
    print("批量收集评估结果")
    print("="*60)
    print(f"方法: {args.methods}")
    print(f"Interval范围: {args.interval_start}-{args.interval_end}")
    print("="*60)
    
    # 收集所有结果
    results = collect_all_results(args.methods, (args.interval_start, args.interval_end))
    
    if not results:
        print("未收集到任何结果！")
        return
    
    # 确保输出目录存在
    os.makedirs(args.output_dir, exist_ok=True)
    
    # 保存结果
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    
    # 保存为CSV
    csv_file = os.path.join(args.output_dir, f"evaluation_results_{timestamp}.csv")
    save_results_to_csv(results, csv_file)
    
    # 保存为JSON
    json_file = os.path.join(args.output_dir, f"evaluation_results_{timestamp}.json")
    save_results_to_json(results, json_file)
    
    # 生成Markdown报告
    report_file = generate_markdown_report(results)
    
    print("\n" + "="*60)
    print("收集完成！")
    print(f"总计收集了 {len(results)} 个评估结果")
    print(f"CSV文件: {csv_file}")
    print(f"JSON文件: {json_file}")
    print(f"Markdown报告: {report_file}")
    print("="*60)

if __name__ == "__main__":
    main() 