import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from .baseline import BaseFLModel
from encoder_manager import create_encoder_manager


class Expert(nn.Module):
    def __init__(self, input_dim, rank, output_dim):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, rank, bias=False)
        self.fc2 = nn.Linear(rank, output_dim, bias=False)
        nn.init.normal_(self.fc1.weight, mean=0.0, std=0.01)
        nn.init.zeros_(self.fc2.weight)

    def forward(self, x):
        return self.fc2(F.relu(self.fc1(x)))

class MoERouter(nn.Module):
    def __init__(self, in_features, num_experts, top_k):
        super().__init__()
        self.gate = nn.Linear(in_features, num_experts)
        self.num_experts = num_experts
        self.top_k = top_k

    def forward(self, x):
        router_logits = self.gate(x)
        router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float)
        router_weights, selected_indices = torch.topk(router_probs, self.top_k, dim=-1)
        router_weights /= router_weights.sum(dim=-1, keepdim=True)
        
        expert_mask = F.one_hot(selected_indices, num_classes=self.num_experts)
        return router_logits, router_weights, expert_mask.permute(2, 1, 0)

class SparseMoEFFN(nn.Module):
    def __init__(self, original_ffn, num_experts, rank, top_k):
        super().__init__()
        
        in_features = original_ffn.fc1.in_features
        out_features = original_ffn.fc2.out_features
        
        self.experts = nn.ModuleList([Expert(in_features, rank, out_features) for _ in range(num_experts)])
        self.router = MoERouter(in_features, num_experts, top_k)
        
        self.original_ffn = original_ffn 
        for param in self.original_ffn.parameters():
            param.requires_grad = False
            
        self.num_experts = num_experts
        self.balance_loss = 0
        self.MoE_alpha = 4

    def forward(self, x):
        ffn_out = self.original_ffn(x) 
        
        batch_size, seq_len, hidden_dim = x.size()
        hidden_states = x.view(-1, hidden_dim)
        
        router_logits, router_weights, expert_masks = self.router(hidden_states)
        
        final_hidden_states = torch.zeros_like(ffn_out.view(-1, self.original_ffn.fc2.out_features))

        for i in range(self.num_experts):
            expert_layer = self.experts[i]
            router_weights_idx, top_x = torch.where(expert_masks[i])
            
            if top_x.shape[0] > 0:
                current_states = hidden_states[top_x]
                expert_out = expert_layer(current_states)
                
                current_router_weights = router_weights[top_x, router_weights_idx].unsqueeze(1)
                weighted_out = expert_out * current_router_weights
                final_hidden_states.index_add_(0, top_x, weighted_out.to(hidden_states.dtype))

        expected_load = torch.softmax(router_logits, dim=-1).mean(dim=0)
        
        if expected_load.max() > 0.2:
            uniform_dist = torch.ones_like(expected_load) / self.num_experts
            kl_div = torch.sum(expected_load * torch.log(expected_load / uniform_dist + 1e-9))
            self.balance_loss = kl_div
        else:
            self.balance_loss = torch.tensor(0.0, device=expected_load.device)
        
        final_output = final_hidden_states.view(batch_size, seq_len, -1)
        return self.MoE_alpha * final_output + ffn_out

class FMoE(BaseFLModel, nn.Module):
    def __init__(self, cfg, device):
        BaseFLModel.__init__(self, cfg, device)
        nn.Module.__init__(self)
        
        m, t, e = cfg.get("model", {}), cfg.get("train", {}), cfg.get("encoder", {})
        
        self.enc_mgr = create_encoder_manager(cfg, self.device)
        self.backbone = self.enc_mgr.model

        self.num_experts = int(m.get("num_experts", 8))
        self.rank_per_expert = int(m.get("rank_per_expert", 4))
        self.top_k = int(m.get("top_k", 1))
        self.aux_lambda = float(t.get("aux_lambda", 1e-5))
        
        self._replace_ffn_with_sparsemoe(self.backbone, self.num_experts, self.rank_per_expert, self.top_k)

        self.classification_head = nn.Linear(self.enc_mgr.feature_dim, int(m.get("num_classes", 10))).to(device)
        
        self.lr = float(t.get("lr", 3e-4))
        self.weight_decay = float(t.get("weight_decay", 1e-2))
        self.epochs = int(t.get("local_epochs", 4))
        self.bs = int(cfg["data"].get("batch_size", 128))
        
        self.trainable_keys = self._get_trainable_keys()

    def _get_trainable_keys(self):
        keys = []
        for name, param in self.named_parameters():
            if "experts" in name or "router" in name or "classification_head" in name:
                param.requires_grad = True
                keys.append(name)
            else:
                param.requires_grad = False
        return keys

    def _replace_ffn_with_sparsemoe(self, model, num_experts, rank, top_k):
        for block in model.blocks:
            original_ffn = block.mlp
            moe_ffn = SparseMoEFFN(original_ffn, num_experts, rank, top_k)
            
            moe_ffn.to(self.device)
            
            block.mlp = moe_ffn

    def get_requirements(self):
        return {"input_type": "images"}

    def init_global(self, enc_info=None):
        trainable_state = {k: v.cpu().clone() for k, v in self.state_dict().items() if k in self.trainable_keys}
        return {"trainable": trainable_state}

    def client_update(self, global_state, client_data, round_idx, enc_mgr=None):
        self.load_state_dict(global_state["trainable"], strict=False)
        self.train()
        
        loader = self._as_loader(client_data, shuffle=True, batch_size=self.bs)
        opt = optim.Adam(filter(lambda p: p.requires_grad, self.parameters()), lr=self.lr, weight_decay=self.weight_decay)
        ce_loss_fn = nn.CrossEntropyLoss()
        
        for _ in range(self.epochs):
            for xb, yb in loader:
                xb, yb = xb.to(self.device), yb.to(self.device)
                
                x = self.backbone.patch_embed(xb)
                x = torch.cat((self.backbone.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
                x = self.backbone.pos_drop(x + self.backbone.pos_embed)
                for block in self.backbone.blocks:
                    x = block(x)
                
                x = self.backbone.norm(x)
                
                cls_output = x[:, 0]
                logits = self.classification_head(cls_output)
                
                task_loss = ce_loss_fn(logits, yb)
                
                aux_loss = 0
                for block in self.backbone.blocks:
                    if isinstance(block.mlp, SparseMoEFFN):
                        aux_loss += block.mlp.balance_loss
                
                total_loss = task_loss + self.aux_lambda * aux_loss
                
                opt.zero_grad(set_to_none=True)
                total_loss.backward()
                opt.step()
       
        updated_params = {k: v.cpu().clone() for k, v in self.state_dict().items() if k in self.trainable_keys}
        num_samples = len(client_data[1])
        return {"trainable": updated_params}, {"trainable": {"scalar": num_samples}}

    @torch.no_grad()
    def evaluate(self, global_state, testset, enc_mgr=None):
        self.load_state_dict(global_state["trainable"], strict=False)
        self.eval()
        
        loader = self._as_loader(testset, shuffle=False, batch_size=self.bs)
        ce_loss_fn = nn.CrossEntropyLoss()
        total_loss, correct, count = 0.0, 0, 0
        
        for xb, yb in loader:
            xb, yb = xb.to(self.device), yb.to(self.device)
            
            x = self.backbone.patch_embed(xb)
            x = torch.cat((self.backbone.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
            x = x + self.backbone.pos_embed
            for block in self.backbone.blocks:
                x = block(x)
            x = self.backbone.norm(x)
            
            cls_output = x[:, 0]
            logits = self.classification_head(cls_output)

            total_loss += ce_loss_fn(logits, yb).item()
            correct += (logits.argmax(1) == yb).sum().item()
            count += yb.numel()
            
        return total_loss / max(1, len(loader)), 100.0 * correct / max(1, count)