import math
from typing import List

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import random
from torch.amp import custom_fwd, custom_bwd
from typing import Any, Dict, List, Optional
from torch import Tensor
class ParallelLinear(torch.autograd.Function):

    @staticmethod
    @custom_fwd(device_type='cuda') #(cast_inputs=torch.float32)
    def forward(ctx, input, expert_size, weight, bias=None):
        output = ParallelLinear.forward_scriptable(input, expert_size, weight, bias)
        # assert torch.allclose(ParallelLinear._forward_scriptable(input, expert_size, weight, bias),  output)
        ctx.save_for_backward(input, expert_size, weight, bias)
        return output

    @staticmethod
    @torch.jit.script
    def forward_scriptable(input: Tensor, expert_size: Tensor,
                           weight: Tensor, bias: Optional[Tensor]):
        output_buf: Tensor = torch.empty((input.size(0), weight.size(2)),
                                         device=input.device, dtype=input.dtype)
        num_linears = weight.size(0)
        expert_size_list: List[int] = expert_size.tolist()
        input_list = input.split(expert_size_list, dim=0)
        output_buf_list = output_buf.split(expert_size_list)
        
        num_used_linears = len(input_list)

        for i in range(min(num_linears, num_used_linears)):
            torch.mm(input_list[i], weight[i], out=output_buf_list[i])

        if bias is not None:
            for i in range(min(num_linears, num_used_linears)):
                output_buf_list[i].add_(bias[i])

        output = output_buf
        return output

    @staticmethod
    @custom_bwd(device_type='cuda')
    def backward(ctx, grad_out):
        input, expert_size, weight, bias = ctx.saved_tensors
        return ParallelLinear.backward_scriptable(
            grad_out, input, expert_size,
            weight, bias
        )

    @staticmethod
    @torch.jit.script
    def backward_scriptable(grad_out: Tensor,
                 input: Tensor, expert_size: Tensor,
                 weight: Tensor, bias: Optional[Tensor]):
        num_linears = weight.size(0)
        expert_size_list: List[int] = expert_size.tolist()
        input_list = input.t().split(expert_size_list, dim=1)
        grad_list = grad_out.split(expert_size_list, dim=0)
        
        num_used_linears = len(input_list)

        d_input_buf = torch.empty_like(input)
        d_input_buf_list = d_input_buf.split(expert_size_list, dim=0)
        d_weight_buf = torch.empty_like(weight)

        weight_t = weight.permute(0, 2, 1)

        for i in range(min(num_linears, num_used_linears)):
            torch.mm(grad_list[i], weight_t[i], out=d_input_buf_list[i])
            torch.mm(input_list[i], grad_list[i], out=d_weight_buf[i])

        d_input = d_input_buf
        d_weight = d_weight_buf

        if bias is not None:
            d_bias_buf = torch.empty_like(bias)
            for i in range(min(num_linears, num_used_linears)):
                torch.sum(grad_list[i], dim=0, keepdim=False, out=d_bias_buf[i])
            d_bias = d_bias_buf
        else:
            d_bias = None

        return d_input, None, d_weight, d_bias

class ParallelExperts(nn.Module):
    def __init__(self, num_experts, input_size, output_size, bias=False) -> None:
        super().__init__()
        self.weight = nn.Parameter(torch.empty(num_experts, input_size, output_size))
        if bias:
            self.bias = nn.Parameter(torch.zeros(num_experts, output_size))
        else:
            self.bias = None
        self.reset_parameters()

    def extra_repr(self):
        return 'num_experts={}, input_size={}, output_size={}'.format(
            self.weight.size(0), self.weight.size(1), self.weight.size(2))

    def reset_parameters(self) -> None:
        # std = math.sqrt(2.0 / float(self.weight.size(1) + self.weight.size(2)))
        # a = math.sqrt(3.0) * std
        nn.init.uniform_(self.weight, -1. / self.weight.size(1), 1. / self.weight.size(1))
        if self.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight[0])
            bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
            nn.init.uniform_(self.bias, -bound, bound)

    def forward(self, inputs, expert_size):
        # Use both old and new experts if new_weight exists
        if hasattr(self, 'new_weight'):
            weight = torch.cat([self.weight, self.new_weight], dim=0)
            bias = None
            if self.bias is not None and hasattr(self, 'new_bias'):
                bias = torch.cat([self.bias, self.new_bias], dim=0)
        else:
            weight = self.weight
            bias = self.bias
        
        results = ParallelLinear.apply(inputs, expert_size, weight, bias)
        return results
    
    
    def add_experts(self, n_extra_experts: int, freeze_old_experts: bool = True):
        """
        Adds `n_extra_experts` experts. Creates new_weight for training while keeping
        old experts frozen in weight.
        """
        old_num_experts = self.weight.size(0)
        device = self.weight.device
        dtype = self.weight.dtype
        input_size = self.weight.size(1)
        output_size = self.weight.size(2)

        # Create new experts parameter
        self.new_weight = nn.Parameter(
            torch.empty(n_extra_experts, input_size, output_size, 
                       device=device, dtype=dtype)
        )
        
        # Initialize new experts
        nn.init.uniform_(
            self.new_weight,
            -1.0 / input_size,
            1.0 / input_size
        )

        # Handle bias if present
        if self.bias is not None:
            self.new_bias = nn.Parameter(
                torch.empty(n_extra_experts, output_size,
                          device=device, dtype=dtype)
            )
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.new_weight[0])
            bound = 1.0 / math.sqrt(fan_in) if fan_in > 0 else 0
            nn.init.uniform_(self.new_bias, -bound, bound)
        
        # Freeze old parameters by setting requires_grad to False
        if freeze_old_experts:
            self.weight.requires_grad_(False)
            if self.bias is not None:
                self.bias.requires_grad_(False)
                
                
    def merge_experts(self):
        """
        Merges new_weight into weight and removes new_weight.
        Should be called after training new experts is complete.
        """
        if hasattr(self, 'new_weight'):
            # Merge weights
            merged_weight = torch.cat([self.weight.data, self.new_weight.data], dim=0)
            self.weight = nn.Parameter(merged_weight)
            
            # Merge biases if they exist
            if self.bias is not None and hasattr(self, 'new_bias'):
                merged_bias = torch.cat([self.bias.data, self.new_bias.data], dim=0)
                self.bias = nn.Parameter(merged_bias)
            
            # Remove new parameters
            if hasattr(self, 'new_weight'):
                del self.new_weight
            if hasattr(self, 'new_bias'):
                del self.new_bias
                
    # def state_dict(self, destination=None, prefix='', keep_vars=False):
    #     """
    #     Custom state dict to handle both old and new parameters correctly.
    #     """
    #     # Grab whatever nn.Module.state_dict() would normally give you
    #     state_dict = super().state_dict(destination=destination, 
    #                                     prefix=prefix, 
    #                                     keep_vars=keep_vars)
        
    #     # The lines below are often unnecessary if you used register_parameter.
    #     # But if you want to handle them manually or rename them, you can:
    #     if hasattr(self, 'new_weight'):
    #         state_dict[prefix + 'new_weight'] = (
    #             self.new_weight if keep_vars else self.new_weight.detach()
    #         )
    #     if hasattr(self, 'new_bias'):
    #         state_dict[prefix + 'new_bias'] = (
    #             self.new_bias if keep_vars else self.new_bias.detach()
    #         )
    #     return state_dict
    
    

    # def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
    #                          missing_keys, unexpected_keys, error_msgs):
    #     """Custom state dict loading to handle both old and new parameters."""
    #     # Check if state dict has new parameters
    #     # Create a copy of state_dict with just the relevant keys for this module
    #     filtered_dict = {}
    #     for key, value in state_dict.items():
    #         # Only keep keys that start with the prefix and are for this module
    #         if key.startswith(prefix):
    #             # Remove the prefix to match local names
    #             local_key = key[len(prefix):]
    #             filtered_dict[local_key] = value

    #     # Check if state dict has new parameters
    #     has_new_params = any('new_weight' in k for k in filtered_dict.keys())

    #     if has_new_params:
    #         # Create new parameters with same size as in state dict
    #         new_weight_key = next(k for k in filtered_dict.keys() if 'new_weight' in k)
    #         self.new_weight = nn.Parameter(torch.empty_like(filtered_dict[new_weight_key])).to(self.weight.device)
    #         if any('new_bias' in k for k in filtered_dict.keys()):
    #             new_bias_key = next(k for k in filtered_dict.keys() if 'new_bias' in k)
    #             self.new_bias = nn.Parameter(torch.empty_like(filtered_dict[new_bias_key])).to(self.bias.device)

    #     # Load remaining parameters normally
    #     super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
    #                                 missing_keys, unexpected_keys, error_msgs)
                    
                
                
    def print_params(self):
        print('*'*100)
        if hasattr(self, 'new_weight'):
            weight = torch.cat([self.weight, self.new_weight], dim=0)
        else:
            weight = self.weight
        print('expert[0,0,0] weight', weight[0,0,0])
        print('expert[-1,-1,-1] weight', weight[-1,-1,-1])
        return weight[0,0,0], weight[-1,-1,-1]



@torch.jit.script
def compute_gating(k: int, probs: torch.Tensor, top_k_gates: torch.Tensor, top_k_indices: torch.Tensor):
    # Input shapes for batch_size=2, num_experts=4, k=2:
    # probs: [batch_size, num_experts]        - Original probabilities for all experts
    # top_k_gates: [batch_size, k]            - Top k probability values
    # top_k_indices: [batch_size, k]          - Indices of those top k probabilities

    # Example values:
    # probs: [[0.1, 0.4, 0.3, 0.2],          - Full probability distribution
    #         [0.2, 0.3, 0.1, 0.4]]
    # top_k_gates: [[0.4, 0.3],               - Top 2 probabilities for each input
    #               [0.4, 0.3]]
    # top_k_indices: [[1, 2],                 - These probabilities came from experts 1&2 and 3&1
    #                 [3, 1]]


    zeros = torch.zeros_like(probs) # [batch_size, num_experts]
    gates = zeros.scatter(1, top_k_indices, top_k_gates)
    # gates: [[0.0, 0.4, 0.3, 0.0],          - Sparse matrix showing which experts are used
    #         [0.0, 0.3, 0.0, 0.4]]
    
    top_k_gates_flat = top_k_gates.flatten() # [0.4, 0.3, 0.4, 0.3]
    top_k_indices_flat = top_k_indices.flatten() # [1, 2, 3, 1]

    _, top_k_indices_sorted = top_k_indices_flat.sort(0)
    # Sorts by expert index to group computations for the same expert
    # Then top_k_indices_sorted = [0, 3, 1, 2]  # These indices would sort the array
    # Because:
    # top_k_indices_flat[top_k_indices_sorted] = [1, 1, 2, 3]

    expert_size = (gates > 0).long().sum(0) # How many times each expert is used
    # expert_size: [0, 2, 1, 1] (expert 1 used twice, experts 2&3 once each)

    
    # expert_size = torch.bincount(top_k_indices_flat)
    batch_index = top_k_indices_sorted.div(k, rounding_mode='trunc')
    # Which batch each computation came from: [0, 1, 0, 1]

    batch_gates = top_k_gates_flat[top_k_indices_sorted]
    # Reordered probabilities: [0.4, 0.3, 0.3, 0.4]

    return batch_gates, batch_index, expert_size, gates, top_k_indices_sorted

# @torch.jit.script
# def compute_gating(k: int, probs: torch.Tensor, top_k_gates: torch.Tensor, top_k_indices: torch.Tensor):
#     zeros = torch.zeros_like(probs)
#     gates = zeros.scatter(1, top_k_indices, top_k_gates)
#     top_k_gates = top_k_gates.flatten()
#     top_k_experts = top_k_indices.flatten()
#     nonzeros = top_k_gates.nonzero().squeeze(-1)
#     top_k_experts_nonzero = top_k_experts[nonzeros]
#     _, _index_sorted_experts = top_k_experts_nonzero.sort(0)
#     expert_size = (gates > 0).long().sum(0)
#     index_sorted_experts = nonzeros[_index_sorted_experts]
#     batch_index = index_sorted_experts.div(k, rounding_mode='trunc')
#     batch_gates = top_k_gates[index_sorted_experts]
#     return batch_gates, batch_index, expert_size, gates, index_sorted_experts


class TaskMoE(nn.Module):

    """Call a Sparsely gated mixture of experts layer with 1-layer Feed-Forward networks as experts.
    Args:
    input_size: integer - size of the input
    output_size: integer - size of the input
    num_experts: an integer - number of experts
    hidden_size: an integer - hidden size of the experts
    noisy_gating: a boolean
    k: an integer - how many experts to use for each batch element
    """

    def __init__(self,  input_size, head_size, 
                 k, num_experts, task_num, expert_bias, 
                 w_MI=0, w_H=0, w_finetune_MI=0, 
                 noisy_gating=True, **kwargs):
        super(TaskMoE, self).__init__()

        # model dimensions
        self.input_size = input_size
        self.head_size = head_size
        
        
        # expert parameters
        if k > num_experts:
            print(f"Warning: topk ({k}) is greater than num_experts ({num_experts}). Set k to num_experts.")
            
        self.k = min(k, num_experts)
        self.num_experts = num_experts
        self.task_num = task_num
        self.expert_bias = expert_bias
        self.experts = ParallelExperts(num_experts, input_size, head_size, self.expert_bias)
        self.expert_activation = nn.GELU()
        self.output_experts = ParallelExperts(num_experts, head_size, input_size, self.expert_bias)
        
        
        # loss parameters
        # mutual information loss
        self.w_MI = w_MI 
        # entropy loss
        self.w_H = w_H
        # finetune mutual information loss
        self.w_finetune_MI = w_finetune_MI


        # gating network
        self.noisy_gating = noisy_gating
        self.f_gate = nn.ModuleList([nn.Sequential(
                                        nn.Linear(input_size, input_size//4),
                                        nn.GELU(),
                                        nn.Linear(input_size//4,
                                                    2 * num_experts if self.noisy_gating else num_experts,
                                                    bias=True)
                                    ) for i in range(task_num)])
        

        for i in range(task_num):
            nn.init.zeros_(self.f_gate[i][-1].weight)

        self.register_buffer('PTE', torch.zeros(self.task_num, self.num_experts))
        self.register_buffer('PE', torch.zeros(self.num_experts))
        self.momentum = 0.0
        self.register_buffer('times',torch.zeros(1))

        self.task_gate_freq = [0] * self.task_num
        self.topk_acc_probs = [0] * self.task_num
        self.token_probs = [0] * self.task_num
        

    def get_MIloss(self, logits, probs, gates, task_bh):

        if not self.training:
            return 0.0

        top_k_gates, _ = probs.topk(self.k, dim=1)
        self.token_probs[task_bh] = self.token_probs[task_bh] * 0.95 + top_k_gates.mean(0).detach()*0.05

        self.task_gate_freq[task_bh] = self.task_gate_freq[task_bh]*0.95 + ((gates > 0).float().sum(0)).detach()*0.05

        self.topk_acc_probs[task_bh] = self.topk_acc_probs[task_bh]*0.95 + (probs.mean(0)).detach()*0.05
        
        PT = 1 / self.task_num # since we want each task to have equal weight

        # probs = P(E|T) in this batch
        # P(T,E) = P(E|T) * P(T) 
        self.PTE[task_bh] = self.PTE[task_bh] * self.momentum + (1-self.momentum) * (probs.mean(0).detach() * PT)


        # entropy loss
        # loss = 0.
        loss = -self.w_H * (probs * torch.log(probs + 0.0001)).sum(1).mean() # maximize the entropy

        if self.times[0] < 100:
            self.times[0] = self.times[0] + 1
            self.momentum = 1 - 1/(self.times[0]) 
            return loss
        else:
            self.momentum = 0.99

        # P(E) = \sum_T (P(E,T))
        PE = self.PTE.sum(0).detach()

        # P(E,T) in this batch
        MI_task_gate = torch.zeros(self.task_num, self.num_experts).to(probs.device)
        MI_task_gate[task_bh] = MI_task_gate[task_bh] + probs.mean(0) * PT

        # P(E) in this batch
        P_EI = probs.mean(0) * PT

        # get the MI loss
        MI_loss = -((MI_task_gate * (1 + torch.log(self.PTE.detach() + 0.0001)) ).sum() - (P_EI * (1 + torch.log(PE + 0.0001))).sum())

        finetune_MI_loss = -((MI_task_gate * (1 + torch.log(self.PTE.detach() + 0.0001)) ).sum())

        loss = loss + self.w_MI * MI_loss + self.w_finetune_MI * finetune_MI_loss 
        
        return loss


    def top_k_gating(self, x, task_bh, skip_mask=None, sample_topk=0, noise_epsilon=1e-2):
        """Noisy top-k gating.
          See paper: https://arxiv.org/abs/1701.06538.
          Args:
            x: input Tensor with shape [batch_size, input_size]
            train: a boolean - we only add noise at training time.
            noise_epsilon: a float
          Returns:
            gates: a Tensor with shape [batch_size, num_experts]
            load: a Tensor with shape [num_experts]
        """
        clean_logits = self.f_gate[task_bh](x)
        # if self.noisy_gating and self.training:
        if self.noisy_gating:
            clean_logits, raw_noise_stddev = clean_logits.chunk(2, dim=-1)
            noise_stddev = F.softplus(raw_noise_stddev) + noise_epsilon
            eps = torch.randn_like(clean_logits)
            noisy_logits = clean_logits + eps * noise_stddev
            logits = noisy_logits
        elif self.noisy_gating:
            logits, _ = clean_logits.chunk(2, dim=-1)
        else:
            logits = clean_logits

        probs = torch.softmax(logits, dim=1) + 1e-4
        
        if skip_mask is not None:
            probs = torch.masked_fill(probs, skip_mask, 0)

        if self.training and (sample_topk > 0):
            # top_k_indices = torch.multinomial(probs + 1e-6, self.k)
            # top_k_gates = torch.gather(probs, 1, top_k_indices)
            assert sample_topk <= self.k

            _, top_km1_indices = probs.topk(self.k - sample_topk, dim=1)
            masked_probs = probs + 1e-6
            masked_probs[torch.arange(probs.size(0)).unsqueeze(
                1), top_km1_indices] = 0
            k_indices = torch.multinomial(masked_probs, sample_topk)
            top_k_indices = torch.cat([top_km1_indices, k_indices], dim=-1)
            top_k_gates = torch.gather(probs, 1, top_k_indices)
        else:
            top_k_gates, top_k_indices = probs.topk(self.k, dim=1)
        
        top_k_gates = top_k_gates

        batch_gates, batch_index, expert_size, gates, index_sorted_experts = \
            compute_gating(self.k, probs, top_k_gates, top_k_indices)
        self.expert_size = expert_size
        self.index_sorted_experts = index_sorted_experts
        self.batch_index = batch_index
        self.batch_gates = batch_gates

        return self.get_MIloss(logits, probs, gates, task_bh),probs

    def forward(self, x, task_bh, skip_mask=None, sample_topk=0, multiply_by_gates=True):
        bsz, length, emb_size = x.size()
        x = x.reshape(-1, emb_size)
        if skip_mask is not None:
            skip_mask = skip_mask.view(-1, 1)
        loss,probs = self.top_k_gating(x, task_bh, skip_mask,  sample_topk=sample_topk)


        # expert_size: [used time of expert 1, used time of expert 2, ...] the number of times each expert is used
        # batch_index: size = sum(expert_size), batch index of input for each expert
        '''
        # Example: 
        # input for expert 1: [0, 1], input for expert 2: [0, 1, 2], input for expert 3: [2]
        # batch_index = ['0, 1' for expert 1, '0, 1, 2' for expert 2, '2' for expert 3]
        
        # expert_size = [0, 2, 3, 1, 0]
        # batch_index = [0, 1, 0, 1, 2, 2]
        '''
        expert_inputs = x[self.batch_index]
        h = self.experts(expert_inputs, self.expert_size)
        h = self.expert_activation(h)
        expert_outputs = self.output_experts(h, self.expert_size)
        
        if multiply_by_gates:
            expert_outputs = expert_outputs * self.batch_gates[:, None]

        zeros = torch.zeros((bsz * length, self.input_size), 
            dtype=expert_outputs.dtype, device=expert_outputs.device)
        y = zeros.index_add(0, self.batch_index, expert_outputs)
        y = y.view(bsz, length, self.input_size)
        return y, loss, probs



class TaskMoEGate(nn.Module):

    """Call a Sparsely gated mixture of experts layer with 1-layer Feed-Forward networks as experts.
    Args:
    input_size: integer - size of the input
    output_size: integer - size of the input
    num_experts: an integer - number of experts
    hidden_size: an integer - hidden size of the experts
    noisy_gating: a boolean
    k: an integer - how many experts to use for each batch element
    """

    def __init__(self,  input_size, 
                 k, num_experts, task_num, 
                 w_MI=0, w_H=0, w_finetune_MI=0, 
                 noisy_gating=True, **kwargs):
        super(TaskMoEGate, self).__init__()

        # model dimensions
        self.input_size = input_size
        
        # expert parameters
        if k > num_experts:
            print(f"Warning: topk ({k}) is greater than num_experts ({num_experts}). Set k to num_experts.")
            
        self.k = min(k, num_experts)
        self.num_experts = num_experts
        self.task_num = task_num
        
        # loss parameters
        # mutual information loss
        self.w_MI = w_MI 
        # entropy loss
        self.w_H = w_H
        # finetune mutual information loss
        self.w_finetune_MI = w_finetune_MI


        # gating network
        self.noisy_gating = noisy_gating
        self.f_gate = nn.ModuleList([nn.Sequential(
                                        nn.Linear(self.input_size, self.input_size//4),
                                        nn.GELU(),
                                        nn.Linear(self.input_size//4,
                                                    2 * self.num_experts if self.noisy_gating else self.num_experts,
                                                    bias=True)
                                    ) for i in range(self.task_num)])
        

        for i in range(self.task_num):
            nn.init.zeros_(self.f_gate[i][-1].weight)
            # nn.init.zeros_(self.f_gate[i][-1].bias)

        self.register_buffer('PTE', torch.zeros(self.task_num, self.num_experts))
        self.register_buffer('PE', torch.zeros(self.num_experts))
        self.momentum = 0.0
        self.register_buffer('times',torch.zeros(1))

        self.task_gate_freq = [0] * self.task_num
        self.topk_acc_probs = [0] * self.task_num
        self.token_probs = [0] * self.task_num
            
    def count_parameters(self):
        total_params = sum(p.numel() for p in self.f_gate.parameters())
        
        active_params = total_params / self.task_num
        
        return {
            'total_params': total_params,
            'active_params': active_params
        }

    def get_MIloss(self, logits, probs, gates, task_bh):
        # note: when continual learning, we only support the situation where we only get access to the data of the new tasks, not the previous tasks.
        return 0.0
        
        if not self.training:
            return 0.0
        

        top_k_gates, _ = probs.topk(self.k, dim=1)
        self.token_probs[task_bh] = self.token_probs[task_bh] * 0.95 + top_k_gates.mean(0).detach()*0.05

        self.task_gate_freq[task_bh] = self.task_gate_freq[task_bh]*0.95 + ((gates > 0).float().sum(0)).detach()*0.05

        self.topk_acc_probs[task_bh] = self.topk_acc_probs[task_bh]*0.95 + (probs.mean(0)).detach()*0.05
        
        PT = 1 / self.task_num # since we want each task to have equal weight

        # probs = P(E|T) in this batch
        # P(T,E) = P(E|T) * P(T) 
        if self.PTE.size(1) != probs.size(1):   
            return 0.0
            # raise ValueError(f"PTE size(1) {self.PTE.size(1)} does not match probs size(1) {probs.size(1)}. when continual learning, we only support the situation where we only get access to the data of the new tasks, not the previous tasks.")
        else:
            self.PTE[task_bh] = self.PTE[task_bh] * self.momentum + (1-self.momentum) * (probs.mean(0).detach() * PT)

        # entropy loss
        loss = -self.w_H * (probs * torch.log(probs + 0.0001)).sum(1).mean() # maximize the entropy

        if self.times[0] < 100:
            self.times[0] = self.times[0] + 1
            self.momentum = 1 - 1/(self.times[0]) 
            return loss
        else:
            self.momentum = 0.99

        # P(E) = \sum_T (P(E,T))
        PE = self.PTE.sum(0).detach()

        # P(E,T) in this batch
        MI_task_gate = torch.zeros(self.task_num, self.num_experts).to(probs.device)
        MI_task_gate[task_bh] = MI_task_gate[task_bh] + probs.mean(0) * PT

        # P(E) in this batch
        P_EI = probs.mean(0) * PT

        # get the MI loss
        MI_loss = -((MI_task_gate * (1 + torch.log(self.PTE.detach() + 0.0001)) ).sum() - (P_EI * (1 + torch.log(PE + 0.0001))).sum())

        finetune_MI_loss = -((MI_task_gate * (1 + torch.log(self.PTE.detach() + 0.0001)) ).sum())

        loss = loss + self.w_MI * MI_loss + self.w_finetune_MI * finetune_MI_loss 
        
        return loss

    def get_probs(self):
        return self.probs
    
    def get_clean_logits(self):
        return self.clean_logits

    def get_raw_noise_stddev(self):
        return self.noise_stddev

    def top_k_gating(self, x, task_bh, skip_mask=None, sample_topk=0, noise_epsilon=1e-2):
        """Noisy top-k gating.
          See paper: https://arxiv.org/abs/1701.06538.
          Args:
            x: input Tensor with shape [batch_size, input_size]
            train: a boolean - we only add noise at training time.
            noise_epsilon: a float
          Returns:
            gates: a Tensor with shape [batch_size, num_experts]
            load: a Tensor with shape [num_experts]
        """
        clean_logits = self.f_gate[task_bh](x)
        # scale = 3
        # clean_logits = torch.tanh(clean_logits) * scale
        # print('topk gating', self.k)
        # if self.noisy_gating and self.training:
        if self.noisy_gating:
            clean_logits, raw_noise_stddev = clean_logits.chunk(2, dim=-1)
            if self.training:
                noise_stddev = F.softplus(raw_noise_stddev) + noise_epsilon # noise_epsilon is 1e-2
                eps = torch.randn_like(clean_logits)
                noisy_logits = clean_logits + eps * noise_stddev
                logits = noisy_logits
            else:
                logits = clean_logits
            self.clean_logits = clean_logits
            self.noise_stddev = raw_noise_stddev
            
        else:
            logits = clean_logits

        probs = torch.softmax(logits, dim=1) + 1e-4
        
        self.probs = probs
        
        # probs = probs.detach().clone()
        # probs[:,:-2] = 0.1
        # probs[:,-2:] = 0.45

        if skip_mask is not None:
            probs = torch.masked_fill(probs, skip_mask, 0)

        # sample_topk = 2
        
        # self.k = 6
        if self.training and (sample_topk > 0):
            # top_k_indices = torch.multinomial(probs + 1e-6, self.k)
            # top_k_gates = torch.gather(probs, 1, top_k_indices)
            assert sample_topk <= self.k

            _, top_km1_indices = probs.topk(self.k - sample_topk, dim=1)
            masked_probs = probs + 1e-6
            masked_probs[torch.arange(probs.size(0)).unsqueeze(
                1), top_km1_indices] = 0
            k_indices = torch.multinomial(masked_probs, sample_topk)
            top_k_indices = torch.cat([top_km1_indices, k_indices], dim=-1)
            top_k_gates = torch.gather(probs, 1, top_k_indices)
        else:
            top_k_gates, top_k_indices = probs.topk(self.k, dim=1)

        # print('top_k_indices', top_k_indices)
        
        # top_k_gates = torch.clamp(top_k_gates, min=0.2, max=0.6)
        
        # top_k_gates = torch.clamp(top_k_gates, min=0.1, max=0.9)
        # top_k_gates /= top_k_gates.sum(dim=-1, keepdim=True)

        self.top_k_indices = top_k_indices
        

        # if k is equal to num_experts, then we can easily compute the gating
        
        if False and (self.k == self.num_experts):
            batch_size, num_experts = probs.shape
    
            # Gates is just probs since we're using all experts
            gates = probs
            
            # Each expert is used by every input
            expert_size = torch.full((num_experts,), batch_size, device=probs.device)
            
            # For each expert, we process all batches in order
            batch_index = torch.arange(batch_size, device=probs.device).repeat(num_experts)
            
            # Flatten probs but reorder correctly to match the expert-first ordering
            batch_gates = probs.t().reshape(-1)
            
            # Sequential indices matching the expert-first ordering 
            top_k_indices_sorted = torch.arange(batch_size * num_experts, device=probs.device)
            
            index_sorted_experts = top_k_indices_sorted
        else:
            batch_gates, batch_index, expert_size, gates, index_sorted_experts = \
                compute_gating(self.k, probs, top_k_gates, top_k_indices)
            
        # # Create a mask for values < 0.05
        # mask = batch_gates < 0.05
        # # Fill masked positions with 0
        # batch_gates = batch_gates.masked_fill(mask, 0)

        self.expert_size = expert_size
        self.index_sorted_experts = index_sorted_experts
        self.batch_index = batch_index
        self.batch_gates = batch_gates

        return self.get_MIloss(logits, probs, gates, task_bh),probs
    
    def get_gate_infos(self):
        return {
            'batch_index': self.batch_index,
            'expert_size': self.expert_size,
            'batch_gates': self.batch_gates
        }

    def forward(self, x, task_bh, skip_mask=None, sample_topk=0):
        loss, probs = self.top_k_gating(x, task_bh, skip_mask,  sample_topk=sample_topk)
        return loss, probs
    
    def add_gates(
        self, 
        n_new_tasks: int, 
        n_new_experts: int, 
        freeze_old: bool = True,
        noisy_gating: bool = False,
        k: int = -1,
    ):
        """
        Add new gating networks for n_new_tasks. 
        Optionally add n_new_experts to the total expert count.
        Freeze the old gating networks if freeze_old_gates is True.
        
        Args:
            n_new_tasks: How many new tasks we are adding
            n_new_experts: How many new experts we are adding (could be 0 if we only add tasks)
            freeze_old_gates: If True, old gating networks are frozen (their parameters won't be updated)
        """
        device = self.PTE.device  # Assuming everything is on the same device
        if k > 0:
            self.k = k
        
        old_num_tasks = self.task_num
        old_num_experts = self.num_experts
        new_num_tasks  = old_num_tasks + n_new_tasks
        new_num_experts = old_num_experts + n_new_experts

        # 1) Freeze old gates if requested
        if freeze_old:
            for gate_network in self.f_gate:
                for param in gate_network.parameters():
                    param.requires_grad = False

        # 2) If we are adding new experts, update num_experts
        if n_new_experts > 0:
            self.num_experts = new_num_experts
        
        # 3) Expand the buffers (PTE, PE, etc.) to accommodate new tasks (and possibly new experts)
        #    - old shape of PTE is [old_num_tasks, old_num_experts]
        #    - new shape of PTE is [new_num_tasks, new_num_experts]
        old_PTE = self.PTE
        self.PTE = torch.zeros(new_num_tasks, new_num_experts, device=device)
        # copy old usage stats into the top-left block
        self.PTE[:old_num_tasks, :old_num_experts] = old_PTE

        # Similarly, expand self.PE if you are actually using it in calculations
        old_PE = self.PE
        self.PE = torch.zeros(new_num_experts, device=device)
        self.PE[:old_num_experts] = old_PE

        # 4) Now update the counters to reflect the new total tasks
        self.task_num = new_num_tasks

        # 5) Create new gating nestworks for each of the newly added tasks
        #    The dimension of the final Linear depends on whether we want them to see 
        #    all experts (old + new) or only new experts, etc.
        #
        #    In this example, we let *new tasks see all experts*. 
        #    So the final dimension is either `new_num_experts` or `2*new_num_experts` (if noisy gating).
        #
        #    Also note that the hidden dimension we used was (input_size // 4) as in __init__.
        self.noisy_gating = noisy_gating
        in_dim = self.input_size
        hidden_dim = in_dim // 4
        out_dim = 2 * new_num_experts if self.noisy_gating else new_num_experts
        
        
        new_gate_networks = []
        
        for _ in range(n_new_tasks):
            net = nn.Sequential(
                nn.Linear(in_dim, hidden_dim),
                nn.GELU(),
                # nn.Linear(hidden_dim, hidden_dim),
                # nn.GELU(),
                # nn.Linear(hidden_dim, hidden_dim),
                # nn.GELU(),
                nn.Linear(hidden_dim, out_dim, bias=True),
            )
            # Initialize the last layer’s weights to zero, consistent with your original code
            nn.init.zeros_(net[-1].weight)
            # nn.init.zeros_(net[-1].bias)
            new_gate_networks.append(net)

        # 6) Append the new gating networks to the existing f_gate
        #    The first old_num_tasks gating networks are old (possibly frozen), 
        #    the new ones are appended at the end.
        for net in new_gate_networks:
            self.f_gate.append(net.to(device))

        # 7) Optionally, re-initialize the usage trackers for the newly added tasks
        #    i.e. self.task_gate_freq, self.topk_acc_probs, self.token_probs
        #    if you track them per-task
        for _ in range(n_new_tasks):
            self.task_gate_freq.append(0)
            self.topk_acc_probs.append(0)
            self.token_probs.append(0)
            

    # def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
    #                          missing_keys, unexpected_keys, error_msgs):
    #     """Custom state dict loading. if the shape of self.PE and self.PTE is different from the state_dict, we need to load the state_dict manually"""
    #     # Check if state dict has new parameters       
    #     filtered_dict = {}
    #     for key, value in state_dict.items():
    #         # Only keep keys that start with the prefix and are for this module
    #         if key.startswith(prefix):
    #             # Remove the prefix to match local names
    #             local_key = key[len(prefix):]
    #             filtered_dict[local_key] = value

    #     if 'PE' in filtered_dict:
    #         # Load both old and new parameters
    #         self.PE = nn.Parameter(torch.empty_like(filtered_dict['PE'])).to(self.PE.device)

    #     if 'PTE' in filtered_dict:
    #         self.PTE = nn.Parameter(torch.empty_like(filtered_dict['PTE'])).to(self.PTE.device)
            
    #     # Load remaining parameters normally
    #     super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
    #                                 missing_keys, unexpected_keys, error_msgs)





class TaskMoEFFN(nn.Module):

    """Call a Sparsely gated mixture of experts layer with 1-layer Feed-Forward networks as experts.
    Args:
    input_size: integer - size of the input
    output_size: integer - size of the input
    num_experts: an integer - number of experts
    hidden_size: an integer - hidden size of the experts
    noisy_gating: a boolean
    k: an integer - how many experts to use for each batch element
    """

    def __init__(self,  input_size, head_size, 
                 k, num_experts, task_num, expert_bias, 
                 w_MI=0, w_H=0, w_finetune_MI=0, 
                 noisy_gating=True, **kwargs):
        super(TaskMoEFFN, self).__init__()

        # model dimensions
        self.input_size = input_size
        self.head_size = head_size
        
        
        # expert parameters
        if k > num_experts:
            print(f"Warning: topk ({k}) is greater than num_experts ({num_experts}). Set k to num_experts.")
            
        self.k = min(k, num_experts)
        self.num_experts = num_experts
        self.task_num = task_num
        self.expert_bias = expert_bias
        self.experts = ParallelExperts(num_experts, input_size, head_size, self.expert_bias)
        self.expert_activation = nn.GELU()
        self.output_experts = ParallelExperts(num_experts, head_size, input_size, self.expert_bias)
        
        
        # loss parameters
        # mutual information loss
        self.w_MI = w_MI 
        # entropy loss
        self.w_H = w_H
        # finetune mutual information loss
        self.w_finetune_MI = w_finetune_MI


        # gating network
        self.gating_network = TaskMoEGate(input_size, 
                                          k, num_experts, task_num, 
                                          w_MI, w_H, w_finetune_MI, 
                                          noisy_gating)
        

    def forward(self, x, task_bh, skip_mask=None, sample_topk=0, multiply_by_gates=True):
        
        
        bsz, length, emb_size = x.size()
        x = x.reshape(-1, emb_size)
        if skip_mask is not None:
            skip_mask = skip_mask.view(-1, 1)
        
        loss, probs = self.gating_network(x, task_bh, skip_mask, sample_topk)
        
        
        # expert_size: [used time of expert 1, used time of expert 2, ...] the number of times each expert is used
        # batch_index: size = sum(expert_size), batch index of input for each expert
        '''
        # Example: 
        # input for expert 1: [0, 1], input for expert 2: [0, 1, 2], input for expert 3: [2]
        # batch_index = ['0, 1' for expert 1, '0, 1, 2' for expert 2, '2' for expert 3]
        
        # expert_size = [0, 2, 3, 1, 0]
        # batch_index = [0, 1, 0, 1, 2, 2]
        '''
        expert_inputs = x[self.gating_network.batch_index]
        h = self.experts(expert_inputs, self.gating_network.expert_size)
        h = self.expert_activation(h)
        expert_outputs = self.output_experts(h, self.gating_network.expert_size)
        
        if multiply_by_gates:
            expert_outputs = expert_outputs * self.gating_network.batch_gates[:, None]

        zeros = torch.zeros((bsz * length, self.input_size), 
            dtype=expert_outputs.dtype, device=expert_outputs.device)
        y = zeros.index_add(0, self.gating_network.batch_index, expert_outputs)
        y = y.view(bsz, length, self.input_size)
        return y, loss, probs



class TaskMoEqkv(nn.Module):

    """Call a Sparsely gated mixture of experts layer with 1-layer Feed-Forward networks as experts.
    Args:
    input_size: integer - size of the input
    output_size: integer - size of the input
    num_experts: an integer - number of experts
    hidden_size: an integer - hidden size of the experts
    noisy_gating: a boolean
    k: an integer - how many experts to use for each batch element
    """

    def __init__(self,  input_size, head_size, 
                 k, num_experts, task_num, expert_bias, 
                 w_MI=0, w_H=0, w_finetune_MI=0, 
                 noisy_gating=True, **kwargs):
        super(TaskMoEqkv, self).__init__()

        # model dimensions
        self.input_size = input_size
        self.head_size = head_size
        
        
        # expert parameters
        if k > num_experts:
            print(f"Warning: topk ({k}) is greater than num_experts ({num_experts}). Set k to num_experts.")
            
        self.k = min(k, num_experts)
        self.num_experts = num_experts
        self.task_num = task_num
        self.expert_bias = expert_bias
        self.experts = ParallelExperts(num_experts, input_size, head_size, self.expert_bias)
        self.expert_activation = nn.GELU()
        self.output_experts = ParallelExperts(num_experts, head_size, input_size, self.expert_bias)
        
        
        # loss parameters
        # mutual information loss
        self.w_MI = w_MI 
        # entropy loss
        self.w_H = w_H
        # finetune mutual information loss
        self.w_finetune_MI = w_finetune_MI


        # gating network
        self.gating_network = TaskMoEGate(input_size, 
                                          k, num_experts, task_num, 
                                          w_MI, w_H, w_finetune_MI, 
                                          noisy_gating)
        

    def forward(self, x, task_bh, skip_mask=None, sample_topk=0, multiply_by_gates=True):
        
        
        bsz, length, emb_size = x.size()
        x = x.reshape(-1, emb_size)
        if skip_mask is not None:
            skip_mask = skip_mask.view(-1, 1)
        
        loss, probs = self.gating_network(x, task_bh, skip_mask, sample_topk)
        
        
        # expert_size: [used time of expert 1, used time of expert 2, ...] the number of times each expert is used
        # batch_index: size = sum(expert_size), batch index of input for each expert
        '''
        # Example: 
        # input for expert 1: [0, 1], input for expert 2: [0, 1, 2], input for expert 3: [2]
        # batch_index = ['0, 1' for expert 1, '0, 1, 2' for expert 2, '2' for expert 3]
        
        # expert_size = [0, 2, 3, 1, 0]
        # batch_index = [0, 1, 0, 1, 2, 2]
        '''
        expert_inputs = x[self.gating_network.batch_index]
        h = self.experts(expert_inputs, self.gating_network.expert_size)
        h = self.expert_activation(h)
        expert_outputs = self.output_experts(h, self.gating_network.expert_size)
        
        if multiply_by_gates:
            expert_outputs = expert_outputs * self.gating_network.batch_gates[:, None]

        zeros = torch.zeros((bsz * length, self.input_size), 
            dtype=expert_outputs.dtype, device=expert_outputs.device)
        y = zeros.index_add(0, self.gating_network.batch_index, expert_outputs)
        y = y.view(bsz, length, self.input_size)
        return y, loss, probs




if __name__ == '__main__':
  batch_size = 3
  sequence_length = 10
  input_size = 20
  model = TaskMoE(input_size=20,head_size=10,num_experts=5,k=2,activation=nn.Sequential(
                        nn.GELU(),
                    ),noisy_gating=False)
  input_data = torch.randn(batch_size, sequence_length, input_size)

  # Specify the task or task batch you want to perform inference for.
  task_batch_index = int(0) # Replace with the appropriate task batch index.

  # You can skip certain tokens during inference by providing a skip_mask. 
  # Set to None if you don't want to skip any tokens.
  skip_mask = None

  # Perform inference (forward pass) using the TaskMoE model for the specified task.
  output, loss = model(input_data, task_batch_index, skip_mask=skip_mask)
  print(output.shape)