import math
from typing import Optional, Tuple

import torch
import torch.nn.functional as F

from moe_peft.common import Linear, LLMMoeBlock, LLMFeedForward

import pdb

from .config import LoraMoeConfig

def _m2lora_load_balancing_loss_func(
    activations: torch.Tensor,
    num_experts: int,
    top_k: int,
    attention_mask
) -> float:

    """
    activations : L, N, B, T
    """

    L, N, B, T = activations.shape

    num_experts = N

    routing_weights = activations # L, N, B, T
    routing_weights = routing_weights.permute(0,2,3,1) # L, B, T, N
    routing_weights = routing_weights / (routing_weights.sum(-1, keepdim=True) + 1e-8) # L, B, T, N
    _, selected_experts = torch.topk(routing_weights, k=top_k, dim=-1)
    expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)

    #pdb.set_trace()

    #routing_weights = torch.nn.functional.softmax(gate_logits, dim=-1)
    #_, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
    #expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)

    if attention_mask is None:
        # Compute the percentage of tokens routed to each experts
        tokens_per_expert = torch.mean(expert_mask.float(), dim=0)

        # Compute the average probability of routing to these experts
        router_prob_per_expert = torch.mean(routing_weights, dim=0)
    else:
        batch_size, sequence_length = attention_mask.shape
        num_hidden_layers = L#routing_weights.shape[0] // (batch_size * sequence_length)

        # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
        expert_attention_mask = (
            attention_mask[None, :, :, None, None]
            .expand(
                (num_hidden_layers, batch_size, sequence_length, top_k, num_experts)
            )
            .reshape(-1, top_k, num_experts)
            .to(routing_weights.device)
        )

        # Compute the percentage of tokens routed to each experts
        tokens_per_expert = torch.sum(
            expert_mask.flatten(0,2).float() * expert_attention_mask, dim=0
        ) / torch.sum(expert_attention_mask, dim=0)

        # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
        router_per_expert_attention_mask = (
            attention_mask[None, :, :, None]
            .expand((num_hidden_layers, batch_size, sequence_length, num_experts))
            .reshape(-1, num_experts)
            .to(routing_weights.device)
        )

        # Compute the average probability of routing to these experts
        router_prob_per_expert = torch.sum(
            routing_weights.flatten(0,2) * router_per_expert_attention_mask, dim=0
        ) / torch.sum(router_per_expert_attention_mask, dim=0)

    overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
    return overall_loss * num_experts

class SelLoss(torch.nn.Module):
    def __init__(self, config) -> None:
        super().__init__()
        self.aux_loss_coef = config.router_aux_loss_coef_
        self.experts = config.num_experts_
        self.topk = config.topk

    def forward(self, activations, attention_mask) -> torch.Tensor:
        """
        L_{sel} = -ln[
            \sum_{j\in top-K}\sum \frac{B_jA_j}{||B_jA_j||_c}x / \sum_n\sum\frac{B_nA_n}{||B_nA_n||_c}x
        ]
        activations: N_{Layers}, N_{experts}, Bs, Tokens (L, N, B, T)
        """
        # A very important issue: I think in this loss, 
        # the hidden state should not backward loss to the former layers
        import pdb
        #pdb.set_trace()
        L, N, B, T = activations.shape
        #activations = torch.exp(activations)
        activations = activations + 1
        top_norm, _ = activations.topk(k=self.topk, dim=1) # L, K, B, T
        numerator = top_norm.sum(1) # L, B, T
        denominator = activations.sum(1) # L, B, T

        orig_loss = numerator / (denominator + 1e-8) * attention_mask[None, ...]
        orig_loss = orig_loss.sum() / (attention_mask.sum() * L)
        #pdb.set_trace()
        return -self.aux_loss_coef * torch.log(orig_loss) + self.aux_loss_coef * _m2lora_load_balancing_loss_func(
            activations, self.experts, self.topk, attention_mask)
        """
        return self.aux_loss_coef * _mixtral_load_balancing_loss_func(
            gate_logits, self.experts, self.topk, attention_mask
        )
        """

class LoraMoe(LLMMoeBlock):
    def __init__(
        self,
        in_features: int,
        device: torch.device,
        config: LoraMoeConfig,
        gate: Optional[torch.Tensor] = None,
    ) -> None:
        super().__init__()

        self.adapter_name_: str = config.adapter_name
        self.dtype_: torch.dtype = torch.float32
        self.gate_ = torch.nn.Linear(
            in_features,
            config.num_experts_,
            bias=False,
            device=device,
            dtype=torch.float32,
        )
        self.experts_ = config.num_experts_
        self.router_logits_: torch.Tensor = None
        self.shared_experts_ = config.shared_experts_

        print("=============================")
        print(self.shared_experts_)
        print("=============================")

        self.config = config

        if gate is None:
            torch.nn.init.kaiming_uniform_(
                self.gate_.weight, a=math.sqrt(config.router_init_range_)
            )
        else:
            with torch.no_grad():
                self.gate_.weight.copy_(gate)
    """
    y = Wx + \alpha \times m_{lora} \otimes \sum_i \frac{B_iA_i}{||B_iA_i||_c}
    """
    def forward(
        self,
        residual: torch.Tensor,
        hidden_states: torch.Tensor,
        lora_linear: Optional[Linear] = None,
    ) -> Tuple:

        m2lora = False

        example_lora = lora_linear.loras_[
            f"moe.{self.adapter_name_}.experts.0"
        ]


        if example_lora.use_mlora_:
            # M2LoRA Pipeline
            
            assert lora_linear is not None
            router_logits = self.gate_(hidden_states.to(self.dtype_))
            self.router_logits_ = router_logits.reshape(-1, self.experts_).detach()

            #import pdb
            #pdb.set_trace()
            # =====================================================================
            #                     excecute shared experts
            # =====================================================================
            for expert_idx in range(self.shared_experts_):
                expert_lora = lora_linear.loras_[
                        f"moe.{self.adapter_name_}.experts.{expert_idx}"
                ]
                expert_lora.use_mlora_ = False # to utilize general lora forward, set the m2lora to false
                hsr = expert_lora.lora_forward(hidden_states)
                residual = residual + hsr.to(hidden_states.dtype)
                expert_lora.use_mlora_ = True # put it back to True

            # =====================================================================
            #                     shared experts excecuted
            # =====================================================================

            activations = []

            out_features = example_lora.out_features_

            shared_magnititude = self.gate_.weight.view(-1)[:out_features]

            if self.config.router_loss_:

                for expert_idx in range(self.shared_experts_, self.experts_):
                    expert_lora = lora_linear.loras_[
                        f"moe.{self.adapter_name_}.experts.{expert_idx}"
                    ]
                    hsr, logit = expert_lora.lora_forward(hidden_states)
                    #hsr, *_ = expert_lora.lora_forward(hidden_states)
                    activations.append(logit)
                    residual = residual + hsr.to(hidden_states.dtype) * shared_magnititude.to(hidden_states.dtype)
                #pdb.set_trace()
                return residual, torch.stack(activations)

            else:
                for expert_idx in range(self.shared_experts_, self.experts_):
                    expert_lora = lora_linear.loras_[
                        f"moe.{self.adapter_name_}.experts.{expert_idx}"
                    ]
                    #hsr, logit = expert_lora.lora_forward(hidden_states)
                    hsr, *_ = expert_lora.lora_forward(hidden_states)
                    #activations.append(logit)
                    residual = residual + hsr.to(hidden_states.dtype) * shared_magnititude.to(hidden_states.dtype)
                #pdb.set_trace()
                return residual


        else:

            # LoRAMoE Pipeline
            #out_features = example_lora.out_features_
            #shared_magnititude = self.gate_.weight.view(-1)[:out_features] # for validation of code, eliminate afterward

            assert lora_linear is not None
            router_logits = self.gate_(hidden_states.to(self.dtype_))
            self.router_logits_ = router_logits.reshape(-1, self.experts_)
            routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float32)

            for expert_idx in range(self.experts_):
                expert_lora = lora_linear.loras_[
                    f"moe.{self.adapter_name_}.experts.{expert_idx}"
                ]
                residual = residual + (
                    torch.unsqueeze(routing_weights[:, :, expert_idx], -1)
                    * expert_lora.lora_forward(hidden_states)
                ).to(hidden_states.dtype) * shared_magnititude.to(hidden_states.dtype)
            pdb.set_trace()
            return residual

class M2LoRA(LLMMoeBlock):
    def __init__(
        self,
        in_features: int,
        device: torch.device,
        config: LoraMoeConfig,
        gate: Optional[torch.Tensor] = None,
    ) -> None:
        super().__init__()

        self.adapter_name_: str = config.adapter_name
        self.dtype_: torch.dtype = torch.float32
        self.gate_ = torch.nn.Linear(
            in_features,
            config.num_experts_,
            bias=False,
            device=device,
            dtype=torch.float32,
        )
        self.experts_ = config.num_experts_
        self.router_logits_: torch.Tensor = None

        if gate is None:
            torch.nn.init.kaiming_uniform_(
                self.gate_.weight, a=math.sqrt(config.router_init_range_)
            )
        else:
            with torch.no_grad():
                self.gate_.weight.copy_(gate)
    """
    y = Wx + \alpha \times m_{lora} \otimes \sum_i \frac{B_iA_i}{||B_iA_i||_c}
    """
    def forward(
        self,
        residual: torch.Tensor,
        hidden_states: torch.Tensor,
        lora_linear: Optional[Linear] = None,
    ) -> Tuple:

        example_lora = lora_linear.loras_[
            f"moe.{self.adapter_name_}.experts.0"
        ]
        assert example_lora.use_mlora_, "M2LoRA must set 'use_mlora' to True"


        assert lora_linear is not None
        router_logits = self.gate_(hidden_states.to(self.dtype_))
        self.router_logits_ = router_logits.reshape(-1, self.experts_).detach()

        activations = []

        out_features = example_lora.out_features_

        shared_magnititude = self.gate_.weight.view(-1)[:out_features]

        for expert_idx in range(self.experts_):
            expert_lora = lora_linear.loras_[
                f"moe.{self.adapter_name_}.experts.{expert_idx}"
            ]
            #hsr, logit = expert_lora.lora_forward(hidden_states)
            hsr, logit = expert_lora.lora_forward(hidden_states)
            activations.append(logit)
            residual = residual + hsr.to(hidden_states.dtype) * shared_magnititude.to(hidden_states.dtype)
        #pdb.set_trace()
        return residual, torch.stack(activations)
