import os
import shutil
import sys
import asyncio
from pathlib import Path

project_root = Path(__file__).parent
sys.path.append(str(project_root))
sys.path.append(str(project_root / "experiment"))
import json
import traceback
from metagpt.logs import logger
import time
from datetime import datetime
import pandas as pd

from experiment.mlebench.registry_helper import load_competition
from experiment.mlebench.utils import get_module_dir
from metagpt.roles.di.data_interpreter import DataInterpreter
from metagpt.utils.recovery_util import save_history
from metagpt.context import Context


def run_mlebench_test(competition_id: str = "spaceship-titanic"):
    """
    Run a test of the Data Science Agent on a specific MLE-bench competition.
    """
    print(f"🚀 Starting MLE-bench test for competition: {competition_id}")

    # Setup paths
    workspace_dir = project_root / "workspace"
    data_dir = workspace_dir / "data"

    # Ensure workspace exists
    workspace_dir.mkdir(parents=True, exist_ok=True)
    
    # Update global configuration to use this specific workspace
    # This ensures ExecuteNbCode uses the correct absolute path
    competition = load_competition(competition_id)
    print(f"✅ Loaded competition: {competition.name}")
    try:
        # 1. Load Competition Info
        try:
            competition = load_competition(competition_id)
            print(f"✅ Loaded competition: {competition.name}")
        except Exception as e:
            print(f"❌ Failed to load competition '{competition_id}': {e}")
            return

        # 2. Clean up previous run & create fresh data dir
        if data_dir.exists():
            shutil.rmtree(data_dir)
            shutil.rmtree(workspace_dir)
            print(f"🧹 Cleaned up previous workspace and data directories.")
        data_dir.mkdir(parents=True, exist_ok=True)

        # Copy competition data to workspace/data
        print(f"📂 Copying data from {competition.public_dir} to {data_dir}...")
        try:
            if competition.public_dir.exists():
                shutil.copytree(competition.public_dir, data_dir, dirs_exist_ok=True)
            else:
                print(f"❌ Data directory not found: {competition.public_dir}")
                local_comp_data = project_root/ "experiment" / "mlebench" / competition_id / "data"
                if local_comp_data.exists():
                    print(f"⚠️  Using local fallback data path: {local_comp_data}")
                    shutil.copytree(local_comp_data, data_dir, dirs_exist_ok=True)
                else:
                    print("❌ Could not find data source. Aborting.")
                    return
        except Exception as e:
            print(f"❌ Error copying data: {e}")
            return
        
        # 3. Prepare Prompt
        instructions_path = project_root / "experiment" / "mlebench" / "instructions.txt"
        submission_sample_str = read_submission_sample(str(competition.public_dir))
        try:
            with open(instructions_path, "r", encoding="utf-8") as f:
                template = f.read()
            prompt = template.format(
                competition_name=competition.name,
                description=competition.description,
                submission_example=submission_sample_str
            )
            print("📝 Prompt prepared.")
        except Exception as e:
            print(f"❌ Error preparing prompt: {e}")
            return
        competition_dir = project_root /  "experiment" / "mlebench" / "competitions" / competition_id
        
        # 保存full_description到competition目录下的full_instructions.txt
        description_file = competition_dir / "full_instructions.txt"
        try:
            with open(description_file, "w", encoding="utf-8") as f:
                f.write(prompt)
            print(f"📝 Full description saved to {description_file}")
        except Exception as e:
            print(f"❌ Failed to save full description: {e}")
            return
        
        
        # 4. Run Agent
        print("\n🤖 Initializing MetaGPT DataInterpreter...")
        
        # Ensure submission directory exists
        (data_dir / "submission").mkdir(parents=True, exist_ok=True)

        print("▶️  Agent acting...")
        
        start_time = time.time()
        prompt_tokens = 0
        completion_tokens = 0
        total_tokens = 0
        async def run_di():
            context = Context()
            di = DataInterpreter(context=context)
            try:
                rsp=await di.run(prompt)
                logger.info(rsp)
                save_history(role=di)
            finally:
                prompt_tokens = di.context.cost_manager.total_prompt_tokens
                completion_tokens = di.context.cost_manager.total_completion_tokens
                logger.info(f"Total Prompt Tokens: {prompt_tokens}")
                logger.info(f"Total Completion Tokens: {completion_tokens}")
                total_tokens = prompt_tokens + completion_tokens
                logger.info(f"Total Tokens: {total_tokens}")
                

        asyncio.run(run_di())

        end_time = time.time()
        elapsed_time_minutes = (end_time - start_time) / 60
        print(f"【程序运行时间】 {elapsed_time_minutes:.2f} 分钟")

        # Save execution metrics
        metrics = {
            "execution_time_minutes": elapsed_time_minutes,
            "agent": "MetaGPT DataInterpreter",
            "token_usage": {
                "prompt_tokens": prompt_tokens,
                "completion_tokens": completion_tokens,
                "total_tokens": total_tokens
            }
        }
        
        metrics_file = competition_dir / "execution_metrics.json"
        
        try:
            competition_dir.mkdir(parents=True, exist_ok=True)
            with open(metrics_file, "w", encoding="utf-8") as f:
                json.dump(metrics, f, indent=4, ensure_ascii=False)
            print(f"📊 Execution metrics saved to {metrics_file}")
        except Exception as e:
            print(f"❌ Failed to save metrics: {e}")

        print("✅ Agent execution finished.")

        # 5. Verify Submission
        submission_path = data_dir / "submission" / "submission.csv"
        if submission_path.exists():
            print(f"🎉 Success! Submission file found at: {submission_path}")
            try:
                shutil.copy(submission_path, competition.public_dir.parent.parent / "submission.csv")
                print(f"✅ Submission file copied to competition directory.")
            except Exception as e:
                print(f"❌ Failed to copy submission file: {e}")
        else:
            print(f"⚠️  Submission file NOT found at: {submission_path}")
            alt_paths = [data_dir / "submission.csv", workspace_dir / "submission.csv"]
            for p in alt_paths:
                if p.exists():
                    print(f"⚠️  Found submission at alternative path: {p}")
                    shutil.copy(p, competition.public_dir.parent.parent / "submission.csv")
                    print(f"✅ Submission file copied to competition directory.")
        # 复制workspace中除data文件夹外的文件复制到competition的public目录的上级目录下的workspace里
        workspace_files_dir = competition.public_dir.parent / "workspace" / "di"
        workspace_files_dir.mkdir(parents=True, exist_ok=True)
        for item in workspace_dir.iterdir():
            if item.name != "data":
                if item.is_file():
                    shutil.copy2(item, workspace_files_dir / item.name)
                elif item.is_dir():
                    dest_dir = workspace_files_dir / item.name
                    if dest_dir.exists():
                        shutil.rmtree(dest_dir)
                    shutil.copytree(item, dest_dir)
        print(f"✅ Workspace files copied to {workspace_files_dir}")
        # 执行命令python extracted_grader.py <competition_id>
        from experiment.mlebench.grade import grade_csv
        from experiment.mlebench.registry import registry
        submission_path = Path("experiment/mlebench/competitions")
        submission_path = submission_path / str(competition_id) / "submission.csv"
        competition = registry.get_competition(competition_id)
        # grade_sample(submission_path, competition_id)
        report=grade_csv(submission_path, competition)
        print("Competition report:")
        print(json.dumps(report.to_dict(), indent=4))
        # 保存评分报告到对应竞赛目录
        report_path = submission_path.parent / "grading_report_di.json"
        with open(report_path, "w", encoding="utf-8") as f:
            json.dump(report.to_dict(), f, indent=4, ensure_ascii=False)
            print(f"Grading report saved to {report_path}")
    finally:
        # # 🔥 清理：无论成功失败，最后都删除 data_dir 内容
        # if data_dir.exists():
        #     try:
        #         shutil.rmtree(data_dir)
        #         print(f"🧹 Cleaned up: removed {data_dir}")
        #     except Exception as e:
        #         print(f"⚠️  Failed to clean up {data_dir}: {e}")
        pass

def read_submission_sample(public_dir: str) -> str:
    """
    读取 competition.public_dir 目录下文件名包含 'submission' 的 CSV 文件，
    返回前3行示例及总行数的字符串描述。
    对 object 类型字段进行清洗：替换换行符、截断过长内容，保证输出整洁。
    """
    submission_files = [
        f for f in os.listdir(public_dir)
        if 'submission' in f.lower() and f.endswith('.csv')
    ]
    
    if not submission_files:
        return "没有发现示例提交文件。"
    
    file_path = os.path.join(public_dir, submission_files[0])
    
    try:
        df = pd.read_csv(file_path)
        total_rows, total_cols = df.shape
        
        # 安全处理前3行：防止换行、超长、NaN 等破坏格式
        df_sample = df.head(3).copy()
        max_len = 80  # 每个单元格最多显示80字符
        
        for col in df_sample.columns:
            def safe_format(val):
                if pd.isna(val):
                    return "NaN"
                s = str(val)
                # 替换换行和回车，便于单行展示
                s = s.replace('\n', '\\n').replace('\r', '\\r')
                # 截断并加省略号
                if len(s) > max_len:
                    s = s[:max_len] + "..."
                return s
            
            df_sample[col] = df_sample[col].apply(safe_format)
        
        # 使用 to_string 并关闭过多换行
        head_str = df_sample.to_string(
            index=True,
            line_width=200,
            max_colwidth=max_len + 5,  # 略大于 max_len
            justify='left'
        )
        
        return (
            f"输出形状为 ({total_rows}, {total_cols})，前3行示例如下（必须严格按照此格式）:\n{head_str}"
        )
    
    except Exception as e:
        return f"读取 submission 文件时出错: {str(e)}"

if __name__ == "__main__":
    # You can change the competition_id here
    run_mlebench_test("stanford-covid-vaccine")
    # print(read_submission_sample("workspace/data"))
