import copy
import pickle
import math
import torch
import torch.nn as nn
import torch.nn.functional as F


class GHA(nn.Module):
    def __init__(self, dim, n_components, learning_rate=1e-1):
        super().__init__()
        self.dim = dim
        self.n_components = n_components
        self.lr = learning_rate
        self.basis = nn.Parameter(
            torch.nn.functional.normalize(
                torch.randn(n_components, dim),
                dim=1
            ), requires_grad=False
        )

    def partial_fit(self, X, iters=10):
        X = X.clone().detach() 
        for _ in range(iters):
            Y = X @ self.basis.T  
            proj = torch.cumsum(Y[:, :, None] * self.basis[None, :, :], dim=1)
            delta = (Y[:, :, None] * (X[:, None, :] - proj)).mean(dim=0)
            with torch.no_grad():
                self.basis.add_(self.lr * delta)            
                self.basis.copy_(F.normalize(self.basis, dim=1))

    def components_(self):
        return self.basis.detach()
    
rank = 0
class STARGate(nn.Module):
    def __init__(self, model_dim, num_global_experts, k=1, fp32_gate=False, glr=2e-5, **options):
        super().__init__()
        self.dim = model_dim
        self.K = num_global_experts
        self.top_k = min(num_global_experts, int(k))
        self.glr = glr
        self.mixing_coef = nn.Parameter(torch.rand(self.K, self.K))
        nn.init.orthogonal_(self.mixing_coef)
        self.fp32_gate = fp32_gate
        self.enable_softmax_logits = True

        self.gha = GHA(self.dim, self.K, learning_rate=self.glr)
        self.routing_ready = True
        self.weight = nn.Parameter(torch.randn(self.K, self.dim))
        self.alpha = nn.Parameter(torch.randn(self.K)) 
        
        if self.gha is not None:
            self.add_module('gha', self.gha)
        
        self.forward_count = 0
        self.total_steps = options.get('total_steps', 10000)
        self.collect_start = self.total_steps // 2
        self.collect_samples = options.get('collect_samples', 1000)
        self.routing_distributions = []
        self.collecting = False
        
        for opt in options:
            if opt not in ('gate_noise', 'capacity_factor', 'total_steps', 'collect_samples', 'normalize_one_score_gate'):
                raise Exception('Unrecognized argument provided to STAR Gate: %s' % opt)
            
    def update(self, x):
        if self.training:
            self.gha.partial_fit(x, iters=3)
            
    def basismix(self):
        basis = self.gha.components_() 
        mbasis = self.mixing_coef @ basis 
        return mbasis

    def forward(self, x):
        self.update(x)  
        rbasis = self.basismix()

        if self.fp32_gate:
            x = x.float()
            weight = self.weight.float()
            alpha = self.alpha.float()
            rbasis = rbasis.float()
        else:
            weight = self.weight
            alpha = self.alpha
            
        logits = torch.matmul(x, weight.T) 
        gha_logits = torch.matmul(x, rbasis.T) 
        alpha = torch.sigmoid(alpha)
        logits = (alpha) * logits + (1-alpha) * gha_logits  
        
        if self.training:
            self.forward_count += 1
            if self.forward_count >= self.collect_start and not self.collecting:
                self.collecting = True
                self.routing_distributions = []
            
            if self.collecting and len(self.routing_distributions) < self.collect_samples:
                with torch.no_grad():
                    probs = F.softmax(logits, dim=-1)
                    _, indices = torch.topk(probs, self.top_k, dim=-1)
                    self.routing_distributions.append({
                        'step': self.forward_count,
                        'expert_choices': indices.cpu(),
                        'probs': probs.cpu()
                    })
        
        return logits, self.top_k
    
    def save_routing_distributions(self, path):
        if len(self.routing_distributions) > 0:
            with open(path, 'wb') as f:
                pickle.dump(self.routing_distributions, f)
            print(f"Saved {len(self.routing_distributions)} routing distributions to {path}")
    
Gate = STARGate
