import copy
import torch
import pickle
import torch.nn as nn
import torch.nn.functional as F
from transformers.activations import ACT2FN


class TokenHashMoeBlock(nn.Module): ### Test
    def __init__(self, config, expert):
        """
        Token level MoE, hash balance
        """
        super().__init__()
        self.num_experts = config.num_experts
        self.top_k = config.topk
        if config.hash_list_path is not None:
            with open(config.hash_list_path, "rb") as file:
                result = pickle.load(file)
            self.hash_list = torch.tensor(result, dtype=torch.int64)
        else:
            self.hash_list = torch.randint(0, self.num_experts, (config.vocab_size,))
        self.experts = nn.ModuleList([copy.deepcopy(expert(config)) for _ in range(self.num_experts)])

    def forward(self, hidden_states: torch.Tensor, router_labels: torch.tensor) -> torch.Tensor:

        batch_size, seq_len, hidden_dim = hidden_states.size()

        hidden_states = hidden_states.view(-1, hidden_dim)
        self.hash_list = self.hash_list.to(hidden_states.device)
        gate = self.hash_list[router_labels.view(-1)]

        order = gate.argsort(0)
        num_tokens = F.one_hot(gate, self.num_experts).gt(0).sum(0)
        hidden_states = hidden_states[order]  # reorder according to expert number
        hidden_states = hidden_states.split(num_tokens.tolist(), dim=0)  # a list of length self.num_experts

        hidden_states = [self.experts[i].forward(hidden_states[i]) for i in range(self.num_experts)]
        hidden_states = torch.vstack(hidden_states)
        hidden_states = hidden_states[order.argsort(0)]  # restore original order
        hidden_states = hidden_states.view(batch_size, seq_len, hidden_dim)
        return hidden_states, None