import copy
import pickle
import math
import torch
import torch.nn as nn
import torch.nn.functional as F


rank = 0
class RandGate(nn.Module):
    def __init__(self, model_dim, num_global_experts, k=1, fp32_gate=False, glr=5e-2, **options):
        super().__init__()
        self.dim = model_dim
        # import pdb;pdb.set_trace()
        self.K = num_global_experts
        self.top_k = min(num_global_experts, int(k))
        self.glr = glr
        self.rand_coef = nn.Parameter(torch.rand(self.K, self.K))
        nn.init.orthogonal_(self.rand_coef)
        self.temp = 1.0
        self.fp32_gate = fp32_gate
        self.enable_softmax_logits = True

        self.basis = nn.Parameter(torch.randn(self.K, self.dim), requires_grad=False)
        self.basis.data = self.basis.data / torch.norm(self.basis.data, dim=1, keepdim=True)
        self.routing_ready = True
        self.weight = nn.Parameter(torch.randn(self.K, self.dim))
        self.alpha = nn.Parameter(torch.randn(self.K)) 
        
    def randommix(self):
        basis = self.basis.data
        rand_basis = self.rand_coef @ basis 
        return rand_basis

    def forward(self, x):
        rbasis = self.randommix()
        # import pdb;pdb.set_trace()
        
        if self.fp32_gate:
            x = x.float()
            weight = self.weight.float()
            alpha = self.alpha.float()
            rbasis = rbasis.float()
        else:
            weight = self.weight
            alpha = self.alpha
            
        logits = torch.matmul(x, weight.T) 
        alpha = torch.sigmoid(alpha*self.temp)
        gha_logits = torch.matmul(x, rbasis.T) 
        logits = (alpha) * logits + (1-alpha) * gha_logits  
        
        return logits, self.top_k
    
    
Gate = RandGate
