import dataclasses
from typing import List

import torch
import torch.nn.functional as F
from simple_parsing.helpers import Serializable
from torch import nn


@dataclasses.dataclass
class MoeArgs(Serializable):
    num_shared_experts: int
    num_experts: int
    num_experts_per_tok: int
    gate_feature_dim: int


class MoeLayer(nn.Module):
    def __init__(self, experts: List[nn.Module], gate: nn.Module, moe_args: MoeArgs):
        super().__init__()
        assert len(experts) > 0
        self.experts = nn.ModuleList(experts)
        self.gate = gate
        self.args = moe_args

    def forward(self, inputs: torch.Tensor, gate_feature: torch.Tensor) -> torch.Tensor:
        # print(f"inputs shape: {inputs.shape}")
        # print(f"gate_feature shape: {gate_feature[:,-2,:].shape}")
        gate_logits = self.gate(gate_feature[:,-2,:].unsqueeze(1))
        weights, selected_experts = torch.topk(gate_logits, self.args.num_experts_per_tok)
        weights = F.softmax(weights, dim=1, dtype=torch.float).to(inputs.dtype)
        results = torch.zeros_like(inputs)
        for i, expert in enumerate(self.experts):
            batch_idx, seq_idx, nth_expert = torch.where(selected_experts == i)
            results[batch_idx, seq_idx] += weights[batch_idx, seq_idx, nth_expert, None] * expert(inputs[batch_idx, seq_idx])

        return results

class MoeLayer1(nn.Module):
    def __init__(self, shared_experts: List[nn.Module], unshared_experts: List[nn.Module],# gate: nn.Module,
                 num_experts_per_tok: int = 1):
        super().__init__()
        # assert len(shared_experts) > 0
        assert len(unshared_experts) > 0

        self.shared_experts = nn.ModuleList(shared_experts)
        self.unshared_experts = nn.ModuleList(unshared_experts)
        self.num_shared_experts = len(shared_experts)
        self.num_unshared_experts = len(unshared_experts)

        # self.gate = gate
        self.num_experts_per_tok = num_experts_per_tok

    def forward(self, inputs: torch.Tensor, use_expert_i: torch.Tensor = None) -> torch.Tensor:
        """
        inputs: (B, seq, D)
        """

        results = torch.zeros_like(inputs)

        weights, selected_experts = torch.topk(F.one_hot(use_expert_i,num_classes=self.num_experts), 1) #(B, 1)

        for i, expert in enumerate(self.unshared_experts):
            batch_idx, nth_expert = torch.where(selected_experts == i)

            if batch_idx.shape[0]==0:
                continue
            
            w=weights[batch_idx, nth_expert, None, None]
            e=expert(inputs[batch_idx])
            results[batch_idx] += w * e

        return results


        # # print(inputs.dtype)
        # B = inputs.shape[0]
        # # gate_logits = self.gate(inputs.mean(dim=1))  # (B, num_experts)
        # if use_expert_i is not None:
        #     # 转成 Tensor，并放到正确 device 上
        #     if not torch.is_tensor(use_expert_i):
        #         use_expert_i = torch.tensor(use_expert_i, dtype=torch.long, device=inputs.device)
        #     else:
        #         use_expert_i = use_expert_i.to(dtype=torch.long, device=inputs.device)

        #     # 如果是单个 int 或 (1,) Tensor，扩展成 batch
        #     if use_expert_i.ndim == 0:
        #         use_expert_i = use_expert_i.expand(B)
        #     elif use_expert_i.ndim == 2 and use_expert_i.shape[1] == 1:
        #         use_expert_i = use_expert_i.squeeze(1)

        #     # 断言最终形状 (B,)
        #     assert use_expert_i.shape[0] == B, \
        #         f"use_expert_i batch size mismatch: expected {B}, got {use_expert_i.shape}"

        #     # one-hot 编码 + topk
        #     one_hot_expert = F.one_hot(use_expert_i, num_classes=self.num_unshared_experts).float()
        #     weights, selected_experts = torch.topk(one_hot_expert, self.num_experts_per_tok, dim=-1)  # (B, k)
        # else:
        #     weights, selected_experts = torch.topk(gate_logits, self.num_experts_per_tok, dim=-1)  # (B, k)

        # weights = F.softmax(weights, dim=-1, dtype=torch.float).to(inputs.dtype)
        # results = torch.zeros_like(inputs)

        # # print("selected_experts:", selected_experts)

        # # for i, expert in enumerate(self.shared_experts):
        # #     results += 1 / (self.num_shared_experts + 1) * expert(inputs)

        # for i, expert in enumerate(self.unshared_experts):
        #     batch_idx, nth_expert = torch.where(selected_experts == i)

        #     if batch_idx.shape[0] == 0:
        #         continue

        #     w = weights[batch_idx, nth_expert, None, None]
        #     e = expert(inputs[batch_idx])
        #     results[batch_idx] += w * e

        # return results#, gate_logits