import os
import shutil
import sys
from pathlib import Path

project_root = Path(__file__).parent.parent
sys.path.append(str(project_root))
sys.path.append(str(project_root / "experiment"))
import json
import traceback
import globals
import logger
import time
from datetime import datetime
import utils
import pandas as pd

# Add project root to sys.path


from agent.datascience import DSagent
from mlebench.registry_helper import load_competition
from mlebench.utils import get_module_dir
from mlebench.grade import grade_csv
from mlebench.registry import registry

def run_mlebench_test(competition_id: str = "spaceship-titanic"):
    """
    Run a test of the Data Science Agent on a specific MLE-bench competition.
    """
    print(f"🚀 Starting MLE-bench test for competition: {competition_id}")

    # Setup paths
    workspace_dir = project_root / "workspace"
    data_dir = workspace_dir / "data"

    # Ensure workspace exists
    workspace_dir.mkdir(parents=True, exist_ok=True)
    
    # Update global configuration to use this specific workspace
    # This ensures ExecuteNbCode uses the correct absolute path
    utils.module_config.workspace_path = str(workspace_dir.resolve())
    globals.execute_nb_code.set_nb_client()
    print(f"🔧 Updated workspace path to: {utils.module_config.workspace_path}")

    try:
        # 1. Load Competition Info
        try:
            competition = load_competition(competition_id)
            print(f"✅ Loaded competition: {competition.name}")
        except Exception as e:
            print(f"❌ Failed to load competition '{competition_id}': {e}")
            return

        # 2. Clean up previous run & create fresh data dir
        if data_dir.exists():
            shutil.rmtree(data_dir)
            shutil.rmtree(workspace_dir)
            print(f"🧹 Cleaned up previous workspace and data directories.")
        data_dir.mkdir(parents=True, exist_ok=True)

        # Copy competition data to workspace/data
        print(f"📂 Copying data from {competition.public_dir} to {data_dir}...")
        try:
            if competition.public_dir.exists():
                shutil.copytree(competition.public_dir, data_dir, dirs_exist_ok=True)
            else:
                print(f"❌ Data directory not found: {competition.public_dir}")
                local_comp_data = project_root / "experiment" / "mlebench" / competition_id / "data"
                if local_comp_data.exists():
                    print(f"⚠️  Using local fallback data path: {local_comp_data}")
                    shutil.copytree(local_comp_data, data_dir, dirs_exist_ok=True)
                else:
                    print("❌ Could not find data source. Aborting.")
                    return
        except Exception as e:
            print(f"❌ Error copying data: {e}")
            return

        # 3. Prepare Prompt
        instructions_path = project_root / "experiment" / "mlebench" / "instructions.txt"
        submission_sample_str = read_submission_sample(str(competition.public_dir))
        try:
            with open(instructions_path, "r", encoding="utf-8") as f:
                template = f.read()
            prompt = template.format(
                competition_name=competition.name,
                description=competition.description,
                submission_example=submission_sample_str
            )
            print("📝 Prompt prepared.")
        except Exception as e:
            print(f"❌ Error preparing prompt: {e}")
            return
        # Save full instructions to a file
        competition_dir = project_root / "experiment" / "mlebench" / "competitions" / competition_id
        competition_dir.mkdir(parents=True, exist_ok=True)
        desc_file = competition_dir / "full_instructions.txt"
        try:
            with open(desc_file, "w", encoding="utf-8") as f:
                f.write(prompt)
            print(f"📝 Full description saved to {desc_file}")
        except Exception as e:
            print(f"❌ Failed to save full description: {e}")
        # 4. Run Agent
        print("\n🤖 Initializing Data Science Agent...")
        agent = DSagent()
        run_id = f"test_mlebench_{competition_id}_{datetime.now().strftime('%Y%m%d%H%M%S')}"

        # Ensure submission directory exists
        (data_dir / "submission").mkdir(parents=True, exist_ok=True)

        print("▶️  Agent acting...")
        
        start_time = time.time()
        timestamp = datetime.now().strftime("%Y%m%d%H%M")
        logger.init(log_filename=f'mlebench_{competition_id}_{timestamp}.log')
        logger.info(
            "------------------------------【程序开始】------------------------------"
        )
        logger.debug(
            f"【API 配置】: {utils.api_config.config_name},",
        )
        extra_context={
            "submission_example": submission_sample_str
        }
        print(f"🧠 Extra context for agent prepared.{extra_context}")
        agent.act(requirement=prompt, run_id=run_id,complexity="auto",extra_context=extra_context)
        end_time = time.time()
        elapsed_time_minutes = (end_time - start_time) / 60
        logger.debug(f"【程序运行时间】 {elapsed_time_minutes:.2f} 分钟")

        # Check for submission file in multiple locations
        possible_submission_paths = [
            data_dir / "submission" / "submission.csv",
            data_dir / "submission.csv",
            workspace_dir / "submission.csv"
        ]
        
        found_submission_path = None
        for p in possible_submission_paths:
            if p.exists():
                found_submission_path = p
                break
        
        submission_found = found_submission_path is not None

        # Save execution metrics
        metrics = {
            "execution_time_minutes": elapsed_time_minutes,
            "token_stats": logger._token_stats,
            "agent":"r3",
            "competition_id": competition_id,
            "submission_found": submission_found
        }
        
        competition_dir = project_root / "experiment" / "mlebench" / "competitions" / competition_id
        metrics_file = competition_dir / "execution_metrics_r3.json"
        
        try:
            competition_dir.mkdir(parents=True, exist_ok=True)
            with open(metrics_file, "w", encoding="utf-8") as f:
                json.dump(metrics, f, indent=4, ensure_ascii=False)
            print(f"📊 Execution metrics saved to {metrics_file}")
            print("📊 Execution metrics:", metrics)
        except Exception as e:
            print(f"❌ Failed to save metrics: {e}")
        # 打印 LLM Token 累计消耗
        logger.log_token_summary()
        logger.info(
            "------------------------------【程序结束】------------------------------"
        )
        print("✅ Agent execution finished.")

        # 5. Verify Submission
        if submission_found:
            print(f"🎉 Success! Submission file found at: {found_submission_path}")
            try:
                shutil.copy(found_submission_path, competition.public_dir.parent.parent / "submission.csv")
                print(f"✅ Submission file copied to competition directory.")
            except Exception as e:
                print(f"❌ Failed to copy submission file: {e}")
            # 复制workspace中除data文件夹外的文件复制到competition的public目录的上级目录下的workspace里
            workspace_files_dir = competition.public_dir.parent / "workspace-r3"
            #先删除目标目录再复制
            if workspace_files_dir.exists():
                shutil.rmtree(workspace_files_dir)
            workspace_files_dir.mkdir(parents=True, exist_ok=True)
            for item in workspace_dir.iterdir():
                if item.name != "data":
                    if item.is_file():
                        shutil.copy2(item, workspace_files_dir / item.name)
                    elif item.is_dir():
                        dest_dir = workspace_files_dir / item.name
                        if dest_dir.exists():
                            shutil.rmtree(dest_dir)
                        shutil.copytree(item, dest_dir)
            print(f"✅ Workspace files copied to {workspace_files_dir}")

            try:
                report = grade_csv(found_submission_path, competition)
                print("Competition report:")
                print(json.dumps(report.to_dict(), indent=4))
                
                # Save grading report
                report_path = competition_dir / "grading_report_r3.json"
                with open(report_path, "w", encoding="utf-8") as f:
                    json.dump(report.to_dict(), f, indent=4, ensure_ascii=False)
                print(f"Grading report saved to {report_path}")
            except Exception as e:
                print(f"❌ Grading failed: {e}")
                traceback.print_exc()
        else:
            print(f"⚠️  Submission file NOT found. Checked locations: {[str(p) for p in possible_submission_paths]}")

    finally:
        # # 🔥 清理：无论成功失败，最后都删除 data_dir 内容
        # if data_dir.exists():
        #     try:
        #         shutil.rmtree(data_dir)
        #         print(f"🧹 Cleaned up: removed {data_dir}")
        #     except Exception as e:
        #         print(f"⚠️  Failed to clean up {data_dir}: {e}")
        pass

def read_submission_sample(public_dir: str) -> str:
    """
    读取 competition.public_dir 目录下文件名包含 'submission' 的 CSV 文件，
    返回前3行示例及总行数的字符串描述。
    对 object 类型字段进行清洗：替换换行符、截断过长内容，保证输出整洁。
    """
    submission_files = [
        f for f in os.listdir(public_dir)
        if 'submission' in f.lower() and f.endswith('.csv')
    ]
    
    if not submission_files:
        return "没有发现示例提交文件。"
    
    file_path = os.path.join(public_dir, submission_files[0])
    
    try:
        df = pd.read_csv(file_path)
        total_rows, total_cols = df.shape
        
        # 安全处理前3行：防止换行、超长、NaN 等破坏格式
        df_sample = df.head(3).copy()
        max_len = 80  # 每个单元格最多显示80字符
        
        for col in df_sample.columns:
            def safe_format(val):
                if pd.isna(val):
                    return "NaN"
                s = str(val)
                # 替换换行和回车，便于单行展示
                s = s.replace('\n', '\\n').replace('\r', '\\r')
                # 截断并加省略号
                if len(s) > max_len:
                    s = s[:max_len] + "..."
                return s
            
            df_sample[col] = df_sample[col].apply(safe_format)
        
        # 使用 to_string 并关闭过多换行
        head_str = df_sample.to_string(
            index=True,
            line_width=200,
            max_colwidth=max_len + 5,  # 略大于 max_len
            justify='left'
        )
        
        return (
            f"输出形状为 ({total_rows}, {total_cols})，前3行示例如下（必须严格按照此格式）:\n{head_str}"
        )
    
    except Exception as e:
        return f"读取 submission 文件时出错: {str(e)}"

if __name__ == "__main__":
    import sys
    if len(sys.argv) > 1:
        comp_id = sys.argv[1]
    else:
        # You can change the default competition_id here
        comp_id = "bms-molecular-translation"
    
    run_mlebench_test(comp_id)
