import sys
import subprocess
import json
import pandas as pd
import traceback
from pathlib import Path
from datetime import datetime

# 设置基础路径
project_root = Path(__file__).parent.parent

def run_batch(competitions):
    results = []
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    
    # 结果输出CSV文件路径
    output_dir = project_root / "output"
    output_dir.mkdir(parents=True, exist_ok=True)
    summary_csv = output_dir / f"batch_mlebench_summary_{timestamp}.csv"
    
    print(f"🚀 开始批量测试，共 {len(competitions)} 个竞赛任务")
    print(f"📂 结果将保存至: {summary_csv}")

    for idx, comp_id in enumerate(competitions, 1):
        print(f"\n[{idx}/{len(competitions)}] {'='*40}")
        print(f"▶️ 正在运行竞赛: {comp_id}")
        print(f"{'='*40}")
        
        start_time = datetime.now()
        status = "Success"
        error_msg = ""
        score = None
        gold_threshold = None
        exec_time = 0
        total_tokens = 0

        try:
            # 使用 subprocess 运行单个测试脚本，确保环境隔离
            script_path = project_root / "experiment" / "run_mlebench_test.py"
            # 继承当前环境变量
            cmd = [sys.executable, str(script_path), comp_id]
            
            # 使用 check=True 确保脚本报错时能被捕获，
            # 但为了不中断整个 batch，这里捕获 CalledProcessError
            subprocess.run(cmd, check=True)

        except subprocess.CalledProcessError as e:
            status = "Failed"
            error_msg = f"Script execution failed with code {e.returncode}"
            print(f"❌ 竞赛 {comp_id} 脚本执行失败: {e}")
        except Exception as e:
            status = "Error"
            error_msg = str(e)
            print(f"❌ 竞赛 {comp_id} 发生未知错误: {e}")
            traceback.print_exc()

        # 无论成功与否，尝试读取结果文件获取指标
        # 即使 Script Failed，也有可能生成了部分 metrics
        try:
            comp_result_dir = project_root / "experiment" / "mlebench" / "competitions" / comp_id
            metrics_file = comp_result_dir / "execution_metrics_r3.json"
            report_file = comp_result_dir / "grading_report_r3.json"

            # 读取执行指标 (Time, Tokens)
            if metrics_file.exists():
                with open(metrics_file, 'r', encoding='utf-8') as f:
                    metrics = json.load(f)
                    exec_time = metrics.get("execution_time_minutes", 0)
                    token_stats = metrics.get("token_stats", {})
                    total_tokens = token_stats.get("total_tokens", 0)
                    submission_found = metrics.get("submission_found", False)
            
            # 读取评分报告 (Score)
            if report_file.exists():
                with open(report_file, 'r', encoding='utf-8') as f:
                    report = json.load(f)
                    # 尝试从不同字段获取分数
                    score = report.get("score") 
                    gold_threshold = report.get("gold_threshold")
                    
            else:
                if status == "Success":
                    status = "No Report"
                    error_msg = "Execution finished but no grading report found"

        except Exception as read_err:
            print(f"⚠️ 读取结果文件时出错: {read_err}")
            if not error_msg:
                error_msg = f"Read results failed: {read_err}"

        # 记录本轮结果
        end_time = datetime.now()
        duration = (end_time - start_time).total_seconds() / 60.0
        
        # 如果 exec_time 为0（未读取到），可以用外部计时补充
        if exec_time == 0:
            exec_time = duration

        entry = {
            "Competition": comp_id,
            "Status": status,
            "Score": score,
            "gold_threshold": gold_threshold,
            "Time(min)": round(exec_time, 2),
            "Tokens": total_tokens,
            "Submission Found": submission_found,
            "Error": error_msg,
            "Timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        }
        results.append(entry)
        
        # 实时保存结果到CSV，防止中断丢失
        cols = ["Competition", "Status", "Score", "gold_threshold", "Time(min)", "Tokens", "Submission Found", "Error", "Timestamp"]
        try:
            pd.DataFrame(results).reindex(columns=cols).to_csv(summary_csv, index=False)
        except Exception as e:
            print(f"⚠️ 保存中间结果失败: {e}")

        # 实时打印简报
        print(f"\n🏁 竞赛 {comp_id} 结束")
        print(f"   状态: {status}")
        print(f"   得分: {score}")
        print(f"   耗时: {entry['Time(min)']} min")
        print(f"   Tokens: {total_tokens}")
        print(f"   Submission Found: {submission_found}")
        if error_msg:
            print(f"   错误信息: {error_msg}")
    # 全部结束后汇总
    print("\n\n" + "="*60)
    print("📊 BATCH EXECUTION SUMMARY")
    print("="*60)
    
    df = pd.DataFrame(results)
    # 调整列顺序
    cols = ["Competition", "Status", "Score", "gold_threshold", "Time(min)", "Tokens", "Submission Found", "Error", "Timestamp"]

    # 确保列都存在
    df = df.reindex(columns=cols)
    
    print(df.to_string(index=False))
    
    df.to_csv(summary_csv, index=False)
    print(f"\n✅ 汇总结果已保存至: {summary_csv}")

if __name__ == "__main__":
    # 在此处配置需要批量运行的竞赛列表
    competitions_list = [
        # "detecting-insults-in-social-commentary",
        # "new-york-city-taxi-fare-prediction",
        # "plant-pathology-2020-fgvc7",
        # "mlsp-2013-birds",
        # "google-quest-challenge",
        # "us-patent-phrase-to-phrase-matching",
        # "petfinder-pawpularity-score",
        # "tensorflow-speech-recognition-challenge",
        # "tgs-salt-identification-challenge",
        # "ventilator-pressure-prediction",
        # "stanford-covid-vaccine",
        # "predict-volcanic-eruptions-ingv-oe",
        # "nfl-player-contact-detection",
        # "bms-molecular-translation",
        # "lmsys-chatbot-arena",
        # "text-normalization-challenge-english-language",
        # "denoising-dirty-documents",
        
        # "detecting-insults-in-social-commentary",
        # "plant-pathology-2020-fgvc7",
        # "google-quest-challenge",
        # "us-patent-phrase-to-phrase-matching",
        # "petfinder-pawpularity-score",
        # "stanford-covid-vaccine",
        "denoising-dirty-documents",
        "predict-volcanic-eruptions-ingv-oe",
        "lmsys-chatbot-arena",
        "text-normalization-challenge-english-language",
        
    ]
    
    if not competitions_list:
        print("请在脚本中配置 competitions_list")
    else:
        run_batch(competitions_list)
    
    # try:
    #     start_time = datetime.now()
    #     status = "Success"
    #     error_msg = ""
    #     score = None
    #     gold_threshold = None
    #     exec_time = 0
    #     total_tokens = 0
    #     comp_result_dir = project_root / "experiment" / "mlebench" / "competitions" / "nfl-player-contact-detection"
    #     metrics_file = comp_result_dir / "execution_metrics.json"
    #     report_file = comp_result_dir / "grading_report_autods.json"

    #     # 读取执行指标 (Time, Tokens)
    #     if metrics_file.exists():
    #         with open(metrics_file, 'r', encoding='utf-8') as f:
    #             metrics = json.load(f)
    #             exec_time = metrics.get("execution_time_minutes", 0)
    #             token_stats = metrics.get("token_stats", {})
    #             total_tokens = token_stats.get("total_tokens", 0)
        
    #     # 读取评分报告 (Score)
    #     if report_file.exists():
    #         with open(report_file, 'r', encoding='utf-8') as f:
    #             report = json.load(f)
    #             # 尝试从不同字段获取分数
    #             score = report.get("score") 
    #             gold_threshold = report.get("gold_threshold")
                
        

    # except Exception as read_err:
    #     print(f"⚠️ 读取结果文件时出错: {read_err}")
    #     if not error_msg:
    #         error_msg = f"Read results failed: {read_err}"

    # # 记录本轮结果
    # end_time = datetime.now()
    # duration = (end_time - start_time).total_seconds() / 60.0
    
    # # 如果 exec_time 为0（未读取到），可以用外部计时补充
    # if exec_time == 0:
    #     exec_time = duration

    # entry = {
    #     "Competition": "nfl-player-contact-detection",
    #     "Status": status,
    #     "Score": score,
    #     "gold_threshold": gold_threshold,
    #     "Time(min)": round(exec_time, 2),
    #     "Tokens": total_tokens,
    #     "Error": error_msg,
    #     "Timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    # }
    # print(entry)
