import sys
import os
# current_dir = os.path.dirname(os.path.abspath(__file__))
# parent_dir = os.path.abspath(os.path.join(current_dir, os.pardir))
# sys.path.append(parent_dir)
# sys.path.append(os.path.join(parent_dir, 'utils'))

import types
from torch import nn

from utils.utils import ActivationModule, Distribution, SparsifyFn, get_module_device

import torch

def _monkeypatch_experts(experts, file_path, grabbing_mode=False, sparse_mode=None, mask_by=None):
    experts.forward_old = experts.forward
    experts.forward = types.MethodType(_experts_forward, experts)

    experts.add_sparse_fns = types.MethodType(add_sparse_fns, experts)

    experts.file_path = file_path
    experts.grabbing_mode = grabbing_mode
    
    experts.sparse_mode = sparse_mode
    if sparse_mode == 'wina':
        experts.gate_up_norm_by_column = experts.gate_up_proj.norm(dim=1)
        experts.down_norm_by_column = torch.ones(experts.down_proj.shape[0])

    if not grabbing_mode:
        if sparse_mode == 'wina':
            experts.distrs = {}
            experts.distrs['gate_up'] = Distribution(file_path, hidden_type='gate_up') if mask_by == 'threshold' else None
            experts.distrs['down'] = Distribution(file_path, hidden_type='down') if mask_by == 'threshold' else None
            
            experts.sparse_fns = nn.ModuleDict({
                'gate_up': SparsifyFn(experts.distrs['gate_up'], sparse_mode=sparse_mode, mask_by=mask_by).to(get_module_device(experts)),
                'down': SparsifyFn(experts.distrs['down'], sparse_mode=sparse_mode, mask_by=mask_by).to(get_module_device(experts)),
            })
                
        elif sparse_mode == 'teal':
            experts.distrs = {}
            experts.distrs['h1'] = Distribution(file_path, hidden_type='h1') if mask_by == 'threshold' else None
            experts.distrs['h2'] = Distribution(file_path, hidden_type='h2') if mask_by == 'threshold' else None

            experts.sparse_fns = nn.ModuleDict({
                'gate_up': SparsifyFn(experts.distrs['h1'], sparse_mode=sparse_mode, mask_by=mask_by).to(get_module_device(experts)),
                'down': SparsifyFn(experts.distrs['h2'], sparse_mode=sparse_mode, mask_by=mask_by).to(get_module_device(experts)),
            })
            
    experts.activation_module = ActivationModule(file_path)

    return experts

def add_sparse_fns(self, sparsity=0.25, mask_by=None):
    experts = self
    experts.grabbing_mode = False
    sparse_mode = experts.sparse_mode
    file_path = experts.file_path
    experts.distrs = {}
    if sparse_mode == 'wina':
        projs = ['gate_up', 'down']
        for proj in projs:
            experts.distrs[proj] = Distribution(file_path, hidden_type=proj) if mask_by == 'threshold' else None

        experts.sparse_fns = nn.ModuleDict({
            proj: SparsifyFn(experts.distrs[proj], sparse_mode=sparse_mode, mask_by=mask_by).to(get_module_device(experts))
            for proj in projs
        })
    elif sparse_mode == 'teal':
        experts.distrs['h1'] = Distribution(file_path, hidden_type='h1') if mask_by == 'threshold' else None
        experts.distrs['h2'] = Distribution(file_path, hidden_type='h2') if mask_by == 'threshold' else None

        experts.sparse_fns = nn.ModuleDict({
            'gate_up': SparsifyFn(experts.distrs['h1'], sparse_mode=sparse_mode, mask_by=mask_by).to(get_module_device(experts)),
            'down': SparsifyFn(experts.distrs['h2'], sparse_mode=sparse_mode, mask_by=mask_by).to(get_module_device(experts)),
        })
        
def _experts_forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None):
    batch_size = hidden_states.shape[0]
    hidden_states = hidden_states.reshape(-1, self.hidden_size)  # (num_tokens, hidden_size)
    num_experts = routing_weights.shape[1]
    if self.training:
        next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device)
        with torch.no_grad():
            expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=num_experts)
            expert_mask = expert_mask.permute(2, 1, 0)
            expert_hitted = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
        for expert_idx in expert_hitted[:]:
            with torch.no_grad():
                _, token_idx = torch.where(expert_mask[expert_idx[0]])
            current_state = hidden_states[token_idx]
            gate_up = current_state @ self.gate_up_proj[expert_idx] + self.gate_up_proj_bias[expert_idx]
            gate, up = gate_up[..., ::2], gate_up[..., 1::2]
            gate = gate.clamp(min=None, max=self.limit)
            up = up.clamp(min=-self.limit, max=self.limit)
            glu = gate * torch.sigmoid(gate * self.alpha)
            gated_output = (up + 1) * glu
            out = gated_output @ self.down_proj[expert_idx] + self.down_proj_bias[expert_idx]
            weighted_output = out[0] * routing_weights[token_idx, expert_idx, None]
            next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype))
        next_states = next_states.view(batch_size, -1, self.hidden_size)
    else:
        hidden_states = hidden_states.repeat(num_experts, 1)
        hidden_states = hidden_states.view(num_experts, -1, self.hidden_size)
        # sparse activation
        if self.sparse_mode == 'wina':
            hidden_states = hidden_states * self.sparse_fns['gate_up'](hidden_states * self.gate_up_norm_by_column[expert_idx].to(hidden_states.device))
        elif self.sparse_mode == 'teal':
            hidden_states = hidden_states * self.sparse_fns['gate_up'](hidden_states)
        print(hidden_states.count_nonzero().numel(), hidden_states.numel())
        print(hidden_states)
        exit(0)
        gate_up = torch.bmm(hidden_states, self.gate_up_proj) + self.gate_up_proj_bias[..., None, :]
        gate, up = gate_up[..., ::2], gate_up[..., 1::2]
        gate = gate.clamp(min=None, max=self.limit)
        up = up.clamp(min=-self.limit, max=self.limit)
        glu = gate * torch.sigmoid(gate * self.alpha)
        gated_output = (up + 1) * glu
        # sparse activation
        if self.sparse_mode == 'wina':
            gated_output = gated_output * self.sparse_fns['down'](gated_output * self.down_norm_by_column[expert_idx].to(gated_output.device))
        elif self.sparse_mode == 'teal':
            gated_output = gated_output * self.sparse_fns['down'](gated_output)
        next_states = torch.bmm(gated_output, self.down_proj)
        next_states = next_states + self.down_proj_bias[..., None, :]
        next_states = next_states.view(num_experts, batch_size, -1, self.hidden_size)
        next_states = next_states * routing_weights.transpose(0, 1).view(num_experts, batch_size, -1)[..., None]
        next_states = next_states.sum(dim=0)
        
    return next_states