from typing import Dict, List, Optional, Any
from .chemical import ChemicalRewards
from .topological import TopologicalRewards
from ..environment.state import AssemblyState
from ..environment.actions import AssemblyAction


class RewardSystem:
    def __init__(self, chemical_weight: float = 0.4, topological_weight: float = 0.6,
                 chemical_config: Optional[Dict] = None, topological_config: Optional[Dict] = None):
        self.chemical_weight = chemical_weight
        self.topological_weight = topological_weight

        # Initialize reward components
        self.chemical_rewards = ChemicalRewards(weights=chemical_config)
        self.topological_rewards = TopologicalRewards(weights=topological_config)

        # Global reward history for analysis
        self.reward_history: List[Dict] = []

    def calculate_reward(self, prev_state: AssemblyState, action: AssemblyAction,
                        next_state: AssemblyState) -> float:
        # Calculate component rewards
        chemical_breakdown = self.chemical_rewards.get_reward_breakdown(prev_state, action, next_state)
        topological_breakdown = self.topological_rewards.get_reward_breakdown(prev_state, action, next_state)

        # Extract total rewards
        chemical_total = chemical_breakdown.get('total_chemical', 0.0)
        topological_total = topological_breakdown.get('total_topological', 0.0)

        # Calculate weighted total reward
        total_reward = (
            self.chemical_weight * chemical_total +
            self.topological_weight * topological_total
        )

        # Record detailed breakdown
        reward_record = {
            'step': next_state.step,
            'action': action.to_dict(),
            'chemical_breakdown': chemical_breakdown,
            'topological_breakdown': topological_breakdown,
            'chemical_total': chemical_total,
            'topological_total': topological_total,
            'total_reward': total_reward,
            'terminated': next_state.terminated
        }

        self.reward_history.append(reward_record)

        return total_reward

    def calculate_detailed_reward(self, prev_state: AssemblyState, action: AssemblyAction,
                                next_state: AssemblyState) -> Dict[str, Any]:
        # Get detailed breakdown of all reward components
        chemical_breakdown = self.chemical_rewards.get_reward_breakdown(prev_state, action, next_state)
        topological_breakdown = self.topological_rewards.get_reward_breakdown(prev_state, action, next_state)

        detailed_reward = {
            'chemical': chemical_breakdown,
            'topological': topological_breakdown,
            'weights': {
                'chemical_weight': self.chemical_weight,
                'topological_weight': self.topological_weight
            },
            'totals': {
                'chemical_total': chemical_breakdown.get('total_chemical', 0.0),
                'topological_total': topological_breakdown.get('total_topological', 0.0),
                'weighted_total': (
                    self.chemical_weight * chemical_breakdown.get('total_chemical', 0.0) +
                    self.topological_weight * topological_breakdown.get('total_topological', 0.0)
                )
            }
        }

        return detailed_reward

    def get_reward_statistics(self, recent_steps: int = 100) -> Dict[str, Any]:
        if not self.reward_history:
            return {}

        recent_rewards = self.reward_history[-recent_steps:]

        stats = {
            'total_steps': len(self.reward_history),
            'recent_steps_analyzed': len(recent_rewards),
            'mean_total_reward': sum(r['total_reward'] for r in recent_rewards) / len(recent_rewards),
            'mean_chemical_reward': sum(r['chemical_total'] for r in recent_rewards) / len(recent_rewards),
            'mean_topological_reward': sum(r['topological_total'] for r in recent_rewards) / len(recent_rewards),
            'reward_variance': self._calculate_variance([r['total_reward'] for r in recent_rewards]),
            'positive_reward_ratio': sum(1 for r in recent_rewards if r['total_reward'] > 0) / len(recent_rewards),
            'termination_reward_stats': self._get_termination_reward_stats(recent_rewards)
        }

        # Component-wise statistics
        stats['chemical_component_stats'] = self._get_component_stats(recent_rewards, 'chemical_breakdown')
        stats['topological_component_stats'] = self._get_component_stats(recent_rewards, 'topological_breakdown')

        return stats

    def _calculate_variance(self, values: List[float]) -> float:
        if len(values) < 2:
            return 0.0

        mean_val = sum(values) / len(values)
        variance = sum((x - mean_val) ** 2 for x in values) / len(values)
        return variance

    def _get_termination_reward_stats(self, rewards: List[Dict]) -> Dict[str, Any]:
        terminated_episodes = [r for r in rewards if r.get('terminated', False)]

        if not terminated_episodes:
            return {'num_terminations': 0}

        return {
            'num_terminations': len(terminated_episodes),
            'mean_termination_reward': sum(r['total_reward'] for r in terminated_episodes) / len(terminated_episodes),
            'termination_rate': len(terminated_episodes) / len(rewards)
        }

    def _get_component_stats(self, rewards: List[Dict], component_key: str) -> Dict[str, float]:
        component_stats = {}

        # Collect all component names
        all_components = set()
        for reward in rewards:
            if component_key in reward:
                all_components.update(reward[component_key].keys())

        # Calculate statistics for each component
        for component in all_components:
            values = []
            for reward in rewards:
                if component_key in reward and component in reward[component_key]:
                    values.append(reward[component_key][component])

            if values:
                component_stats[component] = {
                    'mean': sum(values) / len(values),
                    'count': len(values),
                    'positive_ratio': sum(1 for v in values if v > 0) / len(values)
                }

        return component_stats

    def analyze_reward_trends(self, window_size: int = 50) -> Dict[str, List[float]]:
        if len(self.reward_history) < window_size:
            return {}

        trends = {
            'total_reward': [],
            'chemical_reward': [],
            'topological_reward': []
        }

        for i in range(window_size, len(self.reward_history)):
            window = self.reward_history[i-window_size:i]

            trends['total_reward'].append(sum(r['total_reward'] for r in window) / window_size)
            trends['chemical_reward'].append(sum(r['chemical_total'] for r in window) / window_size)
            trends['topological_reward'].append(sum(r['topological_total'] for r in window) / window_size)

        return trends

    def suggest_reward_adjustments(self) -> Dict[str, Any]:
        if len(self.reward_history) < 100:
            return {'suggestion': 'Not enough data for suggestions'}

        recent_stats = self.get_reward_statistics()
        suggestions = {}

        # Chemical vs topological balance
        chemical_mean = recent_stats.get('mean_chemical_reward', 0)
        topological_mean = recent_stats.get('mean_topological_reward', 0)

        if abs(chemical_mean) > 2 * abs(topological_mean) and topological_mean != 0:
            suggestions['balance'] = f"Chemical rewards ({chemical_mean:.3f}) dominating topological ({topological_mean:.3f}). Consider adjusting weights."

        # Reward sparsity
        positive_ratio = recent_stats.get('positive_reward_ratio', 0)
        if positive_ratio < 0.3:
            suggestions['sparsity'] = f"Low positive reward ratio ({positive_ratio:.2%}). Consider making rewards less sparse."

        # Termination rewards
        termination_stats = recent_stats.get('termination_reward_stats', {})
        if termination_stats.get('num_terminations', 0) < 5:
            suggestions['termination'] = "Few terminations observed. Check if episodes are terminating properly."

        return suggestions

    def reset_history(self):
        # Clear reward history (useful for new training phases)
        self.reward_history.clear()

    def set_weights(self, chemical_weight: float, topological_weight: float):
        # Update reward component weights
        self.chemical_weight = chemical_weight
        self.topological_weight = topological_weight

    def get_current_config(self) -> Dict[str, Any]:
        return {
            'chemical_weight': self.chemical_weight,
            'topological_weight': self.topological_weight,
            'chemical_component_weights': self.chemical_rewards.weights,
            'topological_component_weights': self.topological_rewards.weights
        }