from enum import Enum
from typing import Dict, List, Optional, Any, Tuple
from dataclasses import dataclass
import numpy as np


class CurriculumStage(Enum):
    STRICT_RECONSTRUCTION = 1
    SOFT_TOPOLOGY_LEARNING = 2
    TOPOLOGY_PROPERTY_OPTIMIZATION = 3
    FREE_EXPLORATION = 4


@dataclass
class CurriculumConfig:
    stage: CurriculumStage
    set_bc_weight: float = 1.0
    topology_constraint_strength: float = 1.0
    property_reward_weight: float = 0.0
    exploration_bonus: float = 0.0
    hard_masking: bool = True
    max_over_connections: int = 0


class CurriculumScheduler:
    def __init__(self, total_iterations: int = 10000):
        self.total_iterations = total_iterations
        self.current_iteration = 0

        # Define stage transitions (as fractions of total iterations)
        self.stage_transitions = {
            CurriculumStage.STRICT_RECONSTRUCTION: 0.25,      # First 25%
            CurriculumStage.SOFT_TOPOLOGY_LEARNING: 0.50,    # 25% - 50%
            CurriculumStage.TOPOLOGY_PROPERTY_OPTIMIZATION: 0.75,  # 50% - 75%
            CurriculumStage.FREE_EXPLORATION: 1.0             # 75% - 100%
        }

        self.stage_configs = self._initialize_stage_configs()

    def _initialize_stage_configs(self) -> Dict[CurriculumStage, CurriculumConfig]:
        return {
            CurriculumStage.STRICT_RECONSTRUCTION: CurriculumConfig(
                stage=CurriculumStage.STRICT_RECONSTRUCTION,
                set_bc_weight=1.0,
                topology_constraint_strength=1.0,
                property_reward_weight=0.0,
                exploration_bonus=0.0,
                hard_masking=True,
                max_over_connections=0
            ),

            CurriculumStage.SOFT_TOPOLOGY_LEARNING: CurriculumConfig(
                stage=CurriculumStage.SOFT_TOPOLOGY_LEARNING,
                set_bc_weight=0.3,
                topology_constraint_strength=0.8,
                property_reward_weight=0.1,
                exploration_bonus=0.1,
                hard_masking=False,
                max_over_connections=1
            ),

            CurriculumStage.TOPOLOGY_PROPERTY_OPTIMIZATION: CurriculumConfig(
                stage=CurriculumStage.TOPOLOGY_PROPERTY_OPTIMIZATION,
                set_bc_weight=0.1,
                topology_constraint_strength=0.6,
                property_reward_weight=0.3,
                exploration_bonus=0.2,
                hard_masking=False,
                max_over_connections=2
            ),

            CurriculumStage.FREE_EXPLORATION: CurriculumConfig(
                stage=CurriculumStage.FREE_EXPLORATION,
                set_bc_weight=0.05,
                topology_constraint_strength=0.3,
                property_reward_weight=0.5,
                exploration_bonus=0.3,
                hard_masking=False,
                max_over_connections=5
            )
        }

    def get_current_stage(self) -> CurriculumStage:
        progress = self.current_iteration / self.total_iterations

        for stage, threshold in self.stage_transitions.items():
            if progress <= threshold:
                return stage

        return CurriculumStage.FREE_EXPLORATION

    def get_current_config(self) -> CurriculumConfig:
        current_stage = self.get_current_stage()
        base_config = self.stage_configs[current_stage]

        # Apply progressive changes within stages
        progress_in_stage = self._get_progress_in_current_stage()
        adapted_config = self._adapt_config_for_progress(base_config, progress_in_stage)

        return adapted_config

    def _get_progress_in_current_stage(self) -> float:
        overall_progress = self.current_iteration / self.total_iterations
        current_stage = self.get_current_stage()

        # Find stage boundaries
        stage_list = list(self.stage_transitions.keys())
        stage_thresholds = list(self.stage_transitions.values())

        current_stage_idx = stage_list.index(current_stage)

        if current_stage_idx == 0:
            stage_start = 0.0
        else:
            stage_start = stage_thresholds[current_stage_idx - 1]

        stage_end = stage_thresholds[current_stage_idx]
        stage_duration = stage_end - stage_start

        if stage_duration == 0:
            return 1.0

        progress_in_stage = (overall_progress - stage_start) / stage_duration
        return min(progress_in_stage, 1.0)

    def _adapt_config_for_progress(self, base_config: CurriculumConfig,
                                  progress: float) -> CurriculumConfig:
        # Create a copy of the base config
        adapted_config = CurriculumConfig(
            stage=base_config.stage,
            set_bc_weight=base_config.set_bc_weight,
            topology_constraint_strength=base_config.topology_constraint_strength,
            property_reward_weight=base_config.property_reward_weight,
            exploration_bonus=base_config.exploration_bonus,
            hard_masking=base_config.hard_masking,
            max_over_connections=base_config.max_over_connections
        )

        # Apply progressive changes based on stage
        if base_config.stage == CurriculumStage.STRICT_RECONSTRUCTION:
            # Gradually reduce BC weight and constraint strength
            adapted_config.set_bc_weight = base_config.set_bc_weight * (1.0 - 0.3 * progress)
            adapted_config.topology_constraint_strength = base_config.topology_constraint_strength * (1.0 - 0.2 * progress)

        elif base_config.stage == CurriculumStage.SOFT_TOPOLOGY_LEARNING:
            # Gradually increase property rewards and exploration
            adapted_config.property_reward_weight = base_config.property_reward_weight + 0.2 * progress
            adapted_config.exploration_bonus = base_config.exploration_bonus + 0.1 * progress

        elif base_config.stage == CurriculumStage.TOPOLOGY_PROPERTY_OPTIMIZATION:
            # Balance topology and property optimization
            adapted_config.property_reward_weight = base_config.property_reward_weight + 0.2 * progress
            adapted_config.topology_constraint_strength = base_config.topology_constraint_strength * (1.0 - 0.3 * progress)

        elif base_config.stage == CurriculumStage.FREE_EXPLORATION:
            # Maximize exploration and property optimization
            adapted_config.exploration_bonus = base_config.exploration_bonus + 0.2 * progress
            adapted_config.property_reward_weight = min(0.7, base_config.property_reward_weight + 0.2 * progress)

        return adapted_config

    def step(self):
        self.current_iteration += 1

    def should_transition_stage(self) -> Tuple[bool, Optional[CurriculumStage]]:
        # Check if we should transition to next stage
        current_stage = self.get_current_stage()

        # If we just started or if stage changed from last check
        if not hasattr(self, '_last_stage'):
            self._last_stage = current_stage
            return True, current_stage

        if current_stage != self._last_stage:
            self._last_stage = current_stage
            return True, current_stage

        return False, None

    def get_progress_info(self) -> Dict[str, Any]:
        current_stage = self.get_current_stage()
        progress_in_stage = self._get_progress_in_current_stage()
        overall_progress = self.current_iteration / self.total_iterations

        return {
            'current_iteration': self.current_iteration,
            'total_iterations': self.total_iterations,
            'overall_progress': overall_progress,
            'current_stage': current_stage.name,
            'progress_in_stage': progress_in_stage,
            'stage_config': self.get_current_config()
        }


class CurriculumLearning:
    def __init__(self, scheduler: CurriculumScheduler):
        self.scheduler = scheduler
        self.stage_metrics = {stage: [] for stage in CurriculumStage}
        self.stage_transitions_log = []

    def update_trainer_config(self, trainer, config: CurriculumConfig):
        # Update MAPPO trainer parameters based on curriculum config
        trainer.set_bc_weight = config.set_bc_weight
        trainer.entropy_coef = config.exploration_bonus * 0.1  # Scale exploration bonus to entropy

    def update_reward_system(self, reward_system, config: CurriculumConfig):
        # Update reward system weights based on curriculum config
        # Topology weight decreases as we move to later stages
        topology_weight = 0.6 * config.topology_constraint_strength
        chemical_weight = 0.4
        property_weight = config.property_reward_weight

        # Normalize weights
        total_weight = topology_weight + chemical_weight + property_weight
        if total_weight > 0:
            topology_weight /= total_weight
            chemical_weight /= total_weight
            property_weight /= total_weight

        reward_system.topological_weight = topology_weight
        reward_system.chemical_weight = chemical_weight
        # Add property weight to chemical weight for now
        reward_system.chemical_weight += property_weight

    def update_environment_constraints(self, env, config: CurriculumConfig):
        # Update environment constraints based on curriculum config
        if hasattr(env, 'topology_constraints'):
            env.topology_constraints.allow_over_connections = config.max_over_connections > 0
            env.topology_constraints.max_over_connections = config.max_over_connections

    def apply_action_masking(self, valid_actions, state, config: CurriculumConfig):
        # Apply curriculum-specific action masking
        if not config.hard_masking:
            return valid_actions

        # In strict reconstruction mode, only allow target edges
        if (config.stage == CurriculumStage.STRICT_RECONSTRUCTION and
            state.target_graph and state.mode == "reconstruction"):

            target_edges = set(state.target_graph.graph.edges())
            filtered_actions = []

            for action in valid_actions:
                if action.is_stop_action():
                    filtered_actions.append(action)
                else:
                    proposed_edge = (action.source_motif, action.target_motif)
                    reverse_edge = (action.target_motif, action.source_motif)

                    if proposed_edge in target_edges or reverse_edge in target_edges:
                        filtered_actions.append(action)

            return filtered_actions

        return valid_actions

    def log_stage_metrics(self, stage: CurriculumStage, metrics: Dict[str, float]):
        self.stage_metrics[stage].append(metrics)

    def get_curriculum_statistics(self) -> Dict[str, Any]:
        stats = {
            'current_stage_info': self.scheduler.get_progress_info(),
            'stage_transitions': len(self.stage_transitions_log),
            'stage_metrics_summary': {}
        }

        # Summarize metrics for each stage
        for stage, metrics_list in self.stage_metrics.items():
            if metrics_list:
                # Average metrics across episodes in each stage
                avg_metrics = {}
                for key in metrics_list[0].keys():
                    values = [m[key] for m in metrics_list if key in m]
                    avg_metrics[key] = np.mean(values) if values else 0.0

                stats['stage_metrics_summary'][stage.name] = {
                    'episodes': len(metrics_list),
                    'avg_metrics': avg_metrics
                }

        return stats

    def should_adapt_difficulty(self, recent_success_rate: float,
                              target_success_rate: float = 0.7) -> bool:
        # Determine if curriculum should be adapted based on performance
        return recent_success_rate > target_success_rate

    def get_stage_difficulty_metrics(self, stage: CurriculumStage) -> Dict[str, float]:
        # Calculate difficulty metrics for a stage
        if stage not in self.stage_metrics or not self.stage_metrics[stage]:
            return {}

        recent_metrics = self.stage_metrics[stage][-10:]  # Last 10 episodes

        difficulty_metrics = {}

        if recent_metrics:
            # Success rate
            success_rates = [m.get('success', 0) for m in recent_metrics]
            difficulty_metrics['success_rate'] = np.mean(success_rates)

            # Episode length (longer episodes might indicate difficulty)
            episode_lengths = [m.get('episode_length', 0) for m in recent_metrics]
            difficulty_metrics['avg_episode_length'] = np.mean(episode_lengths)

            # Reward trends
            rewards = [m.get('total_reward', 0) for m in recent_metrics]
            difficulty_metrics['avg_reward'] = np.mean(rewards)
            difficulty_metrics['reward_std'] = np.std(rewards)

        return difficulty_metrics