"""
Staleness Control for Curriculum-Staleness Coupling (CSC).

This module implements staleness-aware training that couples
curriculum difficulty with staleness tolerance.
"""

import math
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Any, Tuple
from collections import deque
import numpy as np


@dataclass
class StaleExperience:
    """An experience with staleness tracking."""
    prompt: str
    response: str
    reward: float
    log_prob: float
    value: float
    difficulty: int
    policy_version: int
    timestamp: float
    task_id: str = ""
    execution_time: float = 0.0

    def staleness(self, current_version: int) -> int:
        """Compute staleness as policy version difference."""
        return current_version - self.policy_version


@dataclass
class StalenessConfig:
    """Configuration for staleness control."""
    eta_base: float = 8.0  # Base staleness threshold
    lambda_coupling: float = 0.5  # Coupling strength between difficulty and staleness
    freshness_weight_beta: float = 0.5  # Exponent for freshness weighting
    max_staleness: int = 10  # Absolute maximum staleness
    min_staleness: int = 1  # Minimum staleness threshold


class StalenessController:
    """
    Basic staleness controller without curriculum coupling.

    Implements AReaL-style staleness control with a fixed eta parameter.
    """

    def __init__(
        self,
        eta: float = 8.0,
        max_staleness: int = 10,
    ):
        """
        Initialize staleness controller.

        Args:
            eta: Maximum allowed staleness (policy version difference)
            max_staleness: Absolute cap on staleness
        """
        self.eta = eta
        self.max_staleness = max_staleness
        self.current_policy_version = 0

        # Statistics
        self.total_samples_seen = 0
        self.samples_discarded = 0
        self.staleness_history: List[int] = []

    def increment_policy_version(self):
        """Called after a policy update."""
        self.current_policy_version += 1

    def is_fresh_enough(self, experience: StaleExperience) -> bool:
        """Check if experience is fresh enough to use."""
        staleness = experience.staleness(self.current_policy_version)
        self.total_samples_seen += 1

        if staleness > self.eta:
            self.samples_discarded += 1
            return False

        self.staleness_history.append(staleness)
        return True

    def compute_importance_weight(self, experience: StaleExperience) -> float:
        """
        Compute importance weight for off-policy correction.

        Fresher samples get higher weight.
        """
        staleness = experience.staleness(self.current_policy_version)

        if staleness <= 0:
            return 1.0

        # Exponential decay with staleness
        weight = math.exp(-0.1 * staleness)
        return max(0.1, min(1.0, weight))

    def filter_buffer(
        self,
        experiences: List[StaleExperience]
    ) -> List[StaleExperience]:
        """Filter a buffer of experiences, keeping only fresh ones."""
        return [exp for exp in experiences if self.is_fresh_enough(exp)]

    def get_statistics(self) -> Dict[str, Any]:
        """Get staleness statistics."""
        if not self.staleness_history:
            avg_staleness = 0.0
        else:
            avg_staleness = np.mean(self.staleness_history[-100:])

        return {
            "current_policy_version": self.current_policy_version,
            "eta": self.eta,
            "total_samples_seen": self.total_samples_seen,
            "samples_discarded": self.samples_discarded,
            "discard_rate": self.samples_discarded / max(1, self.total_samples_seen),
            "avg_staleness": avg_staleness,
        }


class CurriculumStalenessController:
    """
    Curriculum-Staleness Coupling (CSC) Controller.

    Key insight: Lower difficulty tasks are more robust to staleness
    (completing last 10% is similar across policy versions).
    Higher difficulty tasks need fresh policy weights.

    max_staleness(d) = eta_base * exp(-lambda * d)
    """

    def __init__(self, config: Optional[StalenessConfig] = None):
        """
        Initialize CSC controller.

        Args:
            config: StalenessConfig with parameters
        """
        self.config = config or StalenessConfig()
        self.current_policy_version = 0

        # Per-difficulty statistics
        self.difficulty_stats: Dict[int, Dict[str, Any]] = {
            d: {"seen": 0, "discarded": 0, "staleness_sum": 0}
            for d in range(1, 6)
        }

        # Global statistics
        self.total_samples_seen = 0
        self.samples_discarded = 0

    def get_max_staleness(self, difficulty: int) -> float:
        """
        Get maximum allowed staleness for a difficulty level.

        max_staleness(d) = eta_base * exp(-lambda * d)

        Args:
            difficulty: Difficulty level (1-5)

        Returns:
            Maximum allowed staleness
        """
        eta = self.config.eta_base * math.exp(
            -self.config.lambda_coupling * difficulty
        )

        # Clamp to configured range
        eta = max(self.config.min_staleness, min(self.config.max_staleness, eta))
        return eta

    def increment_policy_version(self):
        """Called after a policy update."""
        self.current_policy_version += 1

    def is_fresh_enough(self, experience: StaleExperience) -> bool:
        """
        Check if experience is fresh enough based on its difficulty.

        Easy tasks tolerate more staleness; hard tasks need freshness.
        """
        staleness = experience.staleness(self.current_policy_version)
        difficulty = experience.difficulty
        max_staleness = self.get_max_staleness(difficulty)

        self.total_samples_seen += 1
        self.difficulty_stats[difficulty]["seen"] += 1

        if staleness > max_staleness:
            self.samples_discarded += 1
            self.difficulty_stats[difficulty]["discarded"] += 1
            return False

        self.difficulty_stats[difficulty]["staleness_sum"] += staleness
        return True

    def compute_importance_weight(
        self,
        experience: StaleExperience
    ) -> float:
        """
        Compute importance weight combining freshness and difficulty.

        w(sample) = (1 - staleness/max_staleness)^beta * difficulty_weight

        Args:
            experience: The experience to weight

        Returns:
            Importance weight
        """
        staleness = experience.staleness(self.current_policy_version)
        difficulty = experience.difficulty
        max_staleness = self.get_max_staleness(difficulty)

        # Freshness weight
        if max_staleness > 0:
            freshness_ratio = 1.0 - (staleness / max_staleness)
            freshness_ratio = max(0.0, min(1.0, freshness_ratio))
        else:
            freshness_ratio = 1.0

        freshness_weight = freshness_ratio ** self.config.freshness_weight_beta

        # Difficulty weight (harder tasks are more valuable)
        difficulty_weight = 0.5 + 0.1 * difficulty

        return freshness_weight * difficulty_weight

    def compute_priority(self, experience: StaleExperience) -> float:
        """
        Compute sampling priority for prioritized experience replay.

        Higher priority for:
        - Fresher experiences
        - Higher difficulty (harder tasks)
        - Successful completions

        Args:
            experience: The experience to prioritize

        Returns:
            Priority score
        """
        importance_weight = self.compute_importance_weight(experience)

        # Bonus for successful completions
        success_bonus = 1.5 if experience.reward > 0.5 else 1.0

        # Bonus for harder tasks
        difficulty_bonus = 1.0 + 0.2 * (experience.difficulty - 1)

        return importance_weight * success_bonus * difficulty_bonus

    def filter_buffer(
        self,
        experiences: List[StaleExperience],
        backfill: bool = True,
    ) -> List[StaleExperience]:
        """
        Filter a buffer keeping only fresh-enough experiences.

        Args:
            experiences: List of experiences to filter
            backfill: If True, backfill with easier tasks when hard tasks are stale

        Returns:
            Filtered list of experiences
        """
        fresh_experiences = []
        stale_by_difficulty: Dict[int, List[StaleExperience]] = {d: [] for d in range(1, 6)}

        for exp in experiences:
            if self.is_fresh_enough(exp):
                fresh_experiences.append(exp)
            else:
                stale_by_difficulty[exp.difficulty].append(exp)

        if backfill:
            # If we rejected hard tasks, backfill with easier ones that are fresh
            # This maintains training throughput while respecting staleness
            for difficulty in range(5, 0, -1):
                stale_count = len(stale_by_difficulty[difficulty])
                if stale_count > 0 and difficulty > 1:
                    # Try to find fresh experiences from easier difficulties
                    for easier_diff in range(1, difficulty):
                        easier_fresh = [
                            exp for exp in experiences
                            if exp.difficulty == easier_diff and
                            exp.staleness(self.current_policy_version) <= self.get_max_staleness(easier_diff)
                        ]
                        # Don't actually add more, just note the potential

        return fresh_experiences

    def sample_batch(
        self,
        experiences: List[StaleExperience],
        batch_size: int,
        prioritized: bool = True,
    ) -> List[StaleExperience]:
        """
        Sample a batch from experiences with curriculum-aware prioritization.

        Args:
            experiences: Pool of experiences to sample from
            batch_size: Number of experiences to sample
            prioritized: Whether to use prioritized sampling

        Returns:
            Sampled batch of experiences
        """
        # Filter to fresh experiences first
        fresh = self.filter_buffer(experiences, backfill=False)

        if len(fresh) == 0:
            return []

        if len(fresh) <= batch_size:
            return fresh

        if not prioritized:
            # Uniform sampling
            indices = np.random.choice(len(fresh), batch_size, replace=False)
            return [fresh[i] for i in indices]

        # Prioritized sampling
        priorities = np.array([self.compute_priority(exp) for exp in fresh])
        probs = priorities / priorities.sum()

        indices = np.random.choice(len(fresh), batch_size, replace=False, p=probs)
        return [fresh[i] for i in indices]

    def get_statistics(self) -> Dict[str, Any]:
        """Get CSC statistics."""
        stats = {
            "current_policy_version": self.current_policy_version,
            "total_samples_seen": self.total_samples_seen,
            "samples_discarded": self.samples_discarded,
            "discard_rate": self.samples_discarded / max(1, self.total_samples_seen),
            "by_difficulty": {},
        }

        for d in range(1, 6):
            d_stats = self.difficulty_stats[d]
            max_staleness = self.get_max_staleness(d)
            seen = d_stats["seen"]
            discarded = d_stats["discarded"]
            avg_staleness = d_stats["staleness_sum"] / max(1, seen - discarded)

            stats["by_difficulty"][d] = {
                "max_staleness": max_staleness,
                "seen": seen,
                "discarded": discarded,
                "discard_rate": discarded / max(1, seen),
                "avg_staleness": avg_staleness,
            }

        return stats


if __name__ == "__main__":
    print("Testing Staleness Controllers...")

    # Test basic controller
    print("\n1. Basic StalenessController:")
    basic = StalenessController(eta=5)

    for version in range(10):
        exp = StaleExperience(
            prompt="test",
            response="test",
            reward=1.0,
            log_prob=-1.0,
            value=0.5,
            difficulty=3,
            policy_version=version,
            timestamp=0.0,
        )

        basic.current_policy_version = 8
        is_fresh = basic.is_fresh_enough(exp)
        weight = basic.compute_importance_weight(exp)
        print(f"  Version {version}: fresh={is_fresh}, weight={weight:.3f}")

    stats = basic.get_statistics()
    print(f"  Stats: discarded={stats['samples_discarded']}, rate={stats['discard_rate']:.2f}")

    # Test CSC controller
    print("\n2. CurriculumStalenessController (CSC):")
    config = StalenessConfig(eta_base=8.0, lambda_coupling=0.5)
    csc = CurriculumStalenessController(config)
    csc.current_policy_version = 5

    print("  Max staleness by difficulty:")
    for d in range(1, 6):
        max_s = csc.get_max_staleness(d)
        print(f"    Difficulty {d}: max_staleness={max_s:.2f}")

    # Test filtering with different difficulties and versions
    print("\n  Testing freshness check:")
    for difficulty in [1, 3, 5]:
        for version in [3, 4, 5]:
            exp = StaleExperience(
                prompt="test",
                response="test",
                reward=0.5,
                log_prob=-1.0,
                value=0.5,
                difficulty=difficulty,
                policy_version=version,
                timestamp=0.0,
            )

            is_fresh = csc.is_fresh_enough(exp)
            staleness = exp.staleness(csc.current_policy_version)
            max_s = csc.get_max_staleness(difficulty)
            print(f"    Difficulty {difficulty}, staleness {staleness}: "
                  f"max={max_s:.2f}, fresh={is_fresh}")

    stats = csc.get_statistics()
    print(f"\n  Overall discard rate: {stats['discard_rate']:.2f}")
    print("  By difficulty:")
    for d in range(1, 6):
        d_stats = stats['by_difficulty'][d]
        print(f"    D{d}: discard_rate={d_stats['discard_rate']:.2f}")

    print("\nAll tests passed!")
