"""
Natural Difficulty Stratification for Code Generation Tasks.

This module provides functionality to stratify code tasks by "natural" difficulty,
measured as the pass rate of a base model. This complements the truncation-based
curriculum difficulty by measuring inherent task complexity.

The key insight: truncation-based difficulty controls solution space size synthetically,
while natural difficulty captures real-world problem complexity.
"""

import json
import logging
import numpy as np
import torch
from dataclasses import dataclass, field
from enum import IntEnum
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Any
from collections import defaultdict
from transformers import GenerationConfig

logger = logging.getLogger(__name__)


class NaturalDifficultyLevel(IntEnum):
    """Natural difficulty levels based on base model pass rate."""
    EASY = 1        # Base pass rate > 0.8
    MODERATE = 2    # Base pass rate 0.5-0.8
    MEDIUM = 3      # Base pass rate 0.2-0.5
    HARD = 4        # Base pass rate 0.05-0.2
    VERY_HARD = 5   # Base pass rate < 0.05


# Pass rate thresholds for each difficulty level
DIFFICULTY_THRESHOLDS = {
    NaturalDifficultyLevel.EASY: (0.8, 1.0),
    NaturalDifficultyLevel.MODERATE: (0.5, 0.8),
    NaturalDifficultyLevel.MEDIUM: (0.2, 0.5),
    NaturalDifficultyLevel.HARD: (0.05, 0.2),
    NaturalDifficultyLevel.VERY_HARD: (0.0, 0.05),
}


@dataclass
class NaturalDifficultyTask:
    """A task with natural difficulty annotation."""
    task_id: str
    prompt: str
    canonical_solution: str
    test_cases: str
    entry_point: str
    natural_difficulty: NaturalDifficultyLevel
    base_model_pass_rate: float
    source_benchmark: str = "unknown"
    num_samples: int = 0  # Number of samples used for calibration


@dataclass
class DifficultyCalibrationConfig:
    """Configuration for difficulty calibration."""
    num_samples: int = 20  # Samples per task for calibration
    temperature: float = 0.8
    max_new_tokens: int = 256
    batch_size: int = 4
    timeout_seconds: float = 10.0


@dataclass
class DifficultyCalibrationResult:
    """Result of difficulty calibration for a dataset."""
    dataset_name: str
    model_name: str
    task_difficulties: Dict[str, float]  # task_id -> pass_rate
    difficulty_distribution: Dict[int, int]  # level -> count
    calibration_config: Dict[str, Any]
    timestamp: str = ""


def compute_pass_rate_for_task(
    model,
    tokenizer,
    task,
    env,
    config: DifficultyCalibrationConfig,
    device: str = "cuda",
) -> Tuple[float, int]:
    """
    Compute pass rate for a single task using the given model.

    Args:
        model: The model to evaluate
        tokenizer: Tokenizer for the model
        task: CodeTask object
        env: CurriculumCodeEnv for execution
        config: Calibration configuration
        device: Device to use

    Returns:
        Tuple of (pass_rate, num_successful)
    """
    successes = 0
    attempts = config.num_samples

    for _ in range(attempts):
        # Generate completion
        inputs = tokenizer(
            task.prompt,
            return_tensors="pt",
            truncation=True,
            max_length=512,
        ).to(device)

        # Use GenerationConfig to avoid parameter deprecation warnings
        gen_config = GenerationConfig(
            max_new_tokens=config.max_new_tokens,
            temperature=config.temperature,
            top_p=0.95,
            do_sample=True,
            pad_token_id=tokenizer.pad_token_id,
        )

        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                generation_config=gen_config,
            )

        # Decode response
        prompt_length = inputs.input_ids.shape[1]
        generated_ids = outputs[:, prompt_length:]
        response = tokenizer.decode(generated_ids[0], skip_special_tokens=True)

        # Execute and check
        try:
            result = env.step(response)
            if result.info.get("passed", False):
                successes += 1
        except Exception:
            pass  # Execution failed

    pass_rate = successes / attempts if attempts > 0 else 0.0
    return pass_rate, successes


def classify_difficulty(pass_rate: float) -> NaturalDifficultyLevel:
    """
    Classify a task's natural difficulty based on its pass rate.

    Args:
        pass_rate: Base model pass rate (0-1)

    Returns:
        NaturalDifficultyLevel
    """
    for level, (low, high) in DIFFICULTY_THRESHOLDS.items():
        if low <= pass_rate < high:
            return level
    return NaturalDifficultyLevel.VERY_HARD


def calibrate_dataset_difficulty(
    model,
    tokenizer,
    tasks: List[Any],
    env,
    config: DifficultyCalibrationConfig = None,
    device: str = "cuda",
    progress_callback=None,
) -> DifficultyCalibrationResult:
    """
    Calibrate natural difficulty for all tasks in a dataset.

    Args:
        model: Base model for calibration
        tokenizer: Tokenizer
        tasks: List of CodeTask objects
        env: Code execution environment
        config: Calibration configuration
        device: Device to use
        progress_callback: Optional callback(task_idx, total, pass_rate)

    Returns:
        DifficultyCalibrationResult
    """
    if config is None:
        config = DifficultyCalibrationConfig()

    task_difficulties = {}
    difficulty_distribution = defaultdict(int)

    logger.info(f"Calibrating difficulty for {len(tasks)} tasks with {config.num_samples} samples each")

    for idx, task in enumerate(tasks):
        pass_rate, _ = compute_pass_rate_for_task(
            model, tokenizer, task, env, config, device
        )

        task_difficulties[task.task_id] = pass_rate
        difficulty = classify_difficulty(pass_rate)
        difficulty_distribution[difficulty.value] += 1

        if progress_callback:
            progress_callback(idx, len(tasks), pass_rate)

        if (idx + 1) % 10 == 0:
            logger.info(f"Calibrated {idx + 1}/{len(tasks)} tasks")

    # Convert defaultdict to regular dict
    difficulty_distribution = dict(difficulty_distribution)

    import datetime
    result = DifficultyCalibrationResult(
        dataset_name="unknown",
        model_name=str(type(model).__name__),
        task_difficulties=task_difficulties,
        difficulty_distribution=difficulty_distribution,
        calibration_config={
            "num_samples": config.num_samples,
            "temperature": config.temperature,
            "max_new_tokens": config.max_new_tokens,
        },
        timestamp=datetime.datetime.now().isoformat(),
    )

    logger.info(f"Difficulty distribution: {difficulty_distribution}")
    return result


def save_calibration(result: DifficultyCalibrationResult, output_path: str):
    """Save calibration results to JSON."""
    data = {
        "dataset_name": result.dataset_name,
        "model_name": result.model_name,
        "task_difficulties": result.task_difficulties,
        "difficulty_distribution": result.difficulty_distribution,
        "calibration_config": result.calibration_config,
        "timestamp": result.timestamp,
    }
    with open(output_path, "w") as f:
        json.dump(data, f, indent=2)
    logger.info(f"Saved calibration to {output_path}")


def load_calibration(input_path: str) -> DifficultyCalibrationResult:
    """Load calibration results from JSON."""
    with open(input_path, "r") as f:
        data = json.load(f)
    return DifficultyCalibrationResult(
        dataset_name=data["dataset_name"],
        model_name=data["model_name"],
        task_difficulties=data["task_difficulties"],
        difficulty_distribution=data["difficulty_distribution"],
        calibration_config=data["calibration_config"],
        timestamp=data.get("timestamp", ""),
    )


class NaturalDifficultyStratifier:
    """
    Stratifies tasks by natural difficulty based on pre-computed calibration.

    This can be used alongside truncation-based curriculum to ensure
    the staleness budget works on naturally hard tasks, not just truncated ones.
    """

    def __init__(
        self,
        tasks: List[Any],
        calibration: Optional[DifficultyCalibrationResult] = None,
        calibration_path: Optional[str] = None,
    ):
        """
        Initialize the stratifier.

        Args:
            tasks: List of CodeTask objects
            calibration: Pre-computed calibration result
            calibration_path: Path to load calibration from
        """
        self.tasks = tasks
        self.task_map = {t.task_id: t for t in tasks}

        if calibration:
            self.calibration = calibration
        elif calibration_path:
            self.calibration = load_calibration(calibration_path)
        else:
            # No calibration - all tasks treated as MEDIUM
            self.calibration = None
            logger.warning("No calibration provided - using default MEDIUM difficulty")

        # Build stratified task lists
        self._stratify_tasks()

    def _stratify_tasks(self):
        """Build task lists stratified by natural difficulty."""
        self.tasks_by_difficulty = {
            level: [] for level in NaturalDifficultyLevel
        }

        for task in self.tasks:
            if self.calibration and task.task_id in self.calibration.task_difficulties:
                pass_rate = self.calibration.task_difficulties[task.task_id]
                difficulty = classify_difficulty(pass_rate)
            else:
                difficulty = NaturalDifficultyLevel.MEDIUM  # Default

            self.tasks_by_difficulty[difficulty].append(task)

        # Log distribution
        for level in NaturalDifficultyLevel:
            count = len(self.tasks_by_difficulty[level])
            logger.info(f"  {level.name}: {count} tasks")

    def get_tasks_by_difficulty(
        self,
        difficulty: NaturalDifficultyLevel
    ) -> List[Any]:
        """Get all tasks at a specific natural difficulty level."""
        return self.tasks_by_difficulty.get(difficulty, [])

    def sample_task(
        self,
        difficulty: Optional[NaturalDifficultyLevel] = None,
        exclude_ids: Optional[set] = None,
    ) -> Optional[Any]:
        """
        Sample a task, optionally filtered by difficulty.

        Args:
            difficulty: Natural difficulty level (None for any)
            exclude_ids: Task IDs to exclude

        Returns:
            Sampled CodeTask or None if no matching tasks
        """
        if difficulty is not None:
            candidates = self.tasks_by_difficulty[difficulty]
        else:
            candidates = self.tasks

        if exclude_ids:
            candidates = [t for t in candidates if t.task_id not in exclude_ids]

        if not candidates:
            return None

        return np.random.choice(candidates)

    def get_stratified_sample(
        self,
        samples_per_level: int = 10,
        levels: Optional[List[NaturalDifficultyLevel]] = None,
    ) -> List[Any]:
        """
        Get a stratified sample across difficulty levels.

        Args:
            samples_per_level: Number of samples per difficulty level
            levels: Specific levels to include (None for all)

        Returns:
            List of sampled tasks
        """
        if levels is None:
            levels = list(NaturalDifficultyLevel)

        sampled = []
        for level in levels:
            available = self.tasks_by_difficulty[level]
            n = min(samples_per_level, len(available))
            if n > 0:
                sampled.extend(np.random.choice(available, n, replace=False))

        return sampled

    def get_pass_rate(self, task_id: str) -> Optional[float]:
        """Get the base model pass rate for a task."""
        if self.calibration:
            return self.calibration.task_difficulties.get(task_id)
        return None

    def get_difficulty(self, task_id: str) -> NaturalDifficultyLevel:
        """Get the natural difficulty level for a task."""
        pass_rate = self.get_pass_rate(task_id)
        if pass_rate is not None:
            return classify_difficulty(pass_rate)
        return NaturalDifficultyLevel.MEDIUM


def validate_staleness_on_natural_difficulty(
    trainer,
    stratifier: NaturalDifficultyStratifier,
    staleness_levels: List[int] = [0, 2, 4, 6, 8],
    samples_per_cell: int = 50,
) -> Dict[str, Any]:
    """
    Validate that staleness tolerance patterns hold on naturally hard tasks.

    This experiment tests the hypothesis: naturally hard tasks should show
    the same exponential staleness sensitivity as truncation-hard tasks.

    Args:
        trainer: ACEASTrainer with CSC enabled
        stratifier: NaturalDifficultyStratifier with calibration
        staleness_levels: Staleness values to test
        samples_per_cell: Samples per (difficulty, staleness) cell

    Returns:
        Dict with success rates and gradient metrics per cell
    """
    results = {
        "natural_difficulty": {},
        "truncation_difficulty": {},
    }

    # Test on naturally stratified tasks
    for nat_level in NaturalDifficultyLevel:
        nat_tasks = stratifier.get_tasks_by_difficulty(nat_level)
        if not nat_tasks:
            continue

        level_results = {}
        for staleness in staleness_levels:
            successes = 0
            total = 0
            grad_norms = []

            for _ in range(min(samples_per_cell, len(nat_tasks))):
                task = np.random.choice(nat_tasks)
                # Collect experience with forced staleness
                # (This would integrate with the trainer's collection mechanism)
                # For now, we record the structure
                total += 1

            level_results[staleness] = {
                "success_rate": successes / total if total > 0 else 0,
                "num_samples": total,
            }

        results["natural_difficulty"][nat_level.name] = level_results

    return results
