"""
ACEAS Trainer: Integrated training loop for Adaptive Curriculum
with Execution-Aware Async Scheduling.

This module combines all ACEAS components into a unified training system.
"""

import os
import sys

import ray
import torch
import torch.nn as nn
import numpy as np
import time
import json
import logging
from dataclasses import dataclass, field
from typing import Optional, List, Dict, Any, Tuple
from pathlib import Path
from collections import deque
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, GenerationConfig
from peft import get_peft_model, LoraConfig, TaskType

from ..curriculum.difficulty_levels import DifficultyLevel, CurriculumTaskGenerator, CodeTask
from ..curriculum.adaptive_bandit import AdaptiveCurriculumBandit, ACBConfig, FixedCurriculum
from ..curriculum.task_sampler import CurriculumTaskSampler
from ..scheduling.async_scheduler import ACEASScheduler, SchedulerConfig
from ..scheduling.staleness_control import StaleExperience, StalenessConfig
from ..scheduling.execution_predictor import ExecutionTimePredictor
from ..code_environment.code_env import CurriculumCodeEnv, CodeEnvConfig
from .grpo import GRPOTrainer, GRPOConfig, GRPOBatch, GRPOExperience, create_grpo_batch
from ..utils.megatron_bridge import (
    setup_megatron_bridge,
    initialize_cublas,
    MegatronBridgeConfig,
    configure_cuda_backends,
)

logger = logging.getLogger(__name__)


@dataclass
class ACEASConfig:
    """Configuration for ACEAS training."""
    # Model configuration
    model_name: str = "Salesforce/codegen-350M-mono"
    use_lora: bool = True
    lora_r: int = 16
    lora_alpha: int = 32
    lora_target_modules: Optional[List[str]] = None  # Model-specific LoRA targets

    # Model scaling options (for 1B+ models)
    gradient_checkpointing: bool = False  # Enable gradient checkpointing to save memory
    load_in_8bit: bool = False  # Use 8-bit quantization (requires bitsandbytes)
    load_in_4bit: bool = False  # Use 4-bit quantization (requires bitsandbytes)
    torch_dtype: str = "bfloat16"  # "float16", "bfloat16", or "float32"

    # Worker GPU allocation (adjust for model size)
    gpu_per_worker: float = 0.5  # GPU fraction per worker (use 1.0 for large models)

    # Training configuration
    total_timesteps: int = 10000
    batch_size: int = 32
    eval_interval: int = 500
    save_interval: int = 1000
    log_interval: int = 50

    # Early stopping configuration
    early_stopping_patience: int = 5  # Stop if no improvement for N evals (0 to disable)

    # Curriculum configuration
    curriculum_strategy: str = "adaptive"  # "adaptive", "fixed", or "uniform"
    acb_exploration: float = 1.0
    acb_alpha: float = 0.7

    # Async/scheduling configuration
    num_workers: int = 4
    use_async: bool = True
    use_csc: bool = True
    use_eaas: bool = True
    eta_base: float = 8.0
    lambda_coupling: float = 0.5
    use_local_mode: bool = False  # Run without Ray workers
    broadcast_interval: int = 10  # Broadcast weights every N updates (higher = faster but staler)

    # GRPO configuration
    learning_rate: float = 1e-5
    clip_epsilon: float = 0.2
    kl_coef: float = 0.1
    group_size: int = 8

    # Environment configuration
    timeout_seconds: float = 10.0
    max_code_length: int = 2048

    # Generation configuration
    max_new_tokens: int = 128  # Reduced from 256 - most HumanEval solutions are under 100 tokens
    temperature: float = 0.8
    top_p: float = 0.95

    # Dataset configuration
    dataset: str = "humaneval"  # "humaneval", "mbpp", "mbpp_plus", "apps", "synthetic"
    apps_difficulty: str = "introductory"  # For APPS: "introductory", "interview", "competition"
    max_tasks: Optional[int] = None  # Limit number of tasks (for debugging)


def create_rollout_worker_class(gpu_fraction: float = 0.5):
    """
    Factory to create RolloutWorker class with specified GPU allocation.

    Args:
        gpu_fraction: Fraction of GPU to allocate per worker.
                     Use 0.5 for small models (350M), 1.0 for large models (1B+).

    Returns:
        Ray actor class with specified GPU allocation.
    """
    @ray.remote(num_gpus=gpu_fraction)
    class DynamicRolloutWorker(RolloutWorker):
        """RolloutWorker with dynamic GPU allocation."""
        pass
    return DynamicRolloutWorker


class RolloutWorker:
    """
    Ray actor for collecting rollouts.

    Each worker has its own copy of the model for inference.
    Note: This class should be instantiated via create_rollout_worker_class()
    to enable dynamic GPU allocation based on model size.
    """

    def __init__(
        self,
        worker_id: int,
        model_name: str,
        tasks: List[Dict[str, Any]],
        env_config: Dict[str, Any],
        use_lora: bool = True,
        lora_r: int = 16,
        lora_alpha: int = 32,
        lora_target_modules: Optional[List[str]] = None,
        gradient_checkpointing: bool = False,
        load_in_8bit: bool = False,
        load_in_4bit: bool = False,
        torch_dtype: str = "bfloat16",
    ):
        import os
        
        # Fix CUDA library path for workers - prioritize PyTorch's bundled CUDA
        try:
            torch_lib = os.path.dirname(torch.__file__)
            nvidia_cublas = os.path.join(os.path.dirname(torch_lib), "nvidia", "cublas", "lib")
            nvidia_runtime = os.path.join(os.path.dirname(torch_lib), "nvidia", "cuda_runtime", "lib")
            
            new_paths = []
            if os.path.exists(nvidia_cublas):
                new_paths.append(nvidia_cublas)
            if os.path.exists(nvidia_runtime):
                new_paths.append(nvidia_runtime)
            
            existing = os.environ.get("LD_LIBRARY_PATH", "")
            if new_paths:
                os.environ["LD_LIBRARY_PATH"] = ":".join(new_paths) + ":" + existing
                logger.info(f"Worker {worker_id}: Fixed CUDA library path")
        except Exception as e:
            logger.warning(f"Worker {worker_id}: Could not fix CUDA path: {e}")
        
        self.worker_id = worker_id
        self.device = "cuda" if torch.cuda.is_available() else "cpu"

        # Initialize CUBLAS on this worker's GPU
        if self.device == "cuda":
            initialize_cublas()
            logger.info(f"Worker {worker_id}: CUBLAS initialized on {torch.cuda.get_device_name()}")

        # Convert task dicts to CodeTask objects
        code_tasks = [
            CodeTask(
                task_id=t["task_id"],
                prompt=t["prompt"],
                canonical_solution=t["canonical_solution"],
                test_cases=t["test_cases"],
                entry_point=t["entry_point"],
            )
            for t in tasks
        ]

        # Initialize environment
        self.env = CurriculumCodeEnv(code_tasks, CodeEnvConfig(**env_config))

        # Initialize curriculum task generator
        self.task_generator = CurriculumTaskGenerator(code_tasks)

        # Initialize model
        self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        self.tokenizer.padding_side = "left"

        # Set up dtype
        dtype_map = {
            "float16": torch.float16,
            "bfloat16": torch.bfloat16,
            "float32": torch.float32,
        }
        model_dtype = dtype_map.get(torch_dtype, torch.float16)

        # Set up quantization config if needed
        model_kwargs = {
            "torch_dtype": model_dtype,
            "device_map": "auto" if (load_in_8bit or load_in_4bit) else self.device,
            "trust_remote_code": True,
            "attn_implementation": "eager",  # Use eager attention for compatibility
        }

        if load_in_4bit:
            try:
                from transformers import BitsAndBytesConfig
                model_kwargs["quantization_config"] = BitsAndBytesConfig(
                    load_in_4bit=True,
                    bnb_4bit_compute_dtype=model_dtype,
                    bnb_4bit_use_double_quant=True,
                    bnb_4bit_quant_type="nf4",
                )
                logger.info("Using 4-bit quantization")
            except ImportError:
                logger.warning("bitsandbytes not available, falling back to fp16")
        elif load_in_8bit:
            try:
                from transformers import BitsAndBytesConfig
                model_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
                logger.info("Using 8-bit quantization")
            except ImportError:
                logger.warning("bitsandbytes not available, falling back to fp16")

        # Load config and disable tied weights to avoid checkpoint conflicts
        model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
        model_config.tie_word_embeddings = False

        self.model = AutoModelForCausalLM.from_pretrained(
            model_name,
            config=model_config,
            **model_kwargs
        )

        # Explicitly convert to requested dtype (model config may override)
        if not load_in_8bit and not load_in_4bit:
            self.model = self.model.to(model_dtype)
            logger.info(f"Worker model converted to {model_dtype}")

        # CRITICAL: For LoRA + gradient checkpointing, enable input_require_grads on BASE model
        # BEFORE applying PEFT wrapper. This ensures embeddings have requires_grad=True.
        if use_lora and gradient_checkpointing:
            if hasattr(self.model, "enable_input_require_grads"):
                self.model.enable_input_require_grads()
                logger.info("Enabled input_require_grads on base model (before LoRA)")

        # Apply LoRA FIRST (before gradient checkpointing)
        if use_lora:
            # Determine LoRA target modules based on model architecture
            if lora_target_modules is None:
                # Auto-detect based on model name
                model_name_lower = model_name.lower()
                if "qwen" in model_name_lower or "deepseek" in model_name_lower or "llama" in model_name_lower or "codellama" in model_name_lower:
                    lora_target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
                elif "codegen" in model_name_lower:
                    lora_target_modules = ["qkv_proj", "out_proj", "fc_in", "fc_out"]
                elif "gpt2" in model_name_lower:
                    lora_target_modules = ["c_attn", "c_proj"]  # GPT-2 specific
                elif "gpt" in model_name_lower:
                    lora_target_modules = ["qkv_proj", "out_proj", "fc_in", "fc_out"]
                elif "starcoder" in model_name_lower:
                    lora_target_modules = ["c_attn", "c_proj", "c_fc"]
                else:
                    lora_target_modules = ["q_proj", "v_proj"]  # Fallback

            lora_config = LoraConfig(
                task_type=TaskType.CAUSAL_LM,
                r=lora_r,
                lora_alpha=lora_alpha,
                lora_dropout=0.05,
                target_modules=lora_target_modules,
                bias="none",
            )
            try:
                self.model = get_peft_model(self.model, lora_config)
                logger.info(f"Applied LoRA with target modules: {lora_target_modules}")
            except Exception as e:
                logger.warning(f"LoRA failed: {e}")

        # Enable gradient checkpointing AFTER LoRA
        # Use non-reentrant checkpointing to avoid "None of inputs have requires_grad=True" issue
        if gradient_checkpointing and hasattr(self.model, "gradient_checkpointing_enable"):
            self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
            logger.info("Enabled gradient checkpointing (non-reentrant)")

        self.model.eval()

        # Statistics
        self.total_episodes = 0
        self.total_successes = 0

    def update_weights(self, state_dict: Dict[str, Any]):
        """Update model weights from trainer."""
        logger.info(f"Worker {self.worker_id}: Updating weights...")
        try:
            # Use strict=False to handle potential key mismatches between
            # trainer and worker models (e.g., different device placements)
            self.model.load_state_dict(state_dict, strict=False)
            self.model.eval()
            logger.info(f"Worker {self.worker_id}: Weight update complete")
        except Exception as e:
            logger.error(f"Worker {self.worker_id}: Weight update failed: {e}")
            raise

    def collect_rollout(
        self,
        difficulty: int,
        num_episodes: int,
        policy_version: int,
        max_new_tokens: int = 256,
        temperature: float = 0.8,
    ) -> List[Dict[str, Any]]:
        """
        Collect rollout experiences at specified difficulty.

        Args:
            difficulty: Curriculum difficulty level (1-5)
            num_episodes: Number of episodes to collect
            policy_version: Current policy version for staleness tracking
            max_new_tokens: Max tokens to generate
            temperature: Sampling temperature

        Returns:
            List of experience dictionaries
        """
        experiences = []
        difficulty_level = DifficultyLevel(difficulty)

        for _ in range(num_episodes):
            # Sample random task
            task_idx = np.random.randint(len(self.env.tasks))

            # Generate curriculum task at specified difficulty
            curriculum_task = self.task_generator.generate_task(task_idx, difficulty_level)

            # Get prompt
            prompt = curriculum_task.curriculum_prompt

            # Format prompt based on model type
            # Only use chat template for chat models (Qwen, Llama-chat, etc.)
            # Base models (CodeGen, GPT-2, etc.) should use raw prompts
            model_name_lower = self.tokenizer.name_or_path.lower() if hasattr(self.tokenizer, 'name_or_path') else ""
            is_chat_model = any(x in model_name_lower for x in ['qwen', 'chat', 'instruct', 'llama-2-', 'llama-3'])

            if is_chat_model and hasattr(self.tokenizer, 'chat_template') and self.tokenizer.chat_template:
                # For Qwen3: Add /no_think to disable thinking mode and get direct code output
                no_think_suffix = " /no_think" if 'qwen' in model_name_lower else ""
                messages = [
                    {"role": "user", "content": f"Complete the following Python code. Only output the code completion, no explanations or markdown:{no_think_suffix}\n\n{prompt}"},
                ]
                formatted_prompt = self.tokenizer.apply_chat_template(
                    messages,
                    tokenize=False,
                    add_generation_prompt=True
                )
            else:
                # Base model - use raw prompt directly
                formatted_prompt = prompt

            # Generate completion
            gen_start = time.time()
            inputs = self.tokenizer(
                formatted_prompt,
                return_tensors="pt",
                truncation=True,
                max_length=256,
            ).to(self.device)

            # Temporarily disable gradient checkpointing for generation (incompatible with use_cache)
            # Check both the model and base_model (for PEFT-wrapped models)
            was_checkpointing = False
            if hasattr(self.model, 'is_gradient_checkpointing'):
                was_checkpointing = self.model.is_gradient_checkpointing
            elif hasattr(self.model, 'gradient_checkpointing'):
                was_checkpointing = self.model.gradient_checkpointing
            elif hasattr(self.model, 'base_model') and hasattr(self.model.base_model, 'gradient_checkpointing'):
                was_checkpointing = self.model.base_model.gradient_checkpointing

            if was_checkpointing:
                if hasattr(self.model, 'gradient_checkpointing_disable'):
                    self.model.gradient_checkpointing_disable()
                elif hasattr(self.model, 'base_model') and hasattr(self.model.base_model, 'gradient_checkpointing_disable'):
                    self.model.base_model.gradient_checkpointing_disable()

            # Use GenerationConfig with all parameters to avoid deprecation warnings
            gen_config = GenerationConfig(
                max_new_tokens=max_new_tokens,
                temperature=temperature,
                top_p=0.95,
                do_sample=True,
                pad_token_id=self.tokenizer.pad_token_id,
                return_dict_in_generate=True,
                output_scores=True,
            )

            with torch.no_grad():
                outputs = self.model.generate(
                    **inputs,
                    generation_config=gen_config,
                )

            # Re-enable gradient checkpointing if it was enabled
            if was_checkpointing:
                if hasattr(self.model, 'gradient_checkpointing_enable'):
                    self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
                elif hasattr(self.model, 'base_model') and hasattr(self.model.base_model, 'gradient_checkpointing_enable'):
                    self.model.base_model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})

            gen_time = time.time() - gen_start

            # Decode response
            prompt_length = inputs.input_ids.shape[1]
            generated_ids = outputs.sequences[:, prompt_length:]
            response = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True)

            # Strip Qwen3 thinking tags before processing
            if "<think>" in response:
                import re
                # Remove everything between <think> and </think>
                response = re.sub(r'<think>.*?</think>', '', response, flags=re.DOTALL).strip()
                # Also handle unclosed <think> tags (model cut off mid-thought)
                if "<think>" in response:
                    response = response.split("<think>")[0].strip()

            # Extract just the code from potential markdown formatting
            if "```python" in response:
                response = response.split("```python")[-1].split("```")[0].strip()
            elif "```" in response:
                response = response.split("```")[1].split("```")[0].strip()

            # Handle model output:
            # If response is a complete function (starts with 'def'), use it directly
            # Otherwise, concatenate with prompt
            response_stripped = response.strip()
            if response_stripped.startswith('def '):
                # Model output a complete function - use it directly
                full_code = response_stripped
            else:
                # Model output a continuation - concatenate with prompt
                prompt_stripped = prompt.strip()
                if response_stripped.startswith(prompt_stripped):
                    # Model repeated the prompt, extract only the new part
                    remainder = response_stripped[len(prompt_stripped):]
                    if remainder.lstrip().startswith(('return', 'pass', 'raise', 'yield', 'break', 'continue')):
                        response = '\n    ' + remainder.lstrip()
                    else:
                        response = remainder
                full_code = prompt + response

            # Compute log probability
            if outputs.scores:
                stacked_scores = torch.stack(outputs.scores, dim=1)
                log_probs_all = torch.log_softmax(stacked_scores, dim=-1)

                gen_len = generated_ids.shape[1]
                scores_len = len(outputs.scores)
                actual_len = min(gen_len, scores_len)

                log_prob = 0.0
                for j in range(actual_len):
                    token_id = generated_ids[0, j].item()
                    if token_id != self.tokenizer.pad_token_id:
                        log_prob += log_probs_all[0, j, token_id].item()
            else:
                log_prob = 0.0

            # Log first few generations for debugging (use print for visibility in Ray logs)
            if self.total_episodes < 3:
                print(f"[DEBUG] Worker {self.worker_id} Episode {self.total_episodes}:", flush=True)
                print(f"  Prompt: {prompt[:150]}...", flush=True)
                print(f"  Raw response: {response[:300]}...", flush=True)
                print(f"  Full code: {full_code[:400]}...", flush=True)

            # Execute in environment (full_code already set above)
            env_start = time.time()
            # Reset env with the curriculum task before stepping
            self.env.reset(curriculum_task=curriculum_task)
            result = self.env.step(full_code)
            env_time = time.time() - env_start

            # Log execution result for first few episodes (use print for visibility)
            if self.total_episodes < 3:
                print(f"  Result: reward={result.reward}, passed={result.info.get('passed')}", flush=True)
                if result.info.get('stderr'):
                    print(f"  Stderr: {result.info.get('stderr')[:500]}...", flush=True)

            # Create experience
            exp = {
                "prompt": prompt,
                "response": response,
                "reward": result.reward,
                "log_prob": log_prob,
                "difficulty": difficulty,
                "policy_version": policy_version,
                "timestamp": time.time(),
                "task_id": curriculum_task.task_id,
                "execution_time": env_time,
                "generation_time": gen_time,
                "passed": result.info.get("passed", False),
                "worker_id": self.worker_id,
            }
            experiences.append(exp)

            # Update statistics
            self.total_episodes += 1
            if result.info.get("passed", False):
                self.total_successes += 1

        return experiences

    def get_stats(self) -> Dict[str, Any]:
        """Get worker statistics."""
        return {
            "worker_id": self.worker_id,
            "total_episodes": self.total_episodes,
            "total_successes": self.total_successes,
            "success_rate": self.total_successes / max(1, self.total_episodes),
        }


class ACEASTrainer:
    """
    Main ACEAS trainer that coordinates all components.

    Supports both synchronous and asynchronous training modes.
    """

    def __init__(
        self,
        tasks: List[CodeTask],
        config: Optional[ACEASConfig] = None,
        output_dir: str = "./experiments/results",
    ):
        """
        Initialize ACEAS trainer.

        Args:
            tasks: List of code tasks for training
            config: ACEASConfig
            output_dir: Directory for saving results
        """
        self.config = config or ACEASConfig()
        self.output_dir = Path(output_dir)
        self.output_dir.mkdir(parents=True, exist_ok=True)

        self.tasks = tasks
        self.device = "cuda" if torch.cuda.is_available() else "cpu"

        # Initialize Megatron bridge for CUBLAS stability
        self._init_megatron_bridge()

        # Initialize components
        self._init_model()
        self._init_curriculum()
        self._init_scheduler()
        if not self.config.use_local_mode:
            self._init_workers()
        else:
            self._init_local_env()
            self.workers = []
            logger.info("Using local mode (no Ray workers)")

        # Training state
        self.policy_version = 0
        self.total_timesteps = 0
        self.total_updates = 0

        # Early stopping state
        self.best_pass_at_1 = 0.0
        self.no_improve_count = 0

        # Metrics
        self.train_metrics: List[Dict[str, Any]] = []
        self.eval_metrics: List[Dict[str, Any]] = []
        self.timing_metrics: List[Dict[str, Any]] = []

    def _init_megatron_bridge(self):
        """Initialize Megatron bridge for CUBLAS stability and optimizations."""
        if self.device == "cuda":
            logger.info("Initializing Megatron bridge for CUBLAS stability...")
            bridge_config = MegatronBridgeConfig(
                use_fused_layer_norm=True,
                use_fused_attention=True,
                use_flash_attention=True,
                enable_cublas_tf32=True,
                enable_cudnn_benchmark=True,
            )
            self.megatron_features = setup_megatron_bridge(bridge_config)
            logger.info(f"Megatron bridge initialized: {self.megatron_features}")
        else:
            self.megatron_features = {}
            logger.info("CUDA not available, skipping Megatron bridge initialization")

    def _init_model(self):
        """Initialize the policy model and trainer."""
        logger.info(f"Loading model: {self.config.model_name}")

        self.tokenizer = AutoTokenizer.from_pretrained(
            self.config.model_name,
            trust_remote_code=True,
        )
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        self.tokenizer.padding_side = "left"

        # Determine torch dtype
        dtype_map = {
            "float16": torch.float16,
            "bfloat16": torch.bfloat16,
            "float32": torch.float32,
        }
        torch_dtype = dtype_map.get(self.config.torch_dtype, torch.float16)

        # Prepare model loading kwargs
        model_kwargs = {
            "torch_dtype": torch_dtype,
            "trust_remote_code": True,
            "attn_implementation": "eager",  # Use eager attention for compatibility
        }

        # Handle quantization (mutually exclusive with device_map for some cases)
        if self.config.load_in_8bit:
            try:
                from transformers import BitsAndBytesConfig
                bnb_config = BitsAndBytesConfig(load_in_8bit=True)
                model_kwargs["quantization_config"] = bnb_config
                model_kwargs["device_map"] = "auto"
                logger.info("Using 8-bit quantization")
            except ImportError:
                logger.warning("bitsandbytes not available, falling back to fp16")
                model_kwargs["device_map"] = self.device
        elif self.config.load_in_4bit:
            try:
                from transformers import BitsAndBytesConfig
                bnb_config = BitsAndBytesConfig(
                    load_in_4bit=True,
                    bnb_4bit_compute_dtype=torch_dtype,
                    bnb_4bit_use_double_quant=True,
                    bnb_4bit_quant_type="nf4",
                )
                model_kwargs["quantization_config"] = bnb_config
                model_kwargs["device_map"] = "auto"
                logger.info("Using 4-bit quantization")
            except ImportError:
                logger.warning("bitsandbytes not available, falling back to fp16")
                model_kwargs["device_map"] = self.device
        else:
            model_kwargs["device_map"] = self.device

        # Load config and disable tied weights to avoid checkpoint conflicts
        model_config = AutoConfig.from_pretrained(
            self.config.model_name,
            trust_remote_code=True,
        )
        model_config.tie_word_embeddings = False

        # Load model
        self.model = AutoModelForCausalLM.from_pretrained(
            self.config.model_name,
            config=model_config,
            **model_kwargs,
        )

        # Explicitly convert to requested dtype (model config may override)
        if not self.config.load_in_8bit and not self.config.load_in_4bit:
            self.model = self.model.to(torch_dtype)
            logger.info(f"Model explicitly converted to {torch_dtype}")

        # CRITICAL: For LoRA + gradient checkpointing, enable input_require_grads on BASE model
        # BEFORE applying PEFT wrapper. This ensures embeddings have requires_grad=True.
        if self.config.use_lora and self.config.gradient_checkpointing:
            if hasattr(self.model, "enable_input_require_grads"):
                self.model.enable_input_require_grads()
                logger.info("Enabled input_require_grads on base model (before LoRA)")

        # Apply LoRA FIRST (before gradient checkpointing)
        if self.config.use_lora:
            # Determine target modules based on model architecture
            target_modules = self.config.lora_target_modules
            if target_modules is None:
                # Auto-detect based on model name
                model_name_lower = self.config.model_name.lower()
                if "qwen" in model_name_lower or "deepseek" in model_name_lower or "llama" in model_name_lower:
                    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
                elif "codegen" in model_name_lower:
                    target_modules = ["qkv_proj", "out_proj", "fc_in", "fc_out"]
                elif "gpt2" in model_name_lower:
                    target_modules = ["c_attn", "c_proj"]  # GPT-2 specific
                elif "gpt" in model_name_lower:
                    target_modules = ["qkv_proj", "out_proj", "fc_in", "fc_out"]
                elif "starcoder" in model_name_lower:
                    target_modules = ["c_attn", "c_proj", "c_fc"]
                else:
                    # Generic fallback
                    target_modules = ["q_proj", "v_proj"]
                logger.info(f"Auto-detected LoRA target modules: {target_modules}")

            lora_config = LoraConfig(
                task_type=TaskType.CAUSAL_LM,
                r=self.config.lora_r,
                lora_alpha=self.config.lora_alpha,
                lora_dropout=0.05,
                target_modules=target_modules,
                bias="none",
            )
            try:
                self.model = get_peft_model(self.model, lora_config)
                trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
                total_params = sum(p.numel() for p in self.model.parameters())
                logger.info(f"LoRA applied: {trainable_params:,} trainable / {total_params:,} total params "
                           f"({100*trainable_params/total_params:.2f}%)")
            except Exception as e:
                logger.warning(f"LoRA failed: {e}, trying alternative modules")
                # Fallback to generic modules
                try:
                    lora_config.target_modules = ["q_proj", "v_proj"]
                    self.model = get_peft_model(self.model, lora_config)
                    logger.info("LoRA applied with fallback modules")
                except Exception as e2:
                    logger.error(f"LoRA completely failed: {e2}")

        # Enable gradient checkpointing AFTER LoRA
        # Use non-reentrant checkpointing to avoid "None of inputs have requires_grad=True" issue
        if self.config.gradient_checkpointing:
            self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
            logger.info("Gradient checkpointing enabled (non-reentrant)")

        # Initialize GRPO trainer
        grpo_config = GRPOConfig(
            learning_rate=self.config.learning_rate,
            clip_epsilon=self.config.clip_epsilon,
            kl_coef=self.config.kl_coef,
            group_size=self.config.group_size,
            total_steps=self.config.total_timesteps,  # For LR scheduler
        )
        self.trainer = GRPOTrainer(
            model=self.model,
            tokenizer=self.tokenizer,
            config=grpo_config,
            device=self.device,
        )

    def _init_curriculum(self):
        """Initialize curriculum components."""
        if self.config.curriculum_strategy == "adaptive":
            acb_config = ACBConfig(
                exploration_constant=self.config.acb_exploration,
                alpha=self.config.acb_alpha,
            )
            self.acb = AdaptiveCurriculumBandit(acb_config)
        elif self.config.curriculum_strategy == "fixed":
            self.acb = FixedCurriculum(
                total_steps=self.config.total_timesteps,
                warmup_ratio=0.1,
            )
        else:
            self.acb = None

        self.task_generator = CurriculumTaskGenerator(self.tasks)

    def _init_scheduler(self):
        """Initialize the ACEAS scheduler."""
        scheduler_config = SchedulerConfig(
            num_workers=self.config.num_workers,
            use_csc=self.config.use_csc,
            use_execution_aware=self.config.use_eaas,
            eta_base=self.config.eta_base,
            lambda_coupling=self.config.lambda_coupling,
        )
        staleness_config = StalenessConfig(
            eta_base=self.config.eta_base,
            lambda_coupling=self.config.lambda_coupling,
        )
        self.scheduler = ACEASScheduler(
            config=scheduler_config,
            acb=self.acb,
            staleness_config=staleness_config,
        )

    def _init_workers(self):
        """Initialize Ray workers with dynamic GPU allocation."""
        logger.info("Initializing Ray workers...")
        if not ray.is_initialized():
            logger.info("Initializing Ray runtime...")
            ray.init(ignore_reinit_error=True)
            logger.info("Ray runtime initialized")

        # Convert tasks to dicts for Ray serialization
        task_dicts = [
            {
                "task_id": t.task_id,
                "prompt": t.prompt,
                "canonical_solution": t.canonical_solution,
                "test_cases": t.test_cases,
                "entry_point": t.entry_point,
            }
            for t in self.tasks
        ]

        env_config = {
            "timeout_seconds": self.config.timeout_seconds,
            "max_code_length": self.config.max_code_length,
        }

        # Create worker class with appropriate GPU allocation
        # Use 1.0 GPU for large models (1B+), 0.5 for smaller models
        WorkerClass = create_rollout_worker_class(self.config.gpu_per_worker)
        logger.info(f"Creating workers with {self.config.gpu_per_worker} GPU per worker")

        self.workers = []
        for i in range(self.config.num_workers):
            logger.info(f"Creating worker {i}/{self.config.num_workers}...")
            worker = WorkerClass.remote(
                worker_id=i,
                model_name=self.config.model_name,
                tasks=task_dicts,
                env_config=env_config,
                use_lora=self.config.use_lora,
                lora_r=self.config.lora_r,
                lora_alpha=self.config.lora_alpha,
                lora_target_modules=self.config.lora_target_modules,
                gradient_checkpointing=self.config.gradient_checkpointing,
                load_in_8bit=self.config.load_in_8bit,
                load_in_4bit=self.config.load_in_4bit,
                torch_dtype=self.config.torch_dtype,
            )
            self.workers.append(worker)
            logger.info(f"Worker {i} actor created (model loading in background)")

        logger.info(f"All {len(self.workers)} worker actors created. Note: Workers load models asynchronously.")

    def _init_local_env(self):
        """Initialize local environment for local mode training."""
        env_config = CodeEnvConfig(
            timeout_seconds=self.config.timeout_seconds,
            max_code_length=self.config.max_code_length,
        )
        self.local_env = CurriculumCodeEnv(self.tasks, env_config)

    def _broadcast_weights(self, force_sync: bool = False):
        """
        Broadcast model weights to all workers.

        Args:
            force_sync: If True, wait for broadcast to complete (blocking).
                       If False, fire-and-forget for better throughput.
        """
        if self.config.use_local_mode or not self.workers:
            return  # No workers to broadcast to in local mode

        # Move state dict to CPU for serialization
        state_dict = {k: v.cpu() for k, v in self.model.state_dict().items()}
        futures = [w.update_weights.remote(state_dict) for w in self.workers]

        if force_sync:
            # Blocking mode - wait for all workers to update
            logger.info("Broadcasting weights to workers (sync)...")
            try:
                ray.get(futures, timeout=120)
                logger.info("Weight broadcast complete")
            except ray.exceptions.GetTimeoutError:
                logger.error("Weight broadcast timed out after 120s")
                raise RuntimeError("Weight broadcast to workers timed out.")
        else:
            # Non-blocking mode - fire and forget for better throughput
            # Workers will use slightly stale weights but training continues
            self._pending_weight_updates = futures
            logger.debug("Weight broadcast dispatched (async)")

    def _collect_experiences_local(
        self,
        num_episodes: int,
    ) -> List[Dict[str, Any]]:
        """Collect experiences locally without Ray workers."""
        experiences = []

        for _ in range(num_episodes):
            # Select difficulty
            if self.acb is not None:
                diff = self.acb.select_difficulty()
                if hasattr(diff, 'value'):
                    diff = diff.value
            else:
                diff = np.random.randint(1, 6)

            difficulty_level = DifficultyLevel(diff)

            # Sample random task
            task_idx = np.random.randint(len(self.tasks))

            # Generate curriculum task
            curriculum_task = self.task_generator.generate_task(task_idx, difficulty_level)

            # Get prompt
            prompt = curriculum_task.curriculum_prompt

            # Format prompt based on model type
            # Only use chat template for chat models (Qwen, Llama-chat, etc.)
            # Base models (CodeGen, GPT-2, etc.) should use raw prompts
            model_name_lower = self.tokenizer.name_or_path.lower() if hasattr(self.tokenizer, 'name_or_path') else ""
            is_chat_model = any(x in model_name_lower for x in ['qwen', 'chat', 'instruct', 'llama-2-', 'llama-3'])

            if is_chat_model and hasattr(self.tokenizer, 'chat_template') and self.tokenizer.chat_template:
                # For Qwen3: Add /no_think to disable thinking mode and get direct code output
                no_think_suffix = " /no_think" if 'qwen' in model_name_lower else ""
                messages = [
                    {"role": "user", "content": f"Complete the following Python code. Only output the code completion, no explanations or markdown:{no_think_suffix}\n\n{prompt}"},
                ]
                formatted_prompt = self.tokenizer.apply_chat_template(
                    messages,
                    tokenize=False,
                    add_generation_prompt=True
                )
            else:
                # Base model - use raw prompt directly
                formatted_prompt = prompt

            # Generate completion
            gen_start = time.time()
            inputs = self.tokenizer(
                formatted_prompt,
                return_tensors="pt",
                truncation=True,
                max_length=256,
            ).to(self.device)

            with torch.no_grad():
                # Temporarily disable gradient checkpointing during generation for KV cache
                was_checkpointing = getattr(self.model, 'gradient_checkpointing', False) or \
                                   getattr(self.model.base_model, 'gradient_checkpointing', False)
                if was_checkpointing:
                    if hasattr(self.model, 'gradient_checkpointing_disable'):
                        self.model.gradient_checkpointing_disable()
                    elif hasattr(self.model.base_model, 'gradient_checkpointing_disable'):
                        self.model.base_model.gradient_checkpointing_disable()

                # Use greedy decoding to avoid numerical instability with bfloat16 sampling
                gen_config = GenerationConfig(
                    max_new_tokens=self.config.max_new_tokens,
                    do_sample=False,  # Greedy decoding for stability
                    pad_token_id=self.tokenizer.pad_token_id,
                    return_dict_in_generate=True,
                    output_scores=True,
                )
                outputs = self.model.generate(
                    **inputs,
                    generation_config=gen_config,
                    use_cache=True,  # Enable KV cache for faster generation
                )

                # Re-enable gradient checkpointing for training
                if was_checkpointing:
                    if hasattr(self.model, 'gradient_checkpointing_enable'):
                        self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
                    elif hasattr(self.model, 'base_model') and hasattr(self.model.base_model, 'gradient_checkpointing_enable'):
                        self.model.base_model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})

            gen_time = time.time() - gen_start

            # Decode response
            prompt_length = inputs.input_ids.shape[1]
            generated_ids = outputs.sequences[:, prompt_length:]
            response = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True)

            # Strip Qwen3 thinking tags before processing
            if "<think>" in response:
                import re
                # Remove everything between <think> and </think>
                response = re.sub(r'<think>.*?</think>', '', response, flags=re.DOTALL).strip()
                # Also handle unclosed <think> tags (model cut off mid-thought)
                if "<think>" in response:
                    response = response.split("<think>")[0].strip()

            # Extract just the code from potential markdown formatting
            if "```python" in response:
                response = response.split("```python")[-1].split("```")[0].strip()
            elif "```" in response:
                response = response.split("```")[1].split("```")[0].strip()

            # Handle model output:
            # If response is a complete function (starts with 'def'), use it directly
            # Otherwise, concatenate with prompt
            response_stripped = response.strip()
            if response_stripped.startswith('def '):
                # Model output a complete function - use it directly
                full_code = response_stripped
            else:
                # Model output a continuation - concatenate with prompt
                prompt_stripped = prompt.strip()
                if response_stripped.startswith(prompt_stripped):
                    # Model repeated the prompt, extract only the new part
                    remainder = response_stripped[len(prompt_stripped):]
                    if remainder.lstrip().startswith(('return', 'pass', 'raise', 'yield', 'break', 'continue')):
                        response = '\n    ' + remainder.lstrip()
                    else:
                        response = remainder
                full_code = prompt + response

            # Compute log probability
            if outputs.scores:
                stacked_scores = torch.stack(outputs.scores, dim=1)
                log_probs_all = torch.log_softmax(stacked_scores, dim=-1)

                gen_len = generated_ids.shape[1]
                scores_len = len(outputs.scores)
                actual_len = min(gen_len, scores_len)

                log_prob = 0.0
                for j in range(actual_len):
                    token_id = generated_ids[0, j].item()
                    if token_id != self.tokenizer.pad_token_id:
                        log_prob += log_probs_all[0, j, token_id].item()
            else:
                log_prob = 0.0

            # Log first few generations for debugging (use print for visibility)
            if len(experiences) < 3:
                print(f"[DEBUG] Local Episode {len(experiences)}:", flush=True)
                print(f"  Prompt: {prompt[:150]}...", flush=True)
                print(f"  Raw response: {response[:300]}...", flush=True)
                print(f"  Full code: {full_code[:400]}...", flush=True)

            # Execute in environment (full_code already set above)
            env_start = time.time()
            self.local_env.reset(curriculum_task=curriculum_task)
            result = self.local_env.step(full_code)
            env_time = time.time() - env_start

            # Log execution result for first few episodes (use print for visibility)
            if len(experiences) < 3:
                print(f"  Result: reward={result.reward}, passed={result.info.get('passed')}", flush=True)
                if result.info.get('stderr'):
                    print(f"  Stderr: {result.info.get('stderr')[:500]}...", flush=True)

            # Create experience
            exp = {
                "prompt": prompt,
                "response": response,
                "reward": result.reward,
                "log_prob": log_prob,
                "difficulty": diff,
                "policy_version": self.policy_version,
                "timestamp": time.time(),
                "task_id": curriculum_task.task_id,
                "execution_time": env_time,
                "generation_time": gen_time,
                "passed": result.info.get("passed", False),
                "worker_id": 0,
            }
            experiences.append(exp)

        return experiences

    def _collect_experiences_sync(
        self,
        num_per_worker: int = 8,
    ) -> List[Dict[str, Any]]:
        """Collect experiences synchronously from all workers."""
        # Select difficulties for each worker
        difficulties = []
        for i in range(len(self.workers)):
            if self.acb is not None:
                diff = self.acb.select_difficulty()
                if hasattr(diff, 'value'):
                    diff = diff.value
                difficulties.append(diff)
            else:
                difficulties.append(np.random.randint(1, 6))

        # Collect from all workers
        futures = [
            self.workers[i].collect_rollout.remote(
                difficulty=difficulties[i],
                num_episodes=num_per_worker,
                policy_version=self.policy_version,
                max_new_tokens=self.config.max_new_tokens,
                temperature=self.config.temperature,
            )
            for i in range(len(self.workers))
        ]

        all_experiences = ray.get(futures)

        # Flatten
        experiences = []
        for worker_exps in all_experiences:
            experiences.extend(worker_exps)

        return experiences

    def _collect_experiences_async(
        self,
        target_count: int,
    ) -> List[Dict[str, Any]]:
        """Collect experiences asynchronously."""
        experiences = []
        pending_futures = {}

        # Start initial rollouts
        for i, worker in enumerate(self.workers):
            diff = self.scheduler.select_difficulty_for_worker(i)
            batch_size = self.scheduler.assign_worker_batch_size(i)

            future = worker.collect_rollout.remote(
                difficulty=diff,
                num_episodes=batch_size,
                policy_version=self.policy_version,
            )
            pending_futures[future] = i

        # Collect results as they complete
        while len(experiences) < target_count and pending_futures:
            ready, _ = ray.wait(list(pending_futures.keys()), timeout=0.1)

            for future in ready:
                worker_id = pending_futures.pop(future)
                try:
                    worker_exps = ray.get(future)
                    experiences.extend(worker_exps)

                    # Add to scheduler
                    stale_exps = [
                        StaleExperience(
                            prompt=e["prompt"],
                            response=e["response"],
                            reward=e["reward"],
                            log_prob=e["log_prob"],
                            value=0.0,
                            difficulty=e["difficulty"],
                            policy_version=e["policy_version"],
                            timestamp=e["timestamp"],
                            task_id=e["task_id"],
                            execution_time=e["execution_time"],
                        )
                        for e in worker_exps
                    ]
                    self.scheduler.add_experiences(stale_exps, worker_id)

                except Exception as e:
                    logger.error(f"Worker {worker_id} failed: {e}")

                # Start new rollout if needed
                if len(experiences) < target_count:
                    diff = self.scheduler.select_difficulty_for_worker(worker_id)
                    batch_size = self.scheduler.assign_worker_batch_size(worker_id)

                    new_future = self.workers[worker_id].collect_rollout.remote(
                        difficulty=diff,
                        num_episodes=batch_size,
                        policy_version=self.policy_version,
                    )
                    pending_futures[new_future] = worker_id

        return experiences

    def train(self) -> Dict[str, Any]:
        """
        Run the full training loop.

        Returns:
            Dictionary with training results
        """
        logger.info("Starting ACEAS training")
        print(f"\n{'='*70}", flush=True)
        print(f"ACEAS Training: {self.config.total_timesteps} steps, batch_size={self.config.batch_size}", flush=True)
        print(f"Log interval: {self.config.log_interval}, Eval interval: {self.config.eval_interval}", flush=True)
        print(f"{'='*70}\n", flush=True)
        start_time = time.time()

        # Initial weight broadcast (synchronous to ensure workers are ready)
        self._broadcast_weights(force_sync=True)

        while self.total_timesteps < self.config.total_timesteps:
            update_start = time.time()

            # Collect experiences
            if self.config.use_local_mode:
                experiences = self._collect_experiences_local(
                    num_episodes=self.config.batch_size,
                )
            elif self.config.use_async:
                experiences = self._collect_experiences_async(
                    target_count=self.config.batch_size,
                )
            else:
                exp_per_worker = max(1, self.config.batch_size // len(self.workers))
                experiences = self._collect_experiences_sync(
                    num_per_worker=exp_per_worker,
                )

            if not experiences:
                logger.warning("No experiences collected, retrying...")
                continue

            collection_time = time.time() - update_start

            # Convert to GRPO experiences
            grpo_exps = [
                GRPOExperience(
                    prompt=e["prompt"],
                    response=e["response"],
                    reward=e["reward"],
                    old_log_prob=e["log_prob"],
                    difficulty=e["difficulty"],
                    policy_version=e["policy_version"],
                )
                for e in experiences
            ]

            # Create batch and train
            train_start = time.time()
            batch = create_grpo_batch(grpo_exps)
            train_metrics = self.trainer.train_step(batch)
            train_time = time.time() - train_start

            # Get gradient magnitude for ACB
            grad_mag = self.trainer.get_gradient_magnitude()

            # Update curriculum
            difficulties = [e["difficulty"] for e in experiences]
            rewards = [e["reward"] for e in experiences]
            successes = [e.get("passed", False) for e in experiences]

            if self.config.curriculum_strategy == "adaptive":
                self.scheduler.update_curriculum(
                    difficulties, rewards, successes,
                    gradient_magnitudes=[grad_mag] * len(difficulties),
                )

            # Update state
            self.policy_version += 1
            self.scheduler.on_policy_update()
            self.total_timesteps += len(experiences)
            self.total_updates += 1

            # Broadcast updated weights (async, every N updates for throughput)
            broadcast_interval = getattr(self.config, 'broadcast_interval', 1)
            broadcast_time = 0.0
            if self.total_updates % broadcast_interval == 0:
                broadcast_start = time.time()
                self._broadcast_weights(force_sync=False)  # Non-blocking for throughput
                broadcast_time = time.time() - broadcast_start

            update_time = time.time() - update_start

            # Record metrics
            cumulative_wall_time = time.time() - start_time
            timing = {
                "timestep": self.total_timesteps,
                "update": self.total_updates,
                "cumulative_wall_time": cumulative_wall_time,  # For wall-clock analysis
                "cumulative_gpu_hours": cumulative_wall_time / 3600,  # Convert to hours
                "collection_time": collection_time,
                "train_time": train_time,
                "broadcast_time": broadcast_time,
                "update_time": update_time,
                "throughput": len(experiences) / update_time,
            }
            self.timing_metrics.append(timing)

            train_metrics["timestep"] = self.total_timesteps
            train_metrics["num_experiences"] = len(experiences)
            train_metrics["success_rate"] = sum(successes) / len(successes)
            train_metrics["avg_reward"] = np.mean(rewards)

            # Add difficulty distribution
            diff_dist = {}
            for d in range(1, 6):
                diff_dist[f"difficulty_{d}_ratio"] = sum(1 for x in difficulties if x == d) / len(difficulties)
            train_metrics.update(diff_dist)

            self.train_metrics.append(train_metrics)

            # Logging - use print with flush for Ray worker visibility
            if self.total_updates % self.config.log_interval == 0:
                elapsed = time.time() - start_time
                loss_str = f"Loss: {train_metrics.get('loss', 0):.4f}" if 'loss' in train_metrics else ""
                progress_pct = 100 * self.total_timesteps / self.config.total_timesteps
                print(
                    f"[Step {self.total_timesteps:>6}/{self.config.total_timesteps}] ({progress_pct:5.1f}%) | "
                    f"Reward: {train_metrics['avg_reward']:>7.3f} | "
                    f"Success: {train_metrics['success_rate']:>6.2%} | "
                    f"{loss_str} | "
                    f"Throughput: {timing['throughput']:.1f}/s | "
                    f"Elapsed: {elapsed:.1f}s",
                    flush=True
                )

            # Evaluation
            if self.total_updates % (self.config.eval_interval // self.config.batch_size) == 0:
                eval_metrics = self._evaluate()
                self.eval_metrics.append(eval_metrics)
                print(f"\n>>> EVAL @ Step {self.total_timesteps}: Pass@1 = {eval_metrics['pass_at_1']:.2%} <<<\n", flush=True)

                # Early stopping check
                if self.config.early_stopping_patience > 0:
                    if eval_metrics['pass_at_1'] > self.best_pass_at_1:
                        self.best_pass_at_1 = eval_metrics['pass_at_1']
                        self.no_improve_count = 0
                        logger.info(f"New best Pass@1: {self.best_pass_at_1:.2%}")
                    else:
                        self.no_improve_count += 1
                        logger.info(f"No improvement for {self.no_improve_count}/{self.config.early_stopping_patience} evals")

                    if self.no_improve_count >= self.config.early_stopping_patience:
                        print(f"\n>>> EARLY STOPPING: No improvement for {self.config.early_stopping_patience} evals. Best Pass@1 = {self.best_pass_at_1:.2%} <<<\n", flush=True)
                        logger.info(f"Early stopping triggered at step {self.total_timesteps}")
                        break

            # Checkpointing
            if self.total_updates % (self.config.save_interval // self.config.batch_size) == 0:
                self._save_checkpoint()

        total_time = time.time() - start_time
        print(f"\n{'='*70}", flush=True)
        print(f"Training COMPLETE in {total_time:.1f}s ({self.total_timesteps} steps)", flush=True)
        print(f"Final Avg Reward: {self.train_metrics[-1]['avg_reward']:.3f}" if self.train_metrics else "", flush=True)
        print(f"{'='*70}\n", flush=True)
        logger.info(f"Training complete in {total_time:.1f}s")

        # Final save
        self._save_checkpoint()
        self._save_results()

        return {
            "total_timesteps": self.total_timesteps,
            "total_updates": self.total_updates,
            "total_time": total_time,
            "avg_throughput": self.total_timesteps / total_time,
            "train_metrics": self.train_metrics,
            "eval_metrics": self.eval_metrics,
            "timing_metrics": self.timing_metrics,
            "scheduler_stats": self.scheduler.get_statistics() if hasattr(self, 'scheduler') else {},
        }

    def _evaluate(self, num_tasks: int = 50) -> Dict[str, Any]:
        """Evaluate current policy on held-out tasks."""
        self.model.eval()

        # Use full difficulty (level 5) for evaluation
        successes = 0
        total_reward = 0.0

        task_indices = np.random.choice(len(self.tasks), min(num_tasks, len(self.tasks)), replace=False)

        for idx in task_indices:
            task = self.tasks[idx]

            inputs = self.tokenizer(
                task.prompt,
                return_tensors="pt",
                truncation=True,
                max_length=256,
            ).to(self.device)

            with torch.no_grad():
                outputs = self.model.generate(
                    **inputs,
                    max_new_tokens=self.config.max_new_tokens,
                    do_sample=False,  # Greedy decoding for stability (avoid multinomial NaN issues with bfloat16)
                    pad_token_id=self.tokenizer.pad_token_id,
                )

            response = self.tokenizer.decode(
                outputs[0, inputs.input_ids.shape[1]:],
                skip_special_tokens=True,
            )

            # Evaluate
            full_code = task.prompt + response
            env = CurriculumCodeEnv([task])
            env.reset()
            result = env.step(full_code)

            if result.info.get("passed", False):
                successes += 1
            total_reward += result.reward

        return {
            "timestep": self.total_timesteps,
            "pass_at_1": successes / len(task_indices),
            "avg_reward": total_reward / len(task_indices),
            "num_tasks": len(task_indices),
        }

    def _save_checkpoint(self):
        """Save training checkpoint."""
        ckpt_path = self.output_dir / f"checkpoint_{self.total_updates}.pt"
        self.trainer.save_checkpoint(str(ckpt_path))
        logger.info(f"Saved checkpoint to {ckpt_path}")

    def _save_results(self):
        """Save training results."""
        results = {
            "config": {k: str(v) if not isinstance(v, (int, float, bool, str, list, dict)) else v
                      for k, v in self.config.__dict__.items()},
            "train_metrics": self.train_metrics,
            "eval_metrics": self.eval_metrics,
            "timing_metrics": self.timing_metrics,
            "scheduler_stats": self.scheduler.get_statistics(),
        }

        results_path = self.output_dir / "results.json"
        with open(results_path, "w") as f:
            json.dump(results, f, indent=2, default=str)
        logger.info(f"Saved results to {results_path}")


def create_trainer(
    tasks: List[CodeTask],
    output_dir: str = "./experiments/results",
    **config_overrides,
) -> ACEASTrainer:
    """
    Factory function to create an ACEAS trainer.

    Args:
        tasks: List of code tasks
        output_dir: Output directory
        **config_overrides: Override config parameters

    Returns:
        Configured ACEASTrainer
    """
    config = ACEASConfig(**config_overrides)
    return ACEASTrainer(tasks=tasks, config=config, output_dir=output_dir)


if __name__ == "__main__":
    print("Testing ACEASTrainer initialization...")

    # Create synthetic tasks
    from ..code_environment.code_env import create_synthetic_tasks
    tasks = create_synthetic_tasks(20)

    print(f"Created {len(tasks)} tasks")

    # Test config
    config = ACEASConfig(
        model_name="Salesforce/codegen-350M-mono",
        total_timesteps=100,
        batch_size=8,
        num_workers=2,
    )
    print(f"Config: {config}")

    print("\nACEAS Trainer module loaded successfully!")
