"""
This file defines the Routing-Optimized Group Relative Policy Optimization (RO-GRPO)
Trainer. It extends the base GRPOTrainer to incorporate a mechanism-aware reward
signal derived from the internal routing statistics of a LoRA-MoE model.

The core logic is encapsulated within the `ROGRPOTrainer` class, which overrides
key methods to collect routing data and integrate it into the reward scoring process.
"""

import gc
import inspect
from typing import Any, Dict, List, Optional, Tuple

import torch

from swift.utils import get_logger
from .grpo_trainer import GRPOTrainer, InputsType  # Assumes grpo_trainer.py is in the same directory

try:
    from .layer import LoraLayer
    HAS_LORAMOE_LAYER = True
except ImportError:
    HAS_LORAMOE_LAYER = False
    LoraLayer = None  # Define as None if not available

logger = get_logger()


class ROGRPOTrainer(GRPOTrainer):
    """
    Routing-Optimized GRPO Trainer.

    This trainer extends the standard GRPOTrainer to support mechanism-aware
    fine-tuning of LoRA-MoE models. It introduces a reward component based on
    the model's internal expert routing behavior, encouraging both balanced
    expert utilization and confident routing decisions, as described in the paper.
    """

    def __init__(self, *args, **kwargs):
        """
        Initializes the ROGRPOTrainer.

        Checks for the presence of the custom LoRA-MoE layer and warns if it's
        not available, as routing-aware features will be disabled.
        """
        super().__init__(*args, **kwargs)
        if not HAS_LORAMOE_LAYER:
            logger.warning(
                "The custom `LoraLayer` for LoRA-MoE could not be imported. "
                "RO-GRPO's routing-aware reward features will be disabled. "
                "Please ensure `layer.py` is in the correct path."
            )

    def _collect_routing_stats_after_generation(self) -> List[Dict[str, torch.Tensor]]:
        """
        Collects routing statistics from all LoRA-MoE layers after generation.

        This method iterates through the model's modules, identifies the custom
        `LoraLayer` instances, and extracts the stored routing weights for each
        sample in the batch. The statistics are cleared from the layers after
        collection to prevent memory accumulation.

        Returns:
            A list of dictionaries, where each dictionary corresponds to a single
            sample in the generation batch and maps layer names to their
            respective routing weight tensors.
        """
        if not HAS_LORAMOE_LAYER:
            return []

        sample_stats = []
        unwrapped_model = self.accelerator.unwrap_model(self.model)

        # Find the number of samples from the first available LoRA-MoE layer
        num_samples = 0
        for module in unwrapped_model.modules():
            if isinstance(module, LoraLayer):
                routing_weights = getattr(module, 'routing_weights_for_generation', None)
                if routing_weights:
                    num_samples = len(routing_weights)
                    break
        
        if num_samples == 0:
            logger.debug("No routing statistics found in any LoRA-MoE layer.")
            return []

        # Initialize the list of dictionaries for all samples
        sample_stats = [{} for _ in range(num_samples)]

        # Iterate again to collect stats and clear them
        for name, module in unwrapped_model.named_modules():
            if isinstance(module, LoraLayer):
                # Retrieve the list of routing weights for all samples from this layer
                all_sample_weights = getattr(module, 'routing_weights_for_generation', None)
                if not all_sample_weights:
                    continue

                # Ensure the number of samples is consistent across layers
                if len(all_sample_weights) != num_samples:
                    logger.warning(
                        f"Inconsistent number of samples in layer {name} "
                        f"({len(all_sample_weights)}) vs expected ({num_samples}). Skipping layer."
                    )
                    continue

                # Process stats for each sample
                for sample_idx, sample_weights_list in enumerate(all_sample_weights):
                    if not sample_weights_list:
                        continue
                    
                    # Concatenate all routing tensors for a single sample into one tensor
                    try:
                        concatenated_weights = torch.stack(sample_weights_list, dim=0)
                        sample_stats[sample_idx][name] = concatenated_weights.cpu()
                    except Exception as e:
                        logger.error(f"Error concatenating routing weights for sample {sample_idx} in layer {name}: {e}")

                # Clear the collected statistics from the layer to free memory
                module.clear_routing_stats()

        logger.debug(f"Collected routing statistics for {len(sample_stats)} samples.")
        return sample_stats

    def _score_completions(self,
                           inputs: InputsType,
                           routing_stats: Optional[List[Dict[str, torch.Tensor]]] = None
                           ) -> Tuple[torch.Tensor, torch.Tensor, List[str]]:
        """
        Scores completions using reward functions, with added support for routing-aware rewards.

        This method overrides the base implementation to pass the collected `routing_stats`
        to any reward function that is designed to accept it.

        Args:
            inputs: A list of input dictionaries, each containing the prompt and completion.
            routing_stats: A list of dictionaries containing routing statistics for each sample.

        Returns:
            A tuple containing:
            - total_rewards_per_func: A tensor of rewards from each reward function.
            - total_rewards: A tensor of the final weighted rewards.
            - completions: A list of the generated completion strings.
        """
        device = self.accelerator.device
        completions = [example['messages'][-1]['content'] for example in inputs]
        rewards_per_func = torch.zeros((len(inputs), len(self.reward_funcs)), device=device)

        # Prepare kwargs that might be needed by various reward functions
        reward_kwargs = {
            'trainer_state': self.state,
            'global_step': self.state.global_step,
            'max_steps': self.state.max_steps if hasattr(self.state, 'max_steps') else None,
        }

        for i, (reward_func, reward_model_plugin, reward_func_name) in enumerate(
                zip(self.reward_funcs, self.reward_model_plugins, self.reward_func_names)):

            if isinstance(reward_func, torch.nn.Module):
                # Standard reward model scoring
                output_reward_func = reward_model_plugin(inputs=inputs)
            else:
                # Custom reward function (potentially routing-aware)
                sig = inspect.signature(reward_func)
                func_kwargs = {k: v for k, v in reward_kwargs.items() if k in sig.parameters}

                # Check if the reward function is designed to handle routing statistics
                if 'routing_stats' in sig.parameters:
                    if routing_stats and len(routing_stats) == len(inputs):
                        # Provide per-sample routing statistics
                        per_sample_rewards = []
                        for idx, completion in enumerate(completions):
                            sample_kwargs = func_kwargs.copy()
                            sample_kwargs['routing_stats'] = routing_stats[idx]
                            try:
                                # Reward functions are expected to handle a list of completions
                                r = reward_func([completion], **sample_kwargs)
                                per_sample_rewards.append(r[0] if isinstance(r, list) and r else 0.0)
                            except Exception as e:
                                logger.error(f"Error in reward_func '{reward_func_name}' for sample {idx}: {e}")
                                per_sample_rewards.append(0.0)
                        output_reward_func = per_sample_rewards
                    else:
                        # If stats are needed but not available, provide a default (e.g., zero reward)
                        logger.warning(f"Reward function '{reward_func_name}' expects routing_stats, but none were provided.")
                        output_reward_func = [0.0] * len(inputs)
                else:
                    # Standard reward function call
                    try:
                        output_reward_func = reward_func(completions, **func_kwargs)
                    except Exception as e:
                        logger.error(f"Error in reward_func '{reward_func_name}': {e}")
                        output_reward_func = [0.0] * len(inputs)
            
            # Handle potential None values from reward functions
            output_reward_func = [reward if reward is not None else torch.nan for reward in output_reward_func]
            rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)

        # Gather rewards from all processes
        total_rewards_per_func = gather(rewards_per_func)
        total_rewards = (total_rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1)

        return total_rewards_per_func, total_rewards, completions

    def _generate_and_score_completions(self, inputs: InputsType) -> InputsType:
        """
        Orchestrates the full process of generating, scoring, and preparing data for a batch.

        This overridden method integrates the routing statistics collection step into the
        standard generation and scoring pipeline.
        """
        # Step 1: Generate completions for the input prompts.
        inputs_with_completions = self._generate_completions(inputs)

        # Step 2: Collect routing statistics from LoRA-MoE layers.
        routing_stats = self._collect_routing_stats_after_generation()

        try:
            # Step 3: Score the completions, passing the routing stats to the scoring function.
            total_rewards_per_func, total_rewards, completions = self._score_completions(
                inputs_with_completions, routing_stats=routing_stats
            )

            mode = 'train' if self.model.training else 'eval'

            # Step 4 (Optional): Resample if there are groups with zero reward variance.
            if self.args.dynamic_sample and mode == 'train':
                inputs_with_completions, total_rewards, total_rewards_per_func, completions = \
                    self._dynamic_sampling(inputs_with_completions, total_rewards, total_rewards_per_func, completions)

            # Step 5: Prepare the final batch with advantages, log probabilities, etc.
            batch_encoded_inputs = self._prepare_batch_inputs(inputs_with_completions, total_rewards)
            
            # Step 6: Log all relevant metrics and textual data.
            messages = [inp['messages'][:-1] for inp in inputs_with_completions]
            self._log_metrics(batch_encoded_inputs, messages, completions, total_rewards, total_rewards_per_func)

        finally:
            # Ensure memory from routing statistics is released.
            if routing_stats:
                del routing_stats
                gc.collect()

        return batch_encoded_inputs