"""
Adaptive Curriculum via Bandit-Based Task Selection (ACB).

This module implements the ACB algorithm that models curriculum selection
as a multi-armed bandit problem, selecting difficulty levels based on
real-time success rates and learning signals.
"""

import numpy as np
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple
from collections import defaultdict
import math
import logging

from .difficulty_levels import DifficultyLevel

logger = logging.getLogger(__name__)


@dataclass
class BanditArm:
    """Statistics for a single difficulty level (arm)."""
    difficulty: DifficultyLevel
    num_pulls: int = 0
    total_reward: float = 0.0
    total_successes: int = 0
    recent_rewards: List[float] = field(default_factory=list)
    recent_gradients: List[float] = field(default_factory=list)

    @property
    def success_rate(self) -> float:
        """Empirical success rate."""
        if self.num_pulls == 0:
            return 0.5  # Optimistic initialization
        return self.total_successes / self.num_pulls

    @property
    def mean_reward(self) -> float:
        """Mean reward from this arm."""
        if self.num_pulls == 0:
            return 0.5  # Optimistic initialization
        return self.total_reward / self.num_pulls

    @property
    def recent_gradient(self) -> float:
        """Average recent gradient magnitude (learning signal)."""
        if not self.recent_gradients:
            return 1.0  # High initial value to encourage exploration
        return np.mean(self.recent_gradients[-20:])


@dataclass
class ACBConfig:
    """Configuration for Adaptive Curriculum Bandit."""
    exploration_constant: float = 1.0  # UCB exploration parameter c
    alpha: float = 0.7  # Weight for UCB vs learning signal
    window_size: int = 50  # Window for recent statistics
    min_pulls_before_ucb: int = 5  # Min pulls before using UCB
    gradient_weight: float = 0.3  # Weight for gradient-based scoring
    use_learning_signal: bool = True  # Whether to use gradient magnitude


class AdaptiveCurriculumBandit:
    """
    Adaptive Curriculum Bandit (ACB) for difficulty selection.

    Uses UCB (Upper Confidence Bound) combined with learning signals
    to select the optimal difficulty level for curriculum learning.

    The key insight is that we want to:
    1. Explore difficulties where we're uncertain (UCB exploration)
    2. Exploit difficulties where we're learning (high gradient)
    3. Balance between too easy (no learning) and too hard (no success)
    """

    def __init__(self, config: Optional[ACBConfig] = None):
        """
        Initialize the ACB.

        Args:
            config: ACBConfig with algorithm parameters
        """
        self.config = config or ACBConfig()
        self.total_steps = 0

        # Initialize arms for each difficulty level
        self.arms: Dict[DifficultyLevel, BanditArm] = {
            level: BanditArm(difficulty=level)
            for level in DifficultyLevel
        }

        # History for analysis
        self.selection_history: List[DifficultyLevel] = []
        self.score_history: List[Dict[DifficultyLevel, float]] = []

    def compute_ucb_score(self, arm: BanditArm) -> float:
        """
        Compute UCB score for an arm.

        UCB(d) = mean_reward(d) + c * sqrt(log(N) / n_d)

        Args:
            arm: BanditArm to score

        Returns:
            UCB score
        """
        if arm.num_pulls < self.config.min_pulls_before_ucb:
            # Return high score to encourage exploration
            return 10.0 - arm.num_pulls

        if self.total_steps == 0:
            return arm.mean_reward

        exploration_bonus = self.config.exploration_constant * math.sqrt(
            math.log(self.total_steps + 1) / (arm.num_pulls + 1)
        )

        return arm.mean_reward + exploration_bonus

    def compute_learning_signal_score(self, arm: BanditArm) -> float:
        """
        Compute learning signal score based on gradient magnitude.

        Higher gradient magnitude indicates the model is learning from
        tasks at this difficulty level.

        Args:
            arm: BanditArm to score

        Returns:
            Learning signal score (0 to 1)
        """
        if not self.config.use_learning_signal:
            return 0.5

        gradient_mag = arm.recent_gradient

        # Normalize gradient to [0, 1] range using sigmoid-like scaling
        # Gradients in range [0.001, 0.1] are typical
        normalized = 1.0 / (1.0 + math.exp(-10 * (gradient_mag - 0.01)))

        return normalized

    def compute_combined_score(self, arm: BanditArm) -> float:
        """
        Compute combined score for difficulty selection.

        score(d) = alpha * UCB(d) + (1-alpha) * learning_signal(d)

        Args:
            arm: BanditArm to score

        Returns:
            Combined score
        """
        ucb_score = self.compute_ucb_score(arm)
        learning_score = self.compute_learning_signal_score(arm)

        alpha = self.config.alpha
        combined = alpha * ucb_score + (1 - alpha) * learning_score

        # Add small bonus for moderate difficulties (curriculum prior)
        # This encodes our belief that mid-range difficulties are often best
        difficulty_prior = 1.0 - abs(arm.difficulty.value - 3) / 4.0
        combined += 0.1 * difficulty_prior

        return combined

    def select_difficulty(self) -> DifficultyLevel:
        """
        Select the next difficulty level using ACB.

        Returns:
            Selected DifficultyLevel
        """
        scores = {}
        for level, arm in self.arms.items():
            scores[level] = self.compute_combined_score(arm)

        # Select difficulty with highest score
        selected = max(scores, key=scores.get)

        # Record for analysis
        self.selection_history.append(selected)
        self.score_history.append(scores.copy())

        logger.debug(f"ACB selected difficulty {selected.value}, scores: {scores}")

        return selected

    def select_difficulty_softmax(self, temperature: float = 1.0) -> DifficultyLevel:
        """
        Select difficulty using softmax over scores (stochastic).

        This adds more exploration compared to argmax selection.

        Args:
            temperature: Softmax temperature (higher = more random)

        Returns:
            Selected DifficultyLevel
        """
        scores = {}
        for level, arm in self.arms.items():
            scores[level] = self.compute_combined_score(arm)

        # Convert to numpy arrays for softmax
        levels = list(scores.keys())
        score_values = np.array([scores[l] for l in levels])

        # Softmax with temperature
        exp_scores = np.exp((score_values - score_values.max()) / temperature)
        probs = exp_scores / exp_scores.sum()

        # Sample according to probabilities
        idx = np.random.choice(len(levels), p=probs)
        selected = levels[idx]

        # Record for analysis
        self.selection_history.append(selected)
        self.score_history.append(scores.copy())

        return selected

    def update(
        self,
        difficulty: DifficultyLevel,
        reward: float,
        success: bool,
        gradient_magnitude: Optional[float] = None
    ):
        """
        Update statistics after observing outcome.

        Args:
            difficulty: The difficulty level that was used
            reward: Reward obtained (e.g., test pass rate)
            success: Whether the task was solved successfully
            gradient_magnitude: Optional gradient magnitude from training
        """
        arm = self.arms[difficulty]

        arm.num_pulls += 1
        arm.total_reward += reward
        if success:
            arm.total_successes += 1

        # Update recent history
        arm.recent_rewards.append(reward)
        if len(arm.recent_rewards) > self.config.window_size:
            arm.recent_rewards.pop(0)

        if gradient_magnitude is not None:
            arm.recent_gradients.append(gradient_magnitude)
            if len(arm.recent_gradients) > self.config.window_size:
                arm.recent_gradients.pop(0)

        self.total_steps += 1

    def update_batch(
        self,
        difficulties: List[DifficultyLevel],
        rewards: List[float],
        successes: List[bool],
        gradient_magnitudes: Optional[List[float]] = None
    ):
        """
        Update statistics for a batch of outcomes.

        Args:
            difficulties: List of difficulty levels used
            rewards: List of rewards obtained
            successes: List of success indicators
            gradient_magnitudes: Optional list of gradient magnitudes
        """
        if gradient_magnitudes is None:
            gradient_magnitudes = [None] * len(difficulties)

        for diff, reward, success, grad in zip(
            difficulties, rewards, successes, gradient_magnitudes
        ):
            self.update(diff, reward, success, grad)

    def get_statistics(self) -> Dict[str, any]:
        """
        Get current statistics for all arms.

        Returns:
            Dictionary with statistics per difficulty level
        """
        stats = {}
        for level, arm in self.arms.items():
            stats[f"level_{level.value}"] = {
                "num_pulls": arm.num_pulls,
                "success_rate": arm.success_rate,
                "mean_reward": arm.mean_reward,
                "recent_gradient": arm.recent_gradient,
                "ucb_score": self.compute_ucb_score(arm),
                "combined_score": self.compute_combined_score(arm),
            }
        stats["total_steps"] = self.total_steps
        return stats

    def get_difficulty_distribution(self) -> Dict[DifficultyLevel, float]:
        """
        Get the empirical distribution of difficulty selections.

        Returns:
            Dictionary mapping difficulty to selection probability
        """
        if not self.selection_history:
            # Uniform distribution if no history
            return {level: 1.0 / len(DifficultyLevel) for level in DifficultyLevel}

        counts = defaultdict(int)
        for level in self.selection_history:
            counts[level] += 1

        total = len(self.selection_history)
        return {level: counts[level] / total for level in DifficultyLevel}

    def get_recent_difficulty_distribution(
        self,
        window: int = 100
    ) -> Dict[DifficultyLevel, float]:
        """
        Get recent distribution of difficulty selections.

        Args:
            window: Number of recent selections to consider

        Returns:
            Dictionary mapping difficulty to selection probability
        """
        if not self.selection_history:
            return {level: 1.0 / len(DifficultyLevel) for level in DifficultyLevel}

        recent = self.selection_history[-window:]
        counts = defaultdict(int)
        for level in recent:
            counts[level] += 1

        total = len(recent)
        return {level: counts[level] / total for level in DifficultyLevel}

    def reset(self):
        """Reset all statistics."""
        self.total_steps = 0
        for arm in self.arms.values():
            arm.num_pulls = 0
            arm.total_reward = 0.0
            arm.total_successes = 0
            arm.recent_rewards = []
            arm.recent_gradients = []
        self.selection_history = []
        self.score_history = []


class FixedCurriculum:
    """
    Fixed curriculum baseline (StepCoder CCCS style).

    Progressively increases difficulty over training steps.
    """

    def __init__(
        self,
        total_steps: int,
        warmup_ratio: float = 0.1,
        schedule: str = "linear"
    ):
        """
        Initialize fixed curriculum.

        Args:
            total_steps: Total expected training steps
            warmup_ratio: Ratio of steps to spend at easiest level
            schedule: "linear" or "step" progression
        """
        self.total_steps = total_steps
        self.warmup_ratio = warmup_ratio
        self.schedule = schedule
        self.current_step = 0

    def select_difficulty(self) -> DifficultyLevel:
        """Select difficulty based on current step."""
        progress = self.current_step / max(1, self.total_steps)

        if self.schedule == "linear":
            # Linear progression from level 1 to 5
            if progress < self.warmup_ratio:
                level_float = 1.0
            else:
                adjusted_progress = (progress - self.warmup_ratio) / (1 - self.warmup_ratio)
                level_float = 1.0 + 4.0 * adjusted_progress

            level = int(min(5, max(1, round(level_float))))

        else:  # step schedule
            # Step-wise progression
            thresholds = [0.2, 0.4, 0.6, 0.8, 1.0]
            level = 1
            for i, thresh in enumerate(thresholds):
                if progress < thresh:
                    level = i + 1
                    break
            level = min(5, level)

        self.current_step += 1
        return DifficultyLevel(level)

    def update(self, *args, **kwargs):
        """No-op for fixed curriculum."""
        pass

    def get_statistics(self) -> Dict[str, any]:
        """Return current progress."""
        return {
            "current_step": self.current_step,
            "progress": self.current_step / max(1, self.total_steps),
        }


if __name__ == "__main__":
    # Test the ACB
    print("Testing AdaptiveCurriculumBandit...")

    config = ACBConfig(exploration_constant=1.0, alpha=0.7)
    acb = AdaptiveCurriculumBandit(config)

    # Simulate some training
    np.random.seed(42)

    for step in range(200):
        # Select difficulty
        difficulty = acb.select_difficulty()

        # Simulate outcome (easier levels have higher success probability)
        success_prob = 1.0 - (difficulty.value - 1) * 0.2
        success = np.random.random() < success_prob
        reward = 1.0 if success else 0.0

        # Simulate gradient (higher for levels where we're learning)
        gradient = 0.01 * (3 - abs(difficulty.value - 3))  # Peak at level 3

        # Update
        acb.update(difficulty, reward, success, gradient)

        if (step + 1) % 50 == 0:
            stats = acb.get_statistics()
            dist = acb.get_recent_difficulty_distribution(window=50)
            print(f"\nStep {step + 1}:")
            print(f"  Recent distribution: {dict((k.value, f'{v:.2f}') for k, v in dist.items())}")
            for level in DifficultyLevel:
                s = stats[f"level_{level.value}"]
                print(f"  Level {level.value}: pulls={s['num_pulls']}, "
                      f"success={s['success_rate']:.2f}, score={s['combined_score']:.3f}")

    print("\n\nTesting FixedCurriculum...")
    fixed = FixedCurriculum(total_steps=100, warmup_ratio=0.1, schedule="linear")

    for step in range(100):
        difficulty = fixed.select_difficulty()
        if step % 20 == 0:
            print(f"Step {step}: Difficulty {difficulty.value}")

    print("\nAll tests passed!")
