"""
This file presents a modification of the PEFT LoRA layer to implement a
Mixture-of-Experts (MoE) architecture, referred to as LoRA-MoE. This implementation
is simplified to support only `torch.nn.Linear` layers, focusing on the core
logic required for the experiments in the paper "Balancing the Experts: Unlocking
LoRA-MoE for GRPO via Mechanism-Aware Rewards".

Key modifications include:
- A trainable router network for each LoRA-MoE layer.
- Multiple LoRA experts whose outputs are weighted by the router's decision.
- A mechanism to collect routing statistics during the forward pass, which is
  essential for calculating the mechanism-aware rewards in RO-GRPO.
"""
from __future__ import annotations

import math
import warnings
from typing import Any, List, Optional, Union

import torch
import torch.nn as nn
import torch.nn.functional as F

from peft.tuners.lora.config import LoraConfig
from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge
from peft.utils.other import transpose


class LoraLayer(BaseTunerLayer):
    """
    A base layer for the LoRA-MoE implementation. It manages the creation,
    updating, and application of LoRA-MoE adapters to a `torch.nn.Linear` base layer.
    """

    # Names of adapter layers that are trainable
    adapter_layer_names = ("lora_A", "lora_B", "lora_router")
    # Other parameters that are not part of the model's state dict
    other_param_names = ("r", "lora_alpha", "scaling", "lora_dropout", "num_experts")

    def __init__(self, base_layer: nn.Module, **kwargs) -> None:
        super().__init__()
        self.base_layer = base_layer
        self.r = {}
        self.lora_alpha = {}
        self.scaling = {}
        self.lora_dropout = nn.ModuleDict({})
        self.num_experts = {}  # Store the number of experts for each adapter

        # LoRA-MoE components
        self.lora_A = nn.ModuleDict({})  # ModuleDict of ModuleLists of expert A matrices
        self.lora_B = nn.ModuleDict({})  # ModuleDict of ModuleLists of expert B matrices
        self.lora_router = nn.ModuleDict({})  # Router for each adapter

        # State variables
        self._disable_adapters = False
        self.merged_adapters = []
        self.lora_bias = {}

        # --- Mechanism for Routing Statistics Collection ---
        # This structure collects routing weights for each sample in a batch.
        # It is expected to be cleared by the trainer at the start of each generation.
        # Structure: List[List[torch.Tensor]], where the outer list corresponds to
        # samples in the batch, and the inner list contains routing weights from
        # each LoRA-MoE layer for that sample.
        self.routing_weights_for_generation: List[List[torch.Tensor]] = []
        # --- End of Statistics Collection Mechanism ---

        # Ensure the base layer is a Linear layer, as this implementation is simplified.
        if not isinstance(self.get_base_layer(), nn.Linear):
            raise TypeError(f"LoRA-MoE is only implemented for nn.Linear, but got {type(self.get_base_layer())}")

        self.in_features = self.get_base_layer().in_features
        self.out_features = self.get_base_layer().out_features

    def update_layer(
        self,
        adapter_name: str,
        r: int,
        lora_alpha: int,
        lora_dropout: float,
        init_lora_weights: Union[bool, str],
        use_rslora: bool,
        num_experts: int,  # Added to make number of experts configurable
        lora_bias: bool = False,
    ):
        """
        Update or create a LoRA-MoE adapter for the layer.

        Args:
            adapter_name (str): The name of the adapter to update.
            r (int): The rank of the LoRA experts.
            lora_alpha (int): The alpha parameter for scaling.
            lora_dropout (float): The dropout probability.
            init_lora_weights (Union[bool, str]): How to initialize weights.
            use_rslora (bool): Whether to use rank-stabilized LoRA scaling.
            num_experts (int): The number of experts to create for this adapter.
            lora_bias (bool): Whether to train a bias term for LoRA B experts.
        """
        if r <= 0:
            raise ValueError(f"`r` should be a positive integer, but got {r}")
        if num_experts <= 0:
            raise ValueError(f"`num_experts` should be a positive integer, but got {num_experts}")

        self.r[adapter_name] = r
        self.lora_alpha[adapter_name] = lora_alpha
        self.num_experts[adapter_name] = num_experts
        self.lora_bias[adapter_name] = lora_bias

        if lora_dropout > 0.0:
            self.lora_dropout[adapter_name] = nn.Dropout(p=lora_dropout)
        else:
            self.lora_dropout[adapter_name] = nn.Identity()

        # Create the router
        self.lora_router[adapter_name] = nn.Linear(self.in_features, num_experts, bias=False)

        # Create the experts
        lora_A_experts = nn.ModuleList()
        lora_B_experts = nn.ModuleList()
        for _ in range(num_experts):
            expert_A = nn.Linear(self.in_features, r, bias=False)
            expert_B = nn.Linear(r, self.out_features, bias=lora_bias)
            lora_A_experts.append(expert_A)
            lora_B_experts.append(expert_B)

        self.lora_A[adapter_name] = lora_A_experts
        self.lora_B[adapter_name] = lora_B_experts

        # Set scaling factor
        if use_rslora:
            self.scaling[adapter_name] = lora_alpha / math.sqrt(r)
        else:
            self.scaling[adapter_name] = lora_alpha / r

        # Initialize weights if requested
        if init_lora_weights:
            self.reset_lora_parameters(adapter_name, init_lora_weights)

        # Move the new adapter to the same device as the base layer
        self._move_adapter_to_device_of_base_layer(adapter_name)
        self.set_adapter(self.active_adapters)

    def reset_lora_parameters(self, adapter_name: str, init_lora_weights: Union[bool, str]):
        """
        Initialize the weights of a LoRA-MoE adapter.

        Args:
            adapter_name (str): The name of the adapter to initialize.
            init_lora_weights (Union[bool, str]): The initialization strategy.
        """
        if init_lora_weights is False:
            return

        # Initialize router
        if adapter_name in self.lora_router:
            nn.init.kaiming_uniform_(self.lora_router[adapter_name].weight, a=math.sqrt(5))

        # Initialize experts
        if adapter_name in self.lora_A:
            for expert_A, expert_B in zip(self.lora_A[adapter_name], self.lora_B[adapter_name]):
                if init_lora_weights is True:
                    # Default initialization
                    nn.init.kaiming_uniform_(expert_A.weight, a=math.sqrt(5))
                    nn.init.zeros_(expert_B.weight)
                elif isinstance(init_lora_weights, str) and init_lora_weights.lower() == "gaussian":
                    nn.init.normal_(expert_A.weight, std=1 / self.r[adapter_name])
                    nn.init.zeros_(expert_B.weight)
                else:
                    raise ValueError(f"Unknown initialization strategy: '{init_lora_weights}'")

                if hasattr(expert_B, 'bias') and expert_B.bias is not None:
                    nn.init.zeros_(expert_B.bias)

    def clear_routing_stats(self):
        """
        Clears the collected routing statistics. Should be called by the trainer
        before starting a new generation or batch processing.
        """
        self.routing_weights_for_generation.clear()

    def _get_delta_weight(self, adapter: str, expert_idx: int) -> torch.Tensor:
        """
        Compute the delta weight for a single expert.
        Note: This is used for merging, which is experimental for LoRA-MoE.
        """
        lora_A = self.lora_A[adapter][expert_idx]
        lora_B = self.lora_B[adapter][expert_idx]
        scaling = self.scaling[adapter]
        
        weight_A = lora_A.weight
        weight_B = lora_B.weight

        # Perform computation in float32 for stability, then cast back
        output_tensor = (weight_B.float() @ weight_A.float()) * scaling
        return output_tensor.to(lora_A.weight.dtype)

    def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
        """
        Forward pass for the LoRA-MoE layer.
        """
        # If adapters are disabled or merged, use the base layer's forward pass.
        if self.disable_adapters or self.merged:
            if self.merged:
                self.unmerge()
            return self.get_base_layer()(x, *args, **kwargs)

        # Calculate base layer output first
        base_layer = self.get_base_layer()
        result = base_layer(x, *args, **kwargs)
        result_dtype = result.dtype
        
        # This will hold the combined output from all active MoE adapters
        moe_delta_total = torch.zeros_like(result)

        # --- LoRA-MoE Adapter Application ---
        for active_adapter in self.active_adapters:
            if active_adapter not in self.lora_A:
                continue
            
            # Get components for the active adapter
            router = self.lora_router[active_adapter]
            lora_A_experts = self.lora_A[active_adapter]
            lora_B_experts = self.lora_B[active_adapter]
            dropout = self.lora_dropout[active_adapter]
            scaling = self.scaling[active_adapter]
            num_experts = self.num_experts[active_adapter]

            # 1. Get routing decisions
            router_input = x.to(router.weight.dtype)
            route_logits = router(router_input)
            route_weights = F.softmax(route_logits, dim=-1)

            # 2. Collect routing statistics if the collector is active
            if self.routing_weights_for_generation is not None:
                # The trainer should ensure this list is initialized per batch
                batch_size = route_weights.shape[0]
                if len(self.routing_weights_for_generation) != batch_size:
                    # Initialize list for each sample in the batch
                    self.routing_weights_for_generation.clear()
                    for _ in range(batch_size):
                        self.routing_weights_for_generation.append([])

                # Append the routing weights for each sample
                for i in range(batch_size):
                    # Detach to prevent gradients from flowing back through this path
                    self.routing_weights_for_generation[i].append(route_weights[i].detach())

            # 3. Calculate expert outputs and combine them
            x_dropped = dropout(x)
            current_moe_delta = torch.zeros_like(result)

            for i in range(num_experts):
                lora_A_i = lora_A_experts[i]
                lora_B_i = lora_B_experts[i]
                
                expert_input = x_dropped.to(lora_A_i.weight.dtype)
                expert_output = lora_B_i(lora_A_i(expert_input))

                # Weight the expert output by its corresponding routing weight
                weight_i = route_weights[..., i]
                # Ensure weight_i can be broadcasted to the expert_output shape
                if weight_i.dim() == expert_output.dim() - 1:
                    weight_i = weight_i.unsqueeze(-1)
                
                current_moe_delta += weight_i * expert_output.to(current_moe_delta.dtype)
            
            moe_delta_total += current_moe_delta * scaling

        # Add the final MoE delta to the base result
        result = result + moe_delta_total.to(result_dtype)
        return result


class Linear(LoraLayer, nn.Module):
    """
    LoRA-MoE implemented for a `torch.nn.Linear` layer.
    """
    def __init__(
        self,
        base_layer: nn.Linear,
        adapter_name: str,
        r: int = 0,
        lora_alpha: int = 1,
        lora_dropout: float = 0.0,
        init_lora_weights: Union[bool, str] = True,
        use_rslora: bool = False,
        lora_bias: bool = False,
        num_experts: int = 4,
        **kwargs,
    ) -> None:
        super().__init__(base_layer, **kwargs)
        
        self.update_layer(
            adapter_name,
            r,
            lora_alpha,
            lora_dropout,
            init_lora_weights,
            use_rslora,
            num_experts,
            lora_bias,
        )

    def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
        """
        Merge the LoRA-MoE adapter into the base layer.

        Warning: This is an experimental feature for LoRA-MoE. It merges only the
        weights of the *first expert* (index 0) into the base layer. This is a
        simplification and does not represent a true fusion of the MoE capacity.
        """
        adapter_names = check_adapters_to_merge(self, adapter_names)
        if not adapter_names:
            return

        for active_adapter in adapter_names:
            if active_adapter in self.lora_A:
                base_layer = self.get_base_layer()
                warnings.warn(
                    f"Merging LoRA-MoE adapter '{active_adapter}'. Only the first expert (idx=0) "
                    "will be merged. This is an experimental feature."
                )
                
                # Calculate delta weight from the first expert
                delta_weight = self._get_delta_weight(active_adapter, expert_idx=0)
                
                # Merge weights
                base_layer.weight.data += transpose(delta_weight, fan_in_fan_out=False)

                # Merge bias if applicable for the first expert
                if self.lora_bias.get(active_adapter, False) and hasattr(self.lora_B[active_adapter][0], "bias"):
                    expert_bias = self.lora_B[active_adapter][0].bias
                    if expert_bias is not None:
                        if base_layer.bias is None:
                            base_layer.bias = nn.Parameter(torch.zeros_like(expert_bias))
                        base_layer.bias.data += expert_bias.data
                
                self.merged_adapters.append(active_adapter)

    def unmerge(self) -> None:
        """
        Unmerge the LoRA-MoE adapter from the base layer.

        Warning: This function assumes the merged weights came from the *first expert*
        (index 0), consistent with the `merge` method's behavior.
        """
        if not self.merged:
            warnings.warn("Layer is not merged, cannot unmerge.")
            return

        while self.merged_adapters:
            active_adapter = self.merged_adapters.pop()
            if active_adapter in self.lora_A:
                base_layer = self.get_base_layer()
                
                # Subtract delta weight from the first expert
                delta_weight = self._get_delta_weight(active_adapter, expert_idx=0)
                base_layer.weight.data -= transpose(delta_weight, fan_in_fan_out=False)

                # Subtract bias if applicable
                if self.lora_bias.get(active_adapter, False) and hasattr(self.lora_B[active_adapter][0], "bias"):
                    expert_bias = self.lora_B[active_adapter][0].bias
                    if expert_bias is not None and base_layer.bias is not None:
                        base_layer.bias.data -= expert_bias.data

    def __repr__(self) -> str:
        rep = super().__repr__()
        return "lora_moe." + rep