"""
Curriculum-Aware Task Sampler.

This module provides task sampling that respects curriculum difficulty
and integrates with the ACB for adaptive selection.
"""

import random
from typing import List, Optional, Dict, Any
from dataclasses import dataclass

from .difficulty_levels import (
    DifficultyLevel,
    CurriculumTask,
    CurriculumTaskGenerator,
    CodeTask,
)
from .adaptive_bandit import AdaptiveCurriculumBandit, FixedCurriculum


@dataclass
class SampledBatch:
    """A batch of sampled curriculum tasks with metadata."""
    tasks: List[CurriculumTask]
    difficulties: List[DifficultyLevel]
    task_indices: List[int]  # Original task indices


class CurriculumTaskSampler:
    """
    Task sampler that integrates curriculum learning with task selection.

    Supports:
    - Adaptive difficulty selection via ACB
    - Fixed curriculum schedules
    - Mixed-difficulty batches
    - Stratified sampling
    """

    def __init__(
        self,
        tasks: List[CodeTask],
        curriculum_strategy: str = "adaptive",
        acb_config: Optional[Dict[str, Any]] = None,
        total_steps: Optional[int] = None,
    ):
        """
        Initialize the task sampler.

        Args:
            tasks: List of original CodeTask objects
            curriculum_strategy: "adaptive" (ACB), "fixed", or "uniform"
            acb_config: Config dict for ACB (if adaptive)
            total_steps: Total training steps (required for fixed curriculum)
        """
        self.task_generator = CurriculumTaskGenerator(tasks)
        self.curriculum_strategy = curriculum_strategy
        self.num_tasks = len(tasks)

        if curriculum_strategy == "adaptive":
            from .adaptive_bandit import ACBConfig
            config = ACBConfig(**(acb_config or {}))
            self.difficulty_selector = AdaptiveCurriculumBandit(config)
        elif curriculum_strategy == "fixed":
            assert total_steps is not None, "total_steps required for fixed curriculum"
            self.difficulty_selector = FixedCurriculum(total_steps)
        else:  # uniform
            self.difficulty_selector = None

        # Statistics
        self.total_samples = 0
        self.samples_per_difficulty: Dict[DifficultyLevel, int] = {
            level: 0 for level in DifficultyLevel
        }

    def sample_batch(
        self,
        batch_size: int,
        difficulty: Optional[DifficultyLevel] = None,
    ) -> SampledBatch:
        """
        Sample a batch of curriculum tasks.

        Args:
            batch_size: Number of tasks to sample
            difficulty: If provided, use this difficulty for all tasks.
                       Otherwise, use the curriculum strategy to select.

        Returns:
            SampledBatch with tasks and metadata
        """
        tasks = []
        difficulties = []
        indices = []

        for _ in range(batch_size):
            # Select difficulty
            if difficulty is not None:
                selected_difficulty = difficulty
            elif self.difficulty_selector is not None:
                selected_difficulty = self.difficulty_selector.select_difficulty()
            else:
                # Uniform random difficulty
                selected_difficulty = random.choice(list(DifficultyLevel))

            # Select random task index
            task_idx = random.randrange(self.num_tasks)

            # Generate curriculum task
            curriculum_task = self.task_generator.generate_task(
                task_idx, selected_difficulty
            )

            tasks.append(curriculum_task)
            difficulties.append(selected_difficulty)
            indices.append(task_idx)

            # Update statistics
            self.total_samples += 1
            self.samples_per_difficulty[selected_difficulty] += 1

        return SampledBatch(
            tasks=tasks,
            difficulties=difficulties,
            task_indices=indices,
        )

    def sample_stratified_batch(
        self,
        batch_size: int,
        stratification: Optional[Dict[DifficultyLevel, float]] = None,
    ) -> SampledBatch:
        """
        Sample a batch with stratified difficulty distribution.

        Args:
            batch_size: Total number of tasks
            stratification: Dict mapping difficulty to proportion.
                          If None, uses uniform distribution.

        Returns:
            SampledBatch with stratified difficulties
        """
        if stratification is None:
            # Uniform distribution
            stratification = {
                level: 1.0 / len(DifficultyLevel)
                for level in DifficultyLevel
            }

        # Normalize
        total = sum(stratification.values())
        stratification = {k: v / total for k, v in stratification.items()}

        # Calculate samples per difficulty
        samples_per_level = {}
        remaining = batch_size
        for i, (level, prop) in enumerate(stratification.items()):
            if i == len(stratification) - 1:
                # Last level gets remaining to handle rounding
                samples_per_level[level] = remaining
            else:
                count = int(batch_size * prop)
                samples_per_level[level] = count
                remaining -= count

        # Sample from each difficulty level
        all_tasks = []
        all_difficulties = []
        all_indices = []

        for level, count in samples_per_level.items():
            if count > 0:
                batch = self.sample_batch(count, difficulty=level)
                all_tasks.extend(batch.tasks)
                all_difficulties.extend(batch.difficulties)
                all_indices.extend(batch.task_indices)

        # Shuffle to mix difficulties
        combined = list(zip(all_tasks, all_difficulties, all_indices))
        random.shuffle(combined)
        all_tasks, all_difficulties, all_indices = zip(*combined) if combined else ([], [], [])

        return SampledBatch(
            tasks=list(all_tasks),
            difficulties=list(all_difficulties),
            task_indices=list(all_indices),
        )

    def update_curriculum(
        self,
        difficulties: List[DifficultyLevel],
        rewards: List[float],
        successes: List[bool],
        gradient_magnitudes: Optional[List[float]] = None,
    ):
        """
        Update the curriculum strategy with observed outcomes.

        Args:
            difficulties: Difficulty levels used
            rewards: Rewards obtained
            successes: Whether tasks were successful
            gradient_magnitudes: Optional gradient magnitudes from training
        """
        if self.curriculum_strategy == "adaptive":
            self.difficulty_selector.update_batch(
                difficulties, rewards, successes, gradient_magnitudes
            )
        elif self.curriculum_strategy == "fixed":
            # Fixed curriculum doesn't need updates
            pass

    def get_statistics(self) -> Dict[str, Any]:
        """Get sampler statistics."""
        stats = {
            "total_samples": self.total_samples,
            "samples_per_difficulty": {
                level.value: count
                for level, count in self.samples_per_difficulty.items()
            },
            "curriculum_strategy": self.curriculum_strategy,
        }

        if self.difficulty_selector is not None:
            stats["curriculum_stats"] = self.difficulty_selector.get_statistics()

        return stats

    def get_current_difficulty_distribution(self) -> Dict[DifficultyLevel, float]:
        """Get the current difficulty selection distribution."""
        if self.curriculum_strategy == "adaptive":
            return self.difficulty_selector.get_recent_difficulty_distribution()
        elif self.curriculum_strategy == "fixed":
            # Return distribution based on current position
            current_diff = self.difficulty_selector.select_difficulty()
            # Undo the step increment from select_difficulty
            self.difficulty_selector.current_step -= 1
            return {
                level: 1.0 if level == current_diff else 0.0
                for level in DifficultyLevel
            }
        else:
            return {level: 1.0 / len(DifficultyLevel) for level in DifficultyLevel}


def create_sampler_from_tasks(
    tasks: List[CodeTask],
    strategy: str = "adaptive",
    **kwargs
) -> CurriculumTaskSampler:
    """
    Convenience function to create a task sampler.

    Args:
        tasks: List of CodeTask objects
        strategy: "adaptive", "fixed", or "uniform"
        **kwargs: Additional arguments for the sampler

    Returns:
        Configured CurriculumTaskSampler
    """
    return CurriculumTaskSampler(
        tasks=tasks,
        curriculum_strategy=strategy,
        **kwargs
    )


if __name__ == "__main__":
    # Test the sampler
    print("Testing CurriculumTaskSampler...")

    # Create synthetic tasks
    test_tasks = [
        CodeTask(
            task_id=f"test_{i}",
            prompt=f"def func_{i}(x):\n    ",
            canonical_solution=f"return x + {i}",
            test_cases=f"assert func_{i}(1) == {i+1}",
            entry_point=f"func_{i}"
        )
        for i in range(10)
    ]

    # Test adaptive sampler
    print("\n1. Testing Adaptive Sampler:")
    sampler = CurriculumTaskSampler(
        tasks=test_tasks,
        curriculum_strategy="adaptive",
    )

    for i in range(5):
        batch = sampler.sample_batch(batch_size=4)
        print(f"  Batch {i+1} difficulties: {[d.value for d in batch.difficulties]}")

        # Simulate outcomes
        rewards = [0.5] * 4
        successes = [True, False, True, False]
        sampler.update_curriculum(batch.difficulties, rewards, successes)

    stats = sampler.get_statistics()
    print(f"  Total samples: {stats['total_samples']}")
    print(f"  Samples per difficulty: {stats['samples_per_difficulty']}")

    # Test fixed sampler
    print("\n2. Testing Fixed Sampler:")
    fixed_sampler = CurriculumTaskSampler(
        tasks=test_tasks,
        curriculum_strategy="fixed",
        total_steps=100,
    )

    for step in [0, 25, 50, 75, 99]:
        fixed_sampler.difficulty_selector.current_step = step
        batch = fixed_sampler.sample_batch(batch_size=2)
        print(f"  Step {step} difficulties: {[d.value for d in batch.difficulties]}")

    # Test stratified sampling
    print("\n3. Testing Stratified Sampling:")
    stratification = {
        DifficultyLevel.LEVEL_1: 0.3,
        DifficultyLevel.LEVEL_2: 0.3,
        DifficultyLevel.LEVEL_3: 0.2,
        DifficultyLevel.LEVEL_4: 0.1,
        DifficultyLevel.LEVEL_5: 0.1,
    }

    batch = sampler.sample_stratified_batch(batch_size=10, stratification=stratification)
    from collections import Counter
    dist = Counter([d.value for d in batch.difficulties])
    print(f"  Stratified batch distribution: {dict(dist)}")

    print("\nAll tests passed!")
