import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class LoRA(nn.Module):
    def __init__(
        self,
        in_features: int,
        out_features: int,
        rank: int,
        alpha: int,
        use_lora_dropout: bool = False,
        dropout: float = 0.05,
    ):
        super().__init__()
        self.rank = rank
        self.alpha = alpha
        self.scaling = self.alpha / self.rank
        self.in_features = in_features
        self.out_features = out_features

        self.A = nn.Parameter(torch.randn(rank, in_features, dtype=torch.float32))
        self.B = nn.Parameter(torch.zeros(out_features, rank, dtype=torch.float32))
        
        nn.init.kaiming_uniform_(self.A, a=math.sqrt(5))
        nn.init.zeros_(self.B)

        self.use_lora_dropout = use_lora_dropout
        if self.use_lora_dropout:
            self.dropout = nn.Dropout(p=dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.use_lora_dropout:
            x = self.dropout(x)
        
        lora_update = (x @ self.A.transpose(0, 1) @ self.B.transpose(0, 1))
        return lora_update * self.scaling

class WrappedLoRALayer(nn.Module):
    """Wraps a linear layer with a LoRA adapter."""
    def __init__(
        self,
        original_module: nn.Linear,
        rank: int,
        alpha: int,
        use_lora_dropout: bool,
    ):
        super().__init__()
        self.original_module = original_module
        
        in_features = original_module.in_features
        out_features = original_module.out_features
        self.lora = LoRA(in_features, out_features, rank, alpha, use_lora_dropout)

        # Freeze the original module's parameters
        for param in self.original_module.parameters():
            param.requires_grad = False

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        original_output = self.original_module(x)
        lora_update = self.lora(x.to(torch.float32))
        return original_output + lora_update.to(original_output.dtype)
    

class MoELoRA(nn.Module):
    """Mixture-of-Experts with LoRA experts."""
    def __init__(
        self,
        original_module,
        num_experts: int,
        rank: int,
        top_k: int,
        alpha: int,
        task_embeddings,
        use_lora_dropout: bool,
        film_adapter,
    ):
        super().__init__()
        self.original_module = original_module
        self.in_features = original_module.in_features
        self.out_features = original_module.out_features
        self.num_experts = num_experts
        self.top_k = top_k
        self.film_adapter = film_adapter
        
        # Freeze original module
        self.original_module.weight.requires_grad = False

        self.experts = nn.ModuleList(
            [LoRA(self.in_features, self.out_features, rank, alpha, use_lora_dropout) for _ in range(num_experts)]
        )
        self.shared_expert = LoRA(self.in_features, self.out_features, rank, alpha, use_lora_dropout)
        self.router_fc = nn.Linear(self.in_features, num_experts + 1, dtype=torch.float32) # +1 for shared

        self.task_embeddings = task_embeddings
        self.current_task_ids = None
        self.aux_loss = 0
        self.register_buffer('selected_count', torch.zeros(num_experts, dtype=torch.float32))
    
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        original_output = self.original_module(x)  # [batch_size, seq_length,out_features]
        batch_size, seq_len, in_features = x.shape
        task_emb_dim = self.task_embeddings.size(-1)
        task_emb = self.task_embeddings[self.current_task_ids]
        input_dtype = x.dtype
        x = x.to(torch.float32)
        module_dtype = self.router_fc.weight.dtype
        x_with_task = self.film_adapter(x, task_emb)
        # x_with_task = x

        all_logits = self.router_fc(x_with_task)
        shared_expert_logits = all_logits[..., -1:]
        independent_experts_logits = all_logits[..., :-1]
        topk_values, topk_indices = torch.topk(independent_experts_logits, self.top_k, dim=-1) # [batch_size, seq_len, top_k]
        combined_logits_for_softmax = torch.cat((shared_expert_logits, topk_values), dim=-1)
        combined_weights = F.softmax(combined_logits_for_softmax, dim=-1)
        shared_weight = combined_weights[..., 0:1]
        independent_experts_topk_weights = combined_weights[..., 1:]
        topk_weights = independent_experts_topk_weights
        router_probs = F.softmax(independent_experts_logits,dim=-1)

        flat_idx = topk_indices.view(-1)       # shape: [B * S * K]
        batch_counts = torch.bincount(flat_idx, minlength=self.num_experts)
        if self.training:
            self.selected_count += batch_counts.to(self.selected_count.device)

        x_flat = x.view(batch_size * seq_len, in_features)  
        shared_output = self.shared_expert(x_flat) * shared_weight.view(batch_size * seq_len, 1)

        expert_outputs = torch.stack([expert(x_flat) for expert in self.experts], dim=1) 
        flat_indices = topk_indices.view(batch_size * seq_len, self.top_k)  
        batch_seq_indices = torch.arange(batch_size * seq_len, device=x.device).unsqueeze(1).expand(-1, self.top_k)
        selected_expert_outputs = expert_outputs[batch_seq_indices, flat_indices] 

        flat_weights = topk_weights.view(batch_size * seq_len, self.top_k)  # [batch_size * seq_len, top_k]
        weighted_expert_output = (selected_expert_outputs * flat_weights.unsqueeze(-1)).sum(dim=1)  # [batch_size * seq_len, out_features]

        combined_expert_output = shared_output + weighted_expert_output
        # combined_expert_output = weighted_expert_output
        combined_expert_output = combined_expert_output.view(batch_size, seq_len, self.out_features)

        final_output = original_output + combined_expert_output.to(input_dtype)
        if self.training:
            self.aux_loss = self.compute_aux_loss(router_probs, topk_indices, batch_size, seq_len)
        else:
            self.aux_loss = torch.tensor(0.0, device=x.device)
        return final_output
    
    def sparsify_gradients_by_momentum_with_random(self, optimizer, step: int):
        import random

        shared_expert = self.shared_expert
        params_A = shared_expert.A
        params_B = shared_expert.B
        if params_A.grad is None or params_B.grad is None or params_A not in optimizer.state or params_B not in optimizer.state:
            return
        momentum_A = optimizer.state[params_A]['exp_avg']
        momentum_B = optimizer.state[params_B]['exp_avg']

        warm_up_steps = 500
        decay_end_step = 1000
        final_update_ratio = 0.05  
        random_ratio = 0.002       
        topk_ratio = final_update_ratio - random_ratio  

        total_params = params_A.numel() + params_B.numel()
        final_k_topk = int(total_params * topk_ratio)
        final_k_random = int(total_params * random_ratio)

        if step < warm_up_steps:
            k_topk = total_params
            k_random = 0
        elif step < decay_end_step:
            decay_steps = decay_end_step - warm_up_steps
            current_decay_step = step - warm_up_steps
            cosine_decay = 0.5 * (1 + math.cos(math.pi * current_decay_step / decay_steps))
            k_float = final_k_topk + (total_params - final_k_topk) * cosine_decay
            k_topk = max(final_k_topk, int(k_float))
            k_random = final_k_random
        else:
            k_topk = final_k_topk
            k_random = final_k_random

        with torch.no_grad():
            num_params_A = params_A.numel()
            num_params_B = params_B.numel()

            k_A_topk = int(k_topk * (num_params_A / total_params))
            k_B_topk = k_topk - k_A_topk

            k_A_rand = int(k_random * (num_params_A / total_params))
            k_B_rand = k_random - k_A_rand

            # ===== LoRA A =====
            grad_A = params_A.grad
            if k_A_topk > 0:
                threshold_A = torch.topk(torch.abs(momentum_A.view(-1)), k_A_topk, largest=True).values[-1]
                mask_topk_A = (torch.abs(momentum_A) >= threshold_A).view(-1)
            else:
                mask_topk_A = torch.zeros_like(momentum_A).view(-1).bool()

            if k_A_rand > 0:
                rand_indices = torch.randperm(num_params_A, device=grad_A.device)[:k_A_rand]
                mask_rand_A = torch.zeros(num_params_A, device=grad_A.device).bool()
                mask_rand_A[rand_indices] = True
            else:
                mask_rand_A = torch.zeros(num_params_A, device=grad_A.device).bool()

            mask_final_A = (mask_topk_A | mask_rand_A).view_as(grad_A)
            grad_A.mul_(mask_final_A)

            # ===== LoRA B =====
            grad_B = params_B.grad
            if k_B_topk > 0:
                threshold_B = torch.topk(torch.abs(momentum_B.view(-1)), k_B_topk, largest=True).values[-1]
                mask_topk_B = (torch.abs(momentum_B) >= threshold_B).view(-1)
            else:
                mask_topk_B = torch.zeros_like(momentum_B).view(-1).bool()

            if k_B_rand > 0:
                rand_indices = torch.randperm(num_params_B, device=grad_B.device)[:k_B_rand]
                mask_rand_B = torch.zeros(num_params_B, device=grad_B.device).bool()
                mask_rand_B[rand_indices] = True
            else:
                mask_rand_B = torch.zeros(num_params_B, device=grad_B.device).bool()

            mask_final_B = (mask_topk_B | mask_rand_B).view_as(grad_B)
            grad_B.mul_(mask_final_B)

    def compute_aux_loss(self, router_probs, topk_indices, batch_size, seq_len):
        total_tokens = batch_size * seq_len
        p_i = router_probs.mean(dim=[0, 1])  # [num_experts]
        expert_counts = torch.zeros(self.num_experts, device=router_probs.device)
        for i in range(self.top_k):
            expert_indices = topk_indices[..., i].view(-1)  # [batch_size * seq_len]
            expert_counts += torch.bincount(expert_indices, minlength=self.num_experts)
        f_i = expert_counts / (total_tokens * self.top_k)  
        if self.num_experts == 0:
            aux_loss = 0
        else:
            aux_loss = self.num_experts * torch.sum(p_i * f_i)
        return aux_loss
    
    def compute_orthogonality_loss(self):
        orth_loss = 0.0
        num_experts = len(self.experts)
        for i in range(num_experts):
            A_i = self.experts[i].A
            A_shared = self.shared_expert.A
            orth_loss += torch.abs(torch.mm(A_i, A_shared.T)).sum()
            del A_i
        return orth_loss
    
    def get_aux_loss(self):
        return self.aux_loss

from transformers import AutoModelForCausalLM, AutoConfig
from transformers import Qwen3ForCausalLM, Qwen3Config
class MoELoRAQwen(Qwen3ForCausalLM):
    def __init__(self, config, num_experts=5, rank=8, top_k=2, alpha=16, task_embeddings=None,use_lora_dropout=False,):
        super().__init__(config)
        self.config = config
        self.num_experts = num_experts
        self.rank = rank
        self.top_k = top_k
        self.alpha = alpha
        self.current_task_ids = None
        self.use_lora_dropout = use_lora_dropout
        self.task_embeddings = task_embeddings
        self.task_emb_dim = 768

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
        num_experts = kwargs.pop("num_experts", 8)
        rank = kwargs.pop("rank", 8)
        top_k = kwargs.pop("top_k", 2)
        alpha = kwargs.pop("alpha", 16)
        task_embeddings = kwargs.pop("task_embeddings")
        use_lora_dropout = kwargs.pop("use_lora_dropout", False)

        model = super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
        model.num_experts = num_experts
        model.rank = rank
        model.top_k = top_k
        model.alpha = alpha
        model.task_embeddings = nn.Parameter(
            task_embeddings.clone().detach().requires_grad_(True)) # if task_embeddings is None or isinstance(task_embeddings, torch.Tensor) else torch.tensor(task_embeddings)
        model.current_task_ids = None
        model.use_lora_dropout = use_lora_dropout

        return model

    def apply_film(self):
        self.film_adapter = DynamicFiLMMoLoRAAdapter(self.task_emb_dim, rank_dim=32,use_low_rank=True) 

    def apply_moelora(self):
        print("Applying MoELoRA modifications...")
        print(f"task_embeddings:{self.task_embeddings}")

        for param in self.parameters():
            param.requires_grad = False
            
        for name, module in self.named_modules():
            if isinstance(module, nn.Linear) and "mlp" in name.lower():
                parent_module_name = '.'.join(name.split('.')[:-1])
                parent_module = self.get_submodule(parent_module_name)

                moelora_layer = MoELoRA(
                    module,
                    num_experts=self.num_experts,
                    rank=self.rank,
                    top_k=self.top_k,
                    alpha=self.alpha,
                    task_embeddings=self.task_embeddings,
                    use_lora_dropout=self.use_lora_dropout,
                    film_adapter=self.film_adapter
                )
                parent_module = self._get_parent_module(name)
                layer_name = name.split(".")[-1]
                setattr(parent_module, layer_name, moelora_layer)
            
            elif isinstance(module, nn.Linear) and "attn" in name.lower():
                parent_module_name = '.'.join(name.split('.')[:-1])
                parent_module = self.get_submodule(parent_module_name)
                wrap = True
                if wrap == True:
                    wrapped_lora_layer = WrappedLoRALayer(module, self.rank, self.alpha,self.use_lora_dropout)
                else:
                    wrapped_lora_layer = MoELoRA(
                        module,
                        num_experts=self.num_experts,
                        rank=self.rank,
                        top_k=self.top_k,
                        alpha=self.alpha,
                        task_embeddings=self.task_embeddings,
                        use_lora_dropout=self.use_lora_dropout,
                        film_adapter=self.film_adapter
                    )
                parent_module = self._get_parent_module(name)
                layer_name = name.split(".")[-1]
                setattr(parent_module, layer_name, wrapped_lora_layer)
            
        print("Re-enabling gradients for the FiLM adapter...")
        for param in self.film_adapter.parameters():
            param.requires_grad = True
        if self.task_embeddings is not None:
            self.task_embeddings.requires_grad = False
        print("MoELoRA modification and gradient setup complete.")

    def set_current_task_ids_to_layers(self):
        for module in self.modules():
            if isinstance(module, MoELoRA):
                module.current_task_ids = self.current_task_ids

    def _get_parent_module(self, module_name):
        module_path = module_name.split(".")
        parent_module = self
        for sub_module in module_path[:-1]:
            parent_module = getattr(parent_module, sub_module)
        return parent_module


class FiLMLayer(nn.Module):
    def __init__(self, condition_dim: int, features_dim: int,rank:int,use_low_rank: bool = True):
        super().__init__()
        self.features_dim = features_dim
        self.use_low_rank = use_low_rank
        if self.use_low_rank:
            self.generator_down = nn.Linear(condition_dim, rank,dtype=torch.float32)
            self.generator_up = nn.Linear(rank, features_dim * 2,dtype=torch.float32)
            nn.init.constant_(self.generator_up.weight, 0)
            nn.init.constant_(self.generator_up.bias, 0)
            nn.init.kaiming_uniform_(self.generator_down.weight, a=5)
            nn.init.constant_(self.generator_down.bias, 0)
        else:
            self.generator = nn.Linear(condition_dim, features_dim * 2,dtype=torch.float32)
            nn.init.constant_(self.generator.weight, 0)
            nn.init.constant_(self.generator.bias, 0)

    def forward(self, x: torch.Tensor, condition: torch.Tensor):
        if self.use_low_rank:
            hidden = self.generator_down(condition)
            gamma_beta = self.generator_up(hidden)
        else:
            gamma_beta = self.generator(condition)
        gamma, beta = torch.chunk(gamma_beta, 2, dim=-1)
        gamma = gamma.unsqueeze(1)
        beta = beta.unsqueeze(1)

        return x * (1 + gamma) + beta
    

class DynamicFiLMMoLoRAAdapter(nn.Module):
    def __init__(self, condition_dim: int, rank_dim: int = 16,use_low_rank: bool = True):
        super().__init__()
        self.condition_dim = condition_dim
        self.rank_dim = rank_dim
        self.film_4096 = FiLMLayer(condition_dim, 4096, rank_dim,use_low_rank)
        self.film_12288 = FiLMLayer(condition_dim, 12288, rank_dim,use_low_rank)

    def forward(self, x: torch.Tensor, task_embedding: torch.Tensor) -> torch.Tensor:
        current_features_dim = x.shape[-1]
        if current_features_dim == 4096:
            return self.film_4096(x, task_embedding)
        elif current_features_dim == 12288:
            return self.film_12288(x, task_embedding)
        else:
            raise ValueError(
                f"Unsupported features_dim: {current_features_dim}. "
                "This adapter only supports features_dim of 12288 or 4096."
            )

class DynamicFiLMMoLoRAAdapter2(nn.Module):
    def __init__(self, condition_dim: int, rank_dim: int = 16,use_low_rank: bool = True):
        super().__init__()
        self.condition_dim = condition_dim
        self.rank_dim = rank_dim
        self.film_2560 = FiLMLayer(condition_dim, 2560, rank_dim,use_low_rank)
        self.film_9728 = FiLMLayer(condition_dim, 9728, rank_dim,use_low_rank)

    def forward(self, x: torch.Tensor, task_embedding: torch.Tensor) -> torch.Tensor:
        current_features_dim = x.shape[-1]
        if current_features_dim == 2560:
            return self.film_2560(x, task_embedding)
        elif current_features_dim == 9728:
            return self.film_9728(x, task_embedding)
        else:
            raise ValueError(
                f"Unsupported features_dim: {current_features_dim}. "
                "This adapter only supports features_dim of 9728 or 2560."
            )