#!/usr/bin/env python
# -*- coding: utf-8 -*-
from moosf_combined_prompts import PromptsCombined
from moosf_combined_problem import MOOSFCombinedProblem
from eoh.utils.getParas import Paras
from eoh import eoh
import re
import torch
import numpy as np
import textwrap
import os
import logging
import sys

# --- Get experiment tag and tasks from environment variables ---
experiment_tag = os.environ.get('EXPERIMENT_TAG')
eval_tasks_str = os.environ.get('EVAL_TASKS')

# Add print statements for debugging received environment variables
print(
    f"DEBUG runEoH_combined.py: Initial EXPERIMENT_TAG from env: '{experiment_tag}'")
print(
    f"DEBUG runEoH_combined.py: Initial EVAL_TASKS_STR from env: '{eval_tasks_str}'")

if not experiment_tag or not eval_tasks_str:
    print("ERROR: Environment variables EXPERIMENT_TAG and EVAL_TASKS must be set.")
    print("Example: EXPERIMENT_TAG=\"my_experiment\" EVAL_TASKS=\"task1,task2,task3\" python runEoH_combined.py")
    sys.exit(1)

eval_tasks_list = [task.strip()
                   for task in eval_tasks_str.split(',') if task.strip()]
if not eval_tasks_list:
    print("ERROR: EVAL_TASKS environment variable was empty or malformed (e.g., only commas or whitespace). Please provide comma-separated task names.")
    sys.exit(1)

print(
    f"DEBUG runEoH_combined.py: Running with EXPERIMENT_TAG: {experiment_tag}")
print(f"DEBUG runEoH_combined.py: Parsed eval_tasks_list: {eval_tasks_list}")
# --- End Environment Variable Handling ---

# Force basic logging configuration at the very beginning
logging.basicConfig(level=logging.INFO,
                    format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
                    stream=sys.stderr)  # Log to stderr

# You can change level=logging.INFO to level=logging.DEBUG for more verbose logs


# Add project root to Python path
project_root = os.path.abspath(os.path.join(
    os.path.dirname(__file__), '..', '..', '..'))
sys.path.insert(0, project_root)

# Use the original EoH import style
# Removed: from eoh import EvolutionController, LlmProvider, Prompt


# --- GPU Configuration ---
# Keep this if using PCI bus order is important
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
# os.environ['CUDA_VISIBLE_DEVICES'] = '0'  # Remove this line - Handled by EoH worker assignment
print(f"GPU assignment will be handled by individual EoH worker processes.")

# --- Configuration (Defined here, used in paras.set_paras below) ---
llm_model = "gpt-4o-mini"
llm_api_key = "sk-CUezkFgHUZrxYg2PCMxFbC97DxDecZc2K9bGsU8GZZEsw6RG"
llm_api_endpoint = "api.agicto.cn"
population_size = 4
num_generations = 4
# --- MODIFIED Paths and Args ---
LOS_PROJECT_DIR = os.path.abspath(os.path.join(
    os.path.dirname(__file__), '..', '..', '..', 'LOS'))
# --- Define base output_dir and then append experiment_tag ---
imb_ratio = os.environ.get('IMB_RATIO', '100')  # 默认为100
base_output_dir = os.path.join(
    os.path.dirname(__file__), f'results_los{imb_ratio}_combined')
# Using experiment_tag from env
output_dir = os.path.join(base_output_dir, experiment_tag)
# --- Flexible GPU ID Config --- #
# <<< SET THE DESIRED GPU IDs IN THIS LIST >>>
TARGET_GPU_IDS = [0, 1, 2, 3]  # Example: Use GPUs 4, 5, 6, 7
# TARGET_GPU_IDS = [0, 1] # Example: Use GPUs 0, 1

if not TARGET_GPU_IDS or not all(isinstance(gid, int) for gid in TARGET_GPU_IDS):
    print("Error: TARGET_GPU_IDS must be a non-empty list of integers. Exiting.")
    sys.exit(1)

num_gpus_to_use = len(TARGET_GPU_IDS)
exp_proc = num_gpus_to_use  # Number of parallel evaluations = number of specific GPUs
# --- End Flexible GPU ID Config ---

# --- Dynamically Construct Evaluation Arguments for main_stage2.py ---
EVAL_ARGS_FOR_STAGE2 = ['--MOOSF']  # Start with common MOOSF flag
# Add dynamic tasks from environment variable
EVAL_ARGS_FOR_STAGE2.extend(['--tasks'] + eval_tasks_list)

# Add other fixed arguments
EVAL_ARGS_FOR_STAGE2.extend([
    '--finetune_epoch', '10',
    '--finetune_lr', '0.0005',
    '--finetune_wd', '0',
    '--network', 'resnet34',
    '--dataset', 'cifar100',
    '--imb_ratio', imb_ratio,
    '--num_class', '100',
    '--pretrained_pth', '/data/zz/zhiheng/LOS-main/output/cifar100_IR=10_stage1/5-10_23-2-18_NEW/best_model_stage1.pth',
    '--batch_size', '64',
])

# --- Evaluation Optimization Arguments ---
EVAL_ARGS_FOR_STAGE2.extend([
    '--no-save-model',  # Disable saving best model checkpoint
    '--no-head-val'    # Disable individual head validation sanity check
])
# --- End Evaluation Optimization Arguments ---

# --- Control List for Evolution --- #
# List the names of modules you want the LLM to evolve.
# Any module *not* in this list will be fixed to its original code.
# Options: "hica", "ghop", "eoss"
# FIXED: Temporarily evolve only HICA to test
modules_to_evolve = ["hica", "ghop"]
# Example: modules_to_evolve = ["hica"] # Only evolve HICA
# Example: modules_to_evolve = ["ghop", "eoss"] # Evolve GHOP and EOSS
# ---------------------------------- #

# --- Define Original/Baseline Code Snippets --- #
# --- MODIFIED: Load from LOS files ---
# Helper function to load code safely


def load_original_code(file_path):
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            # Read the whole content first
            content = f.read()
            # Attempt to dedent, handle potential errors if content is empty/weird
            try:
                return textwrap.dedent(content)
            except Exception as dedent_e:
                print(
                    f"Warning: Could not dedent {file_path}: {dedent_e}. Returning raw content.")
                return content
    except FileNotFoundError:
        print(f"Error: Original code file not found: {file_path}")
        return f"# Error: File not found at {file_path}"
    except Exception as e:
        print(f"Error loading original code from {file_path}: {e}")
        return f"# Error loading code from {file_path}"

# Helper function to extract class/method (basic version)


def extract_code_block(full_code, block_type, name):
    if block_type == "class":
        pattern = re.compile(
            rf"^\s*class\s+{name}\b.*?^(?=\s*(?:class|def)\s|\Z)", re.S | re.M)
        match = pattern.search(full_code)
        if match:
            # Dedent the extracted block as well
            try:
                return textwrap.dedent(match.group(0).strip())
            except Exception:
                return match.group(0).strip()  # Fallback to raw strip
        else:
            return f"# Error: Class '{name}' not found"
    elif block_type == "method_body":
        # Find class first if needed, then method, then extract body
        # This is complex and error-prone with regex, manual extraction is safer for now
        # Placeholder:
        return f"# Error: Method body extraction for '{name}' not implemented robustly. Manual extraction recommended."
    return f"# Error: Unknown block type '{block_type}'"


# Correct paths within LOS
hica_parent_file = os.path.join(LOS_PROJECT_DIR, 'MultiBackward', 'MBACK.py')
ghop_parent_file = os.path.join(
    LOS_PROJECT_DIR, 'MultiBackward', 'GradFun', 'pcgrad.py')
eoss_parent_file = os.path.join(
    LOS_PROJECT_DIR, 'MultiBackward', 'ACCFun', 'Mul_same_task.py')

# Load full file contents
hica_full_code = load_original_code(hica_parent_file)
ghop_full_code = load_original_code(ghop_parent_file)
eoss_full_code = load_original_code(eoss_parent_file)

# Extract the required blocks (using basic extraction, might need manual refinement)
# For classes, extraction might work. For method body, it's harder.
print("Attempting to extract original code blocks from LOS files...")
original_hica_code_str = extract_code_block(hica_full_code, "class", "Pla")
# --- GHOP: Use actual code from LOS/MultiBackward/GradFun/pcgrad.py --- #
original_ghop_body_str = textwrap.dedent("""
       # Note: Imports (torch, copy, random) should be handled by pcgrad.py context
       shared = torch.stack(has_grads).prod(0).bool() 
       pc_grad, num_task = copy.deepcopy(grads), len(grads) 
       
       for g_i in pc_grad:
           random.shuffle(grads) 
           for g_j in grads:
               g_i_g_j = torch.dot(g_i, g_j) 
               if g_i_g_j < 0: 
                   g_i -= (g_i_g_j) * g_j / (g_j.norm()**2) # No epsilon here in actual code
       
       merged_grad = torch.zeros_like(grads[0]).to(grads[0].device)
       if self._reduction == 'mean': 
           merged_grad[shared] = torch.stack([g[shared] 
                                          for g in pc_grad]).mean(dim=0)
       elif self._reduction == 'sum':
            merged_grad[shared] = torch.stack([g[shared]
                                          for g in pc_grad]).sum(dim=0)
       else: exit('invalid reduction method')

       merged_grad[~shared] = torch.stack([g[~shared]
                                           for g in pc_grad]).sum(dim=0)
       return merged_grad
""")
print("Info: Using actual _project_conflicting body from LOS/MultiBackward/GradFun/pcgrad.py for GHOP baseline.")
original_eoss_class_str = extract_code_block(
    eoss_full_code, "class", "Weight_acc")

# Print warnings if extraction failed
if "Error:" in original_hica_code_str:
    print(f"Warning: {original_hica_code_str}")
if "Error:" in original_eoss_class_str:
    print(f"Warning: {original_eoss_class_str}")

# --- End MODIFIED ---

# -------------------------------------------------- #

if __name__ == "__main__":
    # --- MODIFIED: Create output dir first ---
    if not os.path.exists(output_dir):
        try:
            os.makedirs(output_dir)
            print(f"Created experiment output directory: {output_dir}")
        except OSError as e:
            print(
                f"Error creating experiment output directory {output_dir}: {e}")
            sys.exit(1)  # Exit if cannot create output dir
    # --- End MODIFIED ---

    # 1. Initialize Paras object
    paras = Paras()

    # 2. Dynamically determine fixed modules based on modules_to_evolve list
    all_modules = {
        "hica": original_hica_code_str,
        "ghop": original_ghop_body_str,
        "eoss": original_eoss_class_str  # Use the full class string now
    }
    fixed_code = {}
    print(f"--- Evolution Configuration ---")
    print(f"Modules targeted for evolution: {modules_to_evolve}")
    for name, code in all_modules.items():
        if name not in modules_to_evolve:
            fixed_code[name] = code
            print(f"  - Fixing module: {name}")
    if not fixed_code:
        print("  - All modules will be evolved.")
    print("-------------------------------")

    # 3. Instantiate prompts, passing the dynamically generated fixed_code
    prompts_instance = PromptsCombined(fixed_modules_code=fixed_code)
    # --- MODIFIED: Pass LOS project dir and eval args ---
    problem_instance = MOOSFCombinedProblem(
        prompts_instance=prompts_instance,
        output_dir=output_dir,
        target_gpus=TARGET_GPU_IDS,
        project_dir=LOS_PROJECT_DIR,
        evaluation_script='main_stage2.py',
        evaluation_args=EVAL_ARGS_FOR_STAGE2,
        # --- ADDED for new directory naming ---
        population_size=population_size,
        total_generations=num_generations,
        modules_to_evolve=modules_to_evolve
    )
    # --- End MODIFIED ---

    # 4. Set parameters using paras.set_paras
    paras.set_paras(
        method="eoh",
        problem=problem_instance,
        # LLM Configuration
        llm_api_endpoint=llm_api_endpoint,
        llm_api_key=llm_api_key,
        llm_model=llm_model,
        # Evolution Configuration
        ec_pop_size=population_size,
        ec_n_pop=num_generations,
        ec_m=2,
        # Experiment Configuration
        exp_n_proc=exp_proc,
        exp_output_path=output_dir,
        exp_debug_mode=False,
        exp_use_seed=False,
        exp_use_continue=False,
        # Evaluation Configuration
        eva_timeout=800,  # Keep timeout sufficient
        eva_numba_decorator=False,
        # 单GPU配置 (每个进程一个GPU)
        exp_multi_gpu=False,  # Each evaluation process uses a single GPU
        # exp_gpu_ids=[0]     # Remove this, specific GPU assigned by worker
    )

    # --- MODIFIED: Update print statements ---
    print("Starting LOS+MOOSF Evolution...")
    print(f"Population Size: {population_size}")
    print(f"Number of Generations: {num_generations}")
    print(f"Output Directory: {output_dir}")
    print(f"Using LLM: {llm_model}")
    print(f"LLM Endpoint: {llm_api_endpoint}")
    # --- MODIFIED Print statement to show specific GPU IDs ---
    print(
        f"Configured for {exp_proc} parallel evaluations using specific GPUs: {TARGET_GPU_IDS}.")
    print(f"Using Evaluation Script: {problem_instance.evaluation_script}")
    print(
        f"Evaluation Args (passed to problem_instance): {' '.join(problem_instance.evaluation_args)}")
    # --- End MODIFIED ---

    # 5. Initialize EoH evolution object (No extra kwargs needed now)
    evolution = eoh.EVOL(paras)

    # 6. Run the evolution process
    evolution.run()

    # --- MODIFIED: Update print statement ---
    print("LOS+MOOSF Evolution finished.")
    # --- End MODIFIED ---
