"""
This file defines the custom mechanism-aware reward functions for RO-GRPO,
as described in the paper "Balancing the Experts: Unlocking LoRA-MoE for GRPO
via Mechanism-Aware Rewards".

Two reward strategies are implemented:
1.  RoutingRewardSmooth: Implements the curriculum-based reward scheduling,
    corresponding to the "RO-GRPO (Smooth)" method.
2.  RoutingRewardRelative: Implements the relative improvement gating mechanism,
    corresponding to the "RO-GRPO (Relative)" method.
"""

import math
from collections import deque
from typing import Dict, List, Optional

import numpy as np
import torch
from swift.llm.plugin import ORM, orms
from swift.utils import get_logger

logger = get_logger()


class RoutingRewardSmooth(ORM):
    """
    Implements the curriculum-based reward scheduling for RO-GRPO (Smooth).

    This reward function uses a sigmoid-based curriculum to smoothly transition
    from an entropy-based reward (encouraging confident routing) early in
    training to a load-balancing reward (encouraging uniform expert usage)
    later in training.
    """

    def __init__(self,
                 num_experts: int = 2,
                 lambda_H_start: float = 0.5,
                 lambda_bal_end: float = 2.0,
                 sigmoid_center: float = 0.5,
                 sigmoid_steepness: float = 10.0,
                 normalize_metrics: bool = True,
                 log_level: str = "info",
                 **kwargs):
        """
        Initializes the smooth routing reward module.

        Args:
            num_experts: The number of experts in the LoRA-MoE layers.
            lambda_H_start: The initial weight for the entropy penalty.
            lambda_bal_end: The final weight for the load balancing (MSE) penalty.
            sigmoid_center: The training progress (0 to 1) where the curriculum transition is centered.
            sigmoid_steepness: Controls how sharp the transition is.
            normalize_metrics: Whether to normalize entropy and MSE to a [0, 1] range.
            log_level: The logging level for reward-related messages.
        """
        super().__init__(**kwargs)
        if num_experts <= 0:
            raise ValueError("num_experts must be a positive integer.")

        self.num_experts = num_experts
        self.lambda_H_start = lambda_H_start
        self.lambda_bal_end = lambda_bal_end
        self.sigmoid_center = sigmoid_center
        self.sigmoid_steepness = sigmoid_steepness
        self.normalize_metrics = normalize_metrics
        self.log_func = getattr(logger, log_level.lower(), logger.info)

        # Pre-calculate theoretical maximums for normalization
        self.max_entropy = math.log(self.num_experts) if self.num_experts > 1 else 1.0
        self.max_mse = (self.num_experts - 1.0) / (self.num_experts**2) if self.num_experts > 1 else 0.0
        self.target_usage = 1.0 / self.num_experts

        self.log_func(f"Initialized RoutingRewardSmooth (RO-GRPO Smooth).")
        self.log_func(f"  - Num Experts: {self.num_experts}")
        self.log_func(f"  - Curriculum: H_weight={lambda_H_start}->0, MSE_weight=0->{lambda_bal_end}")
        self.log_func(f"  - Sigmoid: center={sigmoid_center}, steepness={sigmoid_steepness}")

    def _calculate_scheduled_lambdas(self, global_step: int, max_steps: int) -> tuple[float, float]:
        """Calculates the current weights for entropy and balancing rewards based on training progress."""
        if max_steps <= 0:
            return 0.0, self.lambda_H_start

        progress = global_step / max_steps
        # Sigmoid value transitions from ~0 to ~1 as training progresses
        sigmoid_value = 1.0 / (1.0 + math.exp(-self.sigmoid_steepness * (progress - self.sigmoid_center)))

        current_lambda_bal = self.lambda_bal_end * sigmoid_value
        current_lambda_H = self.lambda_H_start * (1.0 - sigmoid_value)

        return current_lambda_bal, current_lambda_H

    def __call__(self,
                 completions: List[str],
                 routing_stats: Optional[Dict[str, torch.Tensor]] = None,
                 global_step: Optional[int] = None,
                 max_steps: Optional[int] = None,
                 **kwargs) -> List[float]:

        if global_step is None or max_steps is None:
            logger.warning("`global_step` or `max_steps` not provided. Using initial reward weights.")
            lambda_bal, lambda_H = 0.0, self.lambda_H_start
        else:
            lambda_bal, lambda_H = self._calculate_scheduled_lambdas(global_step, max_steps)

        if not routing_stats:
            return [0.0]

        # Aggregate statistics across all layers for the sample
        all_weights = [stats.float() for stats in routing_stats.values() if stats is not None and stats.dim() == 2]
        if not all_weights:
            return [0.0]

        # --- Metric Calculation ---
        # 1. Load Balancing MSE
        layer_mses = []
        if self.num_experts > 1:
            for weights in all_weights:
                if weights.shape[0] > 0:
                    expert_usage = weights.mean(dim=0)  # Average usage across tokens for this layer
                    mse = torch.mean((expert_usage - self.target_usage)**2).item()
                    layer_mses.append(mse)
        avg_mse = np.mean(layer_mses) if layer_mses else 0.0

        # 2. Routing Entropy
        # Concatenate all token weights from all layers into a single tensor
        flat_weights = torch.cat(all_weights, dim=0)
        clamped_weights = torch.clamp(flat_weights, min=1e-9)
        entropy_per_token = -torch.sum(clamped_weights * torch.log(clamped_weights), dim=-1)
        avg_entropy = entropy_per_token.mean().item()

        # --- Normalization and Reward Calculation ---
        norm_mse = avg_mse / self.max_mse if self.normalize_metrics and self.max_mse > 0 else avg_mse
        norm_entropy = avg_entropy / self.max_entropy if self.normalize_metrics and self.max_entropy > 0 else avg_entropy

        # Penalties are negative rewards
        reward_bal = -lambda_bal * norm_mse
        reward_entropy = -lambda_H * norm_entropy
        total_reward = reward_bal + reward_entropy

        self.log_func(
            f"[RO-GRPO Smooth] Step: {global_step}/{max_steps}, "
            f"NormMSE: {norm_mse:.4f}, NormEntropy: {norm_entropy:.4f}, "
            f"Reward: {total_reward:.4f} (Bal: {reward_bal:.4f}, H: {reward_entropy:.4f})"
        )

        return [total_reward]


class RoutingRewardRelative(ORM):
    """
    Implements the relative improvement gating for RO-GRPO (Relative).

    This reward function provides a sparse, positive reward only when both
    routing confidence (entropy) and load balance (MSE) simultaneously improve
    compared to their historical moving averages.
    """

    def __init__(self,
                 num_experts: int = 2,
                 history_size: int = 1000,
                 reward_value: float = 1.0,
                 normalize_metrics: bool = True,
                 log_level: str = "info",
                 **kwargs):
        """
        Initializes the relative routing reward module.

        Args:
            num_experts: The number of experts in the LoRA-MoE layers.
            history_size: The size of the moving window for historical averages.
            reward_value: The constant positive reward 'C' to grant on improvement.
            normalize_metrics: Whether to normalize entropy and MSE to a [0, 1] range.
            log_level: The logging level for reward-related messages.
        """
        super().__init__(**kwargs)
        if num_experts <= 0:
            raise ValueError("num_experts must be a positive integer.")

        self.num_experts = num_experts
        self.reward_value = reward_value
        self.normalize_metrics = normalize_metrics
        self.log_func = getattr(logger, log_level.lower(), logger.info)

        self.history_buffer = {
            'mse': deque(maxlen=history_size),
            'entropy': deque(maxlen=history_size)
        }

        # Pre-calculate theoretical maximums for normalization
        self.max_entropy = math.log(self.num_experts) if self.num_experts > 1 else 1.0
        self.max_mse = (self.num_experts - 1.0) / (self.num_experts**2) if self.num_experts > 1 else 0.0
        self.target_usage = 1.0 / self.num_experts

        self.log_func(f"Initialized RoutingRewardRelative (RO-GRPO Relative).")
        self.log_func(f"  - Num Experts: {self.num_experts}")
        self.log_func(f"  - History Size: {history_size}")
        self.log_func(f"  - Reward Value (C): {reward_value}")

    def _compute_routing_metrics(self, routing_stats: Dict[str, torch.Tensor]) -> tuple[float, float]:
        """Computes the average MSE and entropy for a single sample across all its layers."""
        if not routing_stats:
            return float('inf'), float('inf')

        all_weights = [stats.float() for stats in routing_stats.values() if stats is not None and stats.dim() == 2]
        if not all_weights:
            return float('inf'), float('inf')

        # MSE
        layer_mses = []
        if self.num_experts > 1:
            for weights in all_weights:
                if weights.shape[0] > 0:
                    expert_usage = weights.mean(dim=0)
                    mse = torch.mean((expert_usage - self.target_usage)**2).item()
                    layer_mses.append(mse)
        avg_mse = np.mean(layer_mses) if layer_mses else 0.0

        # Entropy
        flat_weights = torch.cat(all_weights, dim=0)
        clamped_weights = torch.clamp(flat_weights, min=1e-9)
        entropy_per_token = -torch.sum(clamped_weights * torch.log(clamped_weights), dim=-1)
        avg_entropy = entropy_per_token.mean().item()

        # Normalization
        norm_mse = avg_mse / self.max_mse if self.normalize_metrics and self.max_mse > 0 else avg_mse
        norm_entropy = avg_entropy / self.max_entropy if self.normalize_metrics and self.max_entropy > 0 else avg_entropy

        return norm_mse, norm_entropy

    def __call__(self,
                 completions: List[str],
                 routing_stats: Optional[Dict[str, torch.Tensor]] = None,
                 global_step: Optional[int] = None,
                 **kwargs) -> List[float]:

        if not routing_stats:
            return [0.0]

        # Calculate metrics for the current sample
        current_mse, current_entropy = self._compute_routing_metrics(routing_stats)
        reward = 0.0

        # Compare against historical average if buffer is full
        if len(self.history_buffer['mse']) == self.history_buffer['mse'].maxlen:
            ref_mse = np.mean(self.history_buffer['mse'])
            ref_entropy = np.mean(self.history_buffer['entropy'])

            # Grant positive reward only if both metrics improve (i.e., decrease)
            if current_mse < ref_mse and current_entropy < ref_entropy:
                reward = self.reward_value

            self.log_func(
                f"[RO-GRPO Relative] Step: {global_step}, "
                f"MSE: {current_mse:.4f} (vs {ref_mse:.4f}), "
                f"Entropy: {current_entropy:.4f} (vs {ref_entropy:.4f}), "
                f"Reward: {reward:.4f}"
            )
        else:
            self.log_func(
                f"[RO-GRPO Relative] Step: {global_step}, Populating history buffer "
                f"({len(self.history_buffer['mse'])}/{self.history_buffer['mse'].maxlen}). "
                f"MSE: {current_mse:.4f}, Entropy: {current_entropy:.4f}"
            )

        # Update history buffer with current metrics
        if not (math.isinf(current_mse) or math.isinf(current_entropy)):
            self.history_buffer['mse'].append(current_mse)
            self.history_buffer['entropy'].append(current_entropy)

        return [reward]


# Register the custom reward functions
orms['ro_grpo_smooth'] = RoutingRewardSmooth
orms['ro_grpo_relative'] = RoutingRewardRelative