"""
Difficulty Level Definitions for Code Editing Curriculum.

This module defines the curriculum difficulty levels based on code completion
percentage, following the StepCoder CCCS approach but with adaptive selection.
"""

from dataclasses import dataclass
from enum import IntEnum
from typing import List, Tuple, Optional
import random


class DifficultyLevel(IntEnum):
    """Curriculum difficulty levels for code completion."""
    LEVEL_1 = 1  # Complete last 10% of solution
    LEVEL_2 = 2  # Complete last 30% of solution
    LEVEL_3 = 3  # Complete last 50% of solution
    LEVEL_4 = 4  # Complete last 70% of solution
    LEVEL_5 = 5  # Generate from scratch (full problem)


@dataclass
class CurriculumTask:
    """A task with curriculum-based difficulty."""
    task_id: str
    original_prompt: str
    curriculum_prompt: str  # Prompt with partial solution
    canonical_solution: str
    remaining_solution: str  # What agent needs to complete
    test_cases: str
    entry_point: str
    difficulty: DifficultyLevel
    completion_ratio: float  # How much of solution is revealed (0.0 to 1.0)


@dataclass
class CodeTask:
    """Original code task without curriculum modification."""
    task_id: str
    prompt: str
    canonical_solution: str
    test_cases: str
    entry_point: str


class CurriculumTaskGenerator:
    """
    Generates curriculum tasks at different difficulty levels.

    For each original task, creates versions where partial solutions
    are revealed to make the task easier for curriculum learning.
    """

    # Completion ratios for each difficulty level
    COMPLETION_RATIOS = {
        DifficultyLevel.LEVEL_1: 0.9,   # 90% revealed, complete last 10%
        DifficultyLevel.LEVEL_2: 0.7,   # 70% revealed, complete last 30%
        DifficultyLevel.LEVEL_3: 0.5,   # 50% revealed, complete last 50%
        DifficultyLevel.LEVEL_4: 0.3,   # 30% revealed, complete last 70%
        DifficultyLevel.LEVEL_5: 0.0,   # 0% revealed, generate from scratch
    }

    def __init__(self, tasks: List[CodeTask]):
        """
        Initialize with a list of original code tasks.

        Args:
            tasks: List of CodeTask objects
        """
        self.original_tasks = tasks
        self._curriculum_cache = {}  # Cache generated curriculum tasks

    def generate_task(
        self,
        task_idx: int,
        difficulty: DifficultyLevel
    ) -> CurriculumTask:
        """
        Generate a curriculum task at specified difficulty.

        Args:
            task_idx: Index of original task
            difficulty: Desired difficulty level

        Returns:
            CurriculumTask with appropriate partial solution
        """
        cache_key = (task_idx, difficulty)
        if cache_key in self._curriculum_cache:
            return self._curriculum_cache[cache_key]

        task = self.original_tasks[task_idx]
        completion_ratio = self.COMPLETION_RATIOS[difficulty]

        # Split solution at appropriate point
        solution_lines = task.canonical_solution.strip().split('\n')
        total_lines = len(solution_lines)

        if total_lines == 0:
            # Edge case: empty solution
            revealed_solution = ""
            remaining_solution = task.canonical_solution
        else:
            # Calculate split point based on completion ratio
            split_point = int(total_lines * completion_ratio)
            split_point = max(0, min(split_point, total_lines - 1))

            revealed_lines = solution_lines[:split_point]
            remaining_lines = solution_lines[split_point:]

            revealed_solution = '\n'.join(revealed_lines)
            remaining_solution = '\n'.join(remaining_lines)

        # Create curriculum prompt
        if completion_ratio > 0 and revealed_solution:
            # Indent revealed solution properly
            curriculum_prompt = task.prompt + revealed_solution
            if not curriculum_prompt.endswith('\n'):
                curriculum_prompt += '\n'
        else:
            curriculum_prompt = task.prompt

        curriculum_task = CurriculumTask(
            task_id=f"{task.task_id}_d{difficulty.value}",
            original_prompt=task.prompt,
            curriculum_prompt=curriculum_prompt,
            canonical_solution=task.canonical_solution,
            remaining_solution=remaining_solution,
            test_cases=task.test_cases,
            entry_point=task.entry_point,
            difficulty=difficulty,
            completion_ratio=completion_ratio,
        )

        self._curriculum_cache[cache_key] = curriculum_task
        return curriculum_task

    def generate_batch(
        self,
        difficulty: DifficultyLevel,
        batch_size: int,
        replace: bool = True
    ) -> List[CurriculumTask]:
        """
        Generate a batch of tasks at specified difficulty.

        Args:
            difficulty: Difficulty level for all tasks
            batch_size: Number of tasks to generate
            replace: Whether to sample with replacement

        Returns:
            List of CurriculumTask objects
        """
        if replace:
            indices = [random.randrange(len(self.original_tasks)) for _ in range(batch_size)]
        else:
            indices = random.sample(
                range(len(self.original_tasks)),
                min(batch_size, len(self.original_tasks))
            )

        return [self.generate_task(idx, difficulty) for idx in indices]

    def generate_mixed_batch(
        self,
        difficulty_distribution: dict,
        batch_size: int
    ) -> List[CurriculumTask]:
        """
        Generate a batch with mixed difficulties according to distribution.

        Args:
            difficulty_distribution: Dict mapping DifficultyLevel to probability
            batch_size: Total number of tasks

        Returns:
            List of CurriculumTask objects with mixed difficulties
        """
        tasks = []
        difficulties = list(difficulty_distribution.keys())
        probs = [difficulty_distribution[d] for d in difficulties]

        # Normalize probabilities
        total = sum(probs)
        probs = [p / total for p in probs]

        for _ in range(batch_size):
            # Sample difficulty according to distribution
            difficulty = random.choices(difficulties, weights=probs, k=1)[0]
            task_idx = random.randrange(len(self.original_tasks))
            tasks.append(self.generate_task(task_idx, difficulty))

        return tasks

    def get_all_tasks_at_difficulty(
        self,
        difficulty: DifficultyLevel
    ) -> List[CurriculumTask]:
        """Get all original tasks converted to specified difficulty."""
        return [
            self.generate_task(i, difficulty)
            for i in range(len(self.original_tasks))
        ]

    def estimate_task_complexity(self, task: CurriculumTask) -> float:
        """
        Estimate the complexity of completing a curriculum task.

        Higher values indicate more complex tasks (more to complete).

        Args:
            task: CurriculumTask to analyze

        Returns:
            Complexity score between 0 and 1
        """
        # Base complexity from difficulty level
        base_complexity = (task.difficulty.value - 1) / 4.0

        # Adjust based on remaining solution length
        remaining_lines = len(task.remaining_solution.strip().split('\n'))
        total_lines = len(task.canonical_solution.strip().split('\n'))

        if total_lines > 0:
            length_factor = remaining_lines / total_lines
        else:
            length_factor = 0.5

        # Combine factors
        complexity = 0.6 * base_complexity + 0.4 * length_factor
        return min(1.0, max(0.0, complexity))


def create_curriculum_from_humaneval(tasks: List[CodeTask]) -> CurriculumTaskGenerator:
    """Create a curriculum generator from HumanEval tasks."""
    return CurriculumTaskGenerator(tasks)


if __name__ == "__main__":
    # Test the curriculum generator
    print("Testing CurriculumTaskGenerator...")

    # Create synthetic tasks for testing
    test_tasks = [
        CodeTask(
            task_id="test_0",
            prompt="def add(a, b):\n    \"\"\"Add two numbers.\"\"\"\n    ",
            canonical_solution="result = a + b\n    return result",
            test_cases="assert add(1, 2) == 3",
            entry_point="add"
        ),
        CodeTask(
            task_id="test_1",
            prompt="def multiply(a, b):\n    \"\"\"Multiply two numbers.\"\"\"\n    ",
            canonical_solution="product = a * b\n    return product",
            test_cases="assert multiply(2, 3) == 6",
            entry_point="multiply"
        ),
    ]

    generator = CurriculumTaskGenerator(test_tasks)

    # Test each difficulty level
    for difficulty in DifficultyLevel:
        task = generator.generate_task(0, difficulty)
        print(f"\nDifficulty {difficulty.value}:")
        print(f"  Completion ratio: {task.completion_ratio}")
        print(f"  Curriculum prompt: {repr(task.curriculum_prompt[:50])}...")
        print(f"  Remaining solution: {repr(task.remaining_solution[:30])}...")
        print(f"  Complexity: {generator.estimate_task_complexity(task):.3f}")

    # Test batch generation
    batch = generator.generate_batch(DifficultyLevel.LEVEL_3, batch_size=4)
    print(f"\nGenerated batch of {len(batch)} tasks at Level 3")

    print("\nAll tests passed!")
