"""
ACEAS Async Scheduler.

This module coordinates the full ACEAS system:
- Adaptive curriculum task selection (ACB)
- Execution-aware scheduling (EAAS)
- Curriculum-staleness coupling (CSC)
"""

import ray
import time
import logging
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Any, Tuple
from collections import deque
import numpy as np

from .execution_predictor import ExecutionTimePredictor
from .staleness_control import (
    CurriculumStalenessController,
    StalenessController,
    StaleExperience,
    StalenessConfig,
)

logger = logging.getLogger(__name__)


@dataclass
class SchedulerConfig:
    """Configuration for the ACEAS scheduler."""
    # Worker configuration
    num_workers: int = 4
    experiences_per_worker: int = 8

    # Staleness configuration
    use_csc: bool = True  # Use curriculum-staleness coupling
    eta_base: float = 8.0
    lambda_coupling: float = 0.5

    # Scheduling configuration
    use_execution_aware: bool = True  # Use execution-aware scheduling
    reference_time: float = 0.1

    # Buffer configuration
    buffer_size: int = 1024
    min_buffer_for_training: int = 64

    # Load balancing
    enable_load_balancing: bool = True
    balance_interval: int = 10  # Rebalance every N updates


@dataclass
class WorkerState:
    """State tracking for a worker."""
    worker_id: int
    is_busy: bool = False
    current_difficulty: int = 0
    pending_tasks: int = 0
    completed_tasks: int = 0
    total_execution_time: float = 0.0
    avg_throughput: float = 1.0  # Tasks per second


class ACEASScheduler:
    """
    Main scheduler for ACEAS that coordinates all components.

    Integrates:
    1. ACB for difficulty selection
    2. EAAS for execution-time-aware scheduling
    3. CSC for staleness control
    """

    def __init__(
        self,
        config: Optional[SchedulerConfig] = None,
        acb: Optional[Any] = None,  # AdaptiveCurriculumBandit
        staleness_config: Optional[StalenessConfig] = None,
    ):
        """
        Initialize the ACEAS scheduler.

        Args:
            config: SchedulerConfig
            acb: AdaptiveCurriculumBandit instance (optional, will use internal)
            staleness_config: Config for staleness control
        """
        self.config = config or SchedulerConfig()

        # Execution time predictor
        self.time_predictor = ExecutionTimePredictor(
            reference_time=self.config.reference_time,
            use_history=True,
        )

        # Staleness controller
        if self.config.use_csc:
            s_config = staleness_config or StalenessConfig(
                eta_base=self.config.eta_base,
                lambda_coupling=self.config.lambda_coupling,
            )
            self.staleness_controller = CurriculumStalenessController(s_config)
        else:
            self.staleness_controller = StalenessController(eta=self.config.eta_base)

        # Store reference to ACB (will be set externally)
        self.acb = acb

        # Worker states
        self.worker_states: Dict[int, WorkerState] = {
            i: WorkerState(worker_id=i)
            for i in range(self.config.num_workers)
        }

        # Experience buffer
        self.experience_buffer: deque = deque(maxlen=self.config.buffer_size)

        # Policy version tracking
        self.policy_version = 0

        # Statistics
        self.total_experiences = 0
        self.total_updates = 0
        self.scheduling_decisions: List[Dict[str, Any]] = []

    def set_acb(self, acb: Any):
        """Set the ACB instance."""
        self.acb = acb

    def select_difficulty_for_worker(self, worker_id: int) -> int:
        """
        Select difficulty level for a worker's next batch.

        Considers:
        - ACB's recommendation
        - Worker's historical performance
        - Load balancing across workers

        Args:
            worker_id: Worker identifier

        Returns:
            Selected difficulty level (1-5)
        """
        if self.acb is None:
            # No ACB, use uniform random
            return np.random.randint(1, 6)

        # Get ACB recommendation
        from ..curriculum.difficulty_levels import DifficultyLevel
        selected = self.acb.select_difficulty()
        difficulty = selected.value

        # Optional: Adjust based on worker performance
        if self.config.enable_load_balancing:
            worker = self.worker_states[worker_id]
            # Slower workers get easier tasks
            if worker.avg_throughput < 0.5:
                difficulty = max(1, difficulty - 1)

        return difficulty

    def compute_task_priority(
        self,
        difficulty: int,
        predicted_exec_time: float,
    ) -> float:
        """
        Compute scheduling priority for a task.

        Higher priority = should be scheduled sooner.

        Args:
            difficulty: Task difficulty level
            predicted_exec_time: Predicted execution time

        Returns:
            Priority score
        """
        # Base priority from difficulty (prefer moderate difficulties)
        difficulty_priority = 1.0 - abs(difficulty - 3) / 4.0

        # Time priority (prefer tasks that match reference time)
        time_ratio = predicted_exec_time / self.config.reference_time
        if time_ratio < 1:
            time_priority = time_ratio  # Fast tasks get lower priority
        else:
            time_priority = 1.0 / time_ratio  # Slow tasks also get lower priority

        return 0.5 * difficulty_priority + 0.5 * time_priority

    def assign_worker_batch_size(self, worker_id: int) -> int:
        """
        Determine batch size for a worker based on performance.

        Args:
            worker_id: Worker identifier

        Returns:
            Number of experiences to collect
        """
        base_size = self.config.experiences_per_worker
        worker = self.worker_states[worker_id]

        if not self.config.enable_load_balancing:
            return base_size

        # Adjust based on worker throughput
        avg_throughput = worker.avg_throughput
        global_avg = np.mean([w.avg_throughput for w in self.worker_states.values()])

        if global_avg > 0:
            ratio = avg_throughput / global_avg
            adjusted_size = int(base_size * ratio)
            adjusted_size = max(4, min(16, adjusted_size))
            return adjusted_size

        return base_size

    def add_experiences(
        self,
        experiences: List[StaleExperience],
        worker_id: int,
    ):
        """
        Add new experiences to the buffer.

        Args:
            experiences: List of experiences from worker
            worker_id: Worker that generated the experiences
        """
        for exp in experiences:
            # Update execution time predictor
            self.time_predictor.update(
                exp.execution_time,
                task_id=exp.task_id,
                difficulty=exp.difficulty,
            )

            # Add to buffer with current policy version
            self.experience_buffer.append(exp)

        self.total_experiences += len(experiences)

        # Update worker stats
        worker = self.worker_states[worker_id]
        worker.completed_tasks += len(experiences)
        worker.total_execution_time += sum(exp.execution_time for exp in experiences)
        if worker.total_execution_time > 0:
            worker.avg_throughput = worker.completed_tasks / worker.total_execution_time

    def get_training_batch(
        self,
        batch_size: int,
    ) -> Tuple[List[StaleExperience], List[float]]:
        """
        Get a batch of experiences for training.

        Applies staleness filtering and importance weighting.

        Args:
            batch_size: Desired batch size

        Returns:
            Tuple of (experiences, importance_weights)
        """
        if len(self.experience_buffer) < self.config.min_buffer_for_training:
            return [], []

        # Convert buffer to list for processing
        all_experiences = list(self.experience_buffer)

        # Filter using staleness controller
        if isinstance(self.staleness_controller, CurriculumStalenessController):
            filtered = self.staleness_controller.sample_batch(
                all_experiences,
                batch_size=min(batch_size, len(all_experiences)),
                prioritized=True,
            )
        else:
            filtered = self.staleness_controller.filter_buffer(all_experiences)
            if len(filtered) > batch_size:
                indices = np.random.choice(len(filtered), batch_size, replace=False)
                filtered = [filtered[i] for i in indices]

        # Compute importance weights
        weights = [
            self.staleness_controller.compute_importance_weight(exp)
            for exp in filtered
        ]

        return filtered, weights

    def on_policy_update(self):
        """Called after each policy update."""
        self.policy_version += 1
        self.staleness_controller.increment_policy_version()
        self.total_updates += 1

    def update_curriculum(
        self,
        difficulties: List[int],
        rewards: List[float],
        successes: List[bool],
        gradient_magnitudes: Optional[List[float]] = None,
    ):
        """
        Update the curriculum (ACB) with observed outcomes.

        Args:
            difficulties: Difficulty levels used
            rewards: Rewards obtained
            successes: Success indicators
            gradient_magnitudes: Optional gradient magnitudes
        """
        if self.acb is None:
            return

        from ..curriculum.difficulty_levels import DifficultyLevel
        difficulty_enums = [DifficultyLevel(d) for d in difficulties]

        self.acb.update_batch(
            difficulty_enums,
            rewards,
            successes,
            gradient_magnitudes,
        )

    def get_statistics(self) -> Dict[str, Any]:
        """Get comprehensive scheduler statistics."""
        stats = {
            "total_experiences": self.total_experiences,
            "total_updates": self.total_updates,
            "policy_version": self.policy_version,
            "buffer_size": len(self.experience_buffer),
            "time_predictor": self.time_predictor.get_statistics(),
            "staleness": self.staleness_controller.get_statistics(),
            "workers": {},
        }

        for worker_id, worker in self.worker_states.items():
            stats["workers"][worker_id] = {
                "completed_tasks": worker.completed_tasks,
                "avg_throughput": worker.avg_throughput,
            }

        if self.acb is not None:
            stats["curriculum"] = self.acb.get_statistics()

        return stats

    def get_difficulty_distribution(self) -> Dict[int, float]:
        """Get current difficulty distribution from ACB."""
        if self.acb is None:
            return {d: 0.2 for d in range(1, 6)}

        from ..curriculum.difficulty_levels import DifficultyLevel
        dist = self.acb.get_recent_difficulty_distribution()
        return {level.value: prob for level, prob in dist.items()}


def create_scheduler(
    num_workers: int = 4,
    use_csc: bool = True,
    use_eaas: bool = True,
    eta_base: float = 8.0,
    lambda_coupling: float = 0.5,
) -> ACEASScheduler:
    """
    Factory function to create an ACEAS scheduler.

    Args:
        num_workers: Number of rollout workers
        use_csc: Enable curriculum-staleness coupling
        use_eaas: Enable execution-aware scheduling
        eta_base: Base staleness threshold
        lambda_coupling: CSC coupling strength

    Returns:
        Configured ACEASScheduler
    """
    config = SchedulerConfig(
        num_workers=num_workers,
        use_csc=use_csc,
        use_execution_aware=use_eaas,
        eta_base=eta_base,
        lambda_coupling=lambda_coupling,
    )

    staleness_config = StalenessConfig(
        eta_base=eta_base,
        lambda_coupling=lambda_coupling,
    )

    return ACEASScheduler(config=config, staleness_config=staleness_config)


if __name__ == "__main__":
    print("Testing ACEASScheduler...")

    # Create scheduler
    scheduler = create_scheduler(
        num_workers=4,
        use_csc=True,
        use_eaas=True,
    )

    # Simulate adding experiences
    for batch in range(5):
        worker_id = batch % 4
        difficulty = scheduler.select_difficulty_for_worker(worker_id)

        experiences = []
        for i in range(8):
            exp = StaleExperience(
                prompt=f"prompt_{batch}_{i}",
                response=f"response_{batch}_{i}",
                reward=np.random.random(),
                log_prob=-np.random.random() * 2,
                value=np.random.random(),
                difficulty=difficulty,
                policy_version=scheduler.policy_version,
                timestamp=time.time(),
                task_id=f"task_{batch}_{i}",
                execution_time=0.05 + 0.1 * np.random.random(),
            )
            experiences.append(exp)

        scheduler.add_experiences(experiences, worker_id)
        print(f"Batch {batch}: Added {len(experiences)} experiences at difficulty {difficulty}")

        if batch % 2 == 1:
            scheduler.on_policy_update()

    # Get training batch
    batch, weights = scheduler.get_training_batch(batch_size=16)
    print(f"\nTraining batch: {len(batch)} experiences")
    print(f"Weight range: [{min(weights):.3f}, {max(weights):.3f}]")

    # Print statistics
    stats = scheduler.get_statistics()
    print(f"\nStatistics:")
    print(f"  Total experiences: {stats['total_experiences']}")
    print(f"  Policy version: {stats['policy_version']}")
    print(f"  Buffer size: {stats['buffer_size']}")
    print(f"  Staleness discard rate: {stats['staleness']['discard_rate']:.3f}")

    print("\nAll tests passed!")
