from typing import Generator, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F

from src.model.config import ModelConfig, Transformer
from src.model.base import make_attention_mask, Rotary

class LoRA(nn.Module):
    def __init__(
        self,
        in_dim: int,
        out_dim: int,
        rank: int = 16,
    ) -> None:

        super().__init__()
        self.A = nn.Parameter(torch.empty(in_dim, rank))
        self.B = nn.Parameter(torch.empty(rank, out_dim))
        nn.init.normal_(self.A, mean=0.0, std=0.02)
        nn.init.zeros_(self.B)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x @ self.A @ self.B


class LoRALinear(nn.Module):
    def __init__(
        self,
        in_dim: int,
        out_dim: int,
        bias: bool = True,
        rank: int = 16,
        num_experts: int = 1,
    ) -> None:
        
        super().__init__()
        self.num_experts = num_experts
        self.experts = nn.ModuleList(
            [nn.Linear(in_dim, out_dim, bias=bias)] +  # core (idx 0)
            [LoRA(in_dim, out_dim, rank)      # aux  (idx 1..E-1)
             for _ in range(num_experts - 1)]
        )

    def forward(self, x: torch.Tensor, select_mask: torch.Tensor, exp_idx: int) -> torch.Tensor:
        """
        Batched parallel computation over experts.
        
        x: (B, T, in_dim)
        select_mask: (K,) boolean, multi-hot selection of experts.
        """

        K = len(self.experts)
        if K == 1:
            exp_idx = 0
        
        # Core expert (idx 0)
        y = self.experts[0](x) * select_mask[0]

        if exp_idx == 0:
            return y

        elif exp_idx > 0:
            return y + self.experts[exp_idx](x)

        else:
            
            # Stack all LoRA A and B matrices
            aux_A = torch.stack([e.A for e in self.experts[1:]], dim=0)  # (K-1, in_dim, rank)
            aux_B = torch.stack([e.B for e in self.experts[1:]], dim=0)  # (K-1, rank, out_dim)
            
            # Batched LoRA computation: x @ A @ B for all aux experts
            # Step 1: x @ A -> (B, T, in_dim) with (K-1, in_dim, rank) -> (B, T, K-1, rank)
            aux_hidden = torch.einsum('bte,ker->btkr', x, aux_A)
            
            # Step 2: hidden @ B -> (B, T, K-1, rank) with (K-1, rank, out_dim) -> (B, T, K-1, out_dim)
            aux_output = torch.einsum('btkr,kro->btko', aux_hidden, aux_B)

            # Apply select mask to aux outputs
            aux_select_mask = select_mask[1:]
            aux_output = aux_output * aux_select_mask.view(1, 1, K-1, 1)
            aux_output = aux_output.sum(dim=2)  # (B, T, out_dim)
            
            # Sum LoRA outputs with core output
            return y + aux_output  # (B, T, out_dim)

    def get_params(self, expert_idx: int) -> Generator[torch.Tensor, None, None]:
        assert 0 <= expert_idx < self.num_experts, f"Expert index {expert_idx} out of range"
        yield from self.experts[expert_idx].parameters()


class Attention(nn.Module):
    def __init__(
        self,
        num_heads: int,
        embed_dim: int,
        num_key_value: int,
        bias: bool,
        num_experts: int,
        lora_rank: int,
    ) -> None:
    
        super().__init__()

        assert embed_dim % num_heads == 0
        assert num_heads % num_key_value == 0

        self.num_heads = num_heads
        self.embed_dim = embed_dim
        self.num_key_value = num_key_value
        self.head_dim = embed_dim // num_heads

        self.rotary = Rotary(embed_dim // num_heads)

        lora_args = {
            "num_experts": num_experts,
            "rank": lora_rank,
        }

        self.c_attn_q = LoRALinear(
            in_dim=embed_dim, 
            out_dim=embed_dim,
            bias=bias,
            **lora_args)

        self.c_attn_kv = LoRALinear(
            in_dim=embed_dim, 
            out_dim=2 * num_key_value * self.head_dim, 
            bias=bias,
            **lora_args)

        self.c_proj = LoRALinear(
            in_dim=embed_dim, 
            out_dim=embed_dim, 
            bias=bias,
            **lora_args)

    def forward(
        self, x: torch.Tensor,
        attn_mask: torch.Tensor,
        select_mask: torch.Tensor,
        exp_idx: int,
    ) -> torch.Tensor:
    
        B, T, E = x.size()

        q = self.c_attn_q(x, select_mask, exp_idx)
        q = q.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)

        k = self.c_attn_kv(x, select_mask, exp_idx)
        k, v = k.split(self.num_key_value * self.head_dim, dim=2)

        k = k.view(B, T, self.num_key_value, self.head_dim).transpose(1, 2)
        v = v.view(B, T, self.num_key_value, self.head_dim).transpose(1, 2)
        q, k = self.rotary(q), self.rotary(k)

        y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, enable_gqa=True)
        y = y.transpose(1, 2).contiguous().view(B, T, E)
        y = self.c_proj(y, select_mask, exp_idx)

        return y


class MLP(nn.Module):
    def __init__(
        self,
        embed_dim: int,
        mlp_dim: int,
        num_experts: int,
        lora_rank: int,
    ) -> None:

        super().__init__()

        lora_args = {
            "num_experts": num_experts,
            "rank": lora_rank,
        }

        self.c_fc = LoRALinear(
            in_dim=embed_dim, 
            out_dim=mlp_dim,
            bias=True,
            **lora_args)

        self.c_proj = LoRALinear(
            in_dim=mlp_dim, 
            out_dim=embed_dim, 
            bias=True,
            **lora_args)

    def forward(
        self, x: torch.Tensor, 
        select_mask: torch.Tensor, 
        exp_idx: int,
    ) -> torch.Tensor:

        h = self.c_fc(x, select_mask, exp_idx)
        h = F.gelu(h, approximate="tanh")
        h = self.c_proj(h, select_mask, exp_idx)
        return h


class Block(nn.Module):

    def __init__(
        self,
        config: ModelConfig,
        num_experts: int,
        mlp_dim: int,
        lora_rank: int,
        lora_attn: bool,
        lora_mlp: bool,
    ) -> None:

        super().__init__()

        num_attn_experts = num_experts if lora_attn else 1
        num_mlp_experts  = num_experts if lora_mlp else 1

        self.attn = Attention(
            num_heads=config.num_heads,
            embed_dim=config.embed_dim,
            num_key_value=config.num_key_value,
            bias=config.attn_bias,
            num_experts=num_attn_experts,
            lora_rank=lora_rank,
        )

        self.mlp = MLP(
            embed_dim=config.embed_dim,
            mlp_dim=mlp_dim,
            num_experts=num_mlp_experts,
            lora_rank=lora_rank,
        )

        self.norm_1 = nn.RMSNorm(config.embed_dim)
        self.norm_2 = nn.RMSNorm(config.embed_dim)
        
    def forward(
        self, 
        x: torch.Tensor, 
        attn_mask: torch.Tensor, 
        select_mask: torch.Tensor,
        exp_idx: int,
    ) -> torch.Tensor:
        
        x = x + self.attn(self.norm_1(x), attn_mask=attn_mask, select_mask=select_mask, exp_idx=exp_idx)
        x = x + self.mlp(self.norm_2(x), select_mask=select_mask, exp_idx=exp_idx)
        return x


class LoRATransformer(Transformer):
    def __init__(
        self, 
        config: ModelConfig, 
        mlp_dim: int,
        lora_rank: int,
        lora_attn: bool, 
        lora_mlp: bool, 
    ) -> None:

        super().__init__(config)

        self.model_type = "routed"
        self.labels = ["core"] + config.aux_labels

        assert lora_attn or lora_mlp, "At least one of lora_attn or lora_mlp must be True"
        assert lora_rank > 0, "lora_rank must be greater than 0"

        self.embed = nn.Embedding(config.vocab_size, config.embed_dim)
        self.norm = nn.RMSNorm(config.embed_dim)
        self.unembed = nn.Linear(config.embed_dim, config.vocab_size, bias=True)
        
        self.blocks = []
        for idx in range(config.num_layers):
            num_experts = len(self.labels) if idx in config.target_layers else 1
            self.blocks.append(Block(
                config=config,
                num_experts=num_experts,
                mlp_dim=mlp_dim,
                lora_rank=lora_rank,
                lora_attn=lora_attn,
                lora_mlp=lora_mlp,
            ))
        self.blocks = nn.ModuleList(self.blocks)

        self.apply(self._init_weights)

    def _init_weights(self, module: nn.Module) -> None:
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=0.02)
            if hasattr(module, "bias") and getattr(module, "bias") is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.RMSNorm):
            module.weight.data.fill_(1.0)

    def get_params(self, label: str) -> Generator[torch.Tensor, None, None]:

        assert label in self.labels, f"Label {label} not found in {self.labels}"

        if label == "core":
            yield from self.embed.parameters()
            yield from self.norm.parameters()
            yield from self.unembed.parameters()

        e_idx = self.labels.index(label)

        for block in self.blocks:

            for name in ["c_attn_q", "c_attn_kv", "c_proj"]:
                module = getattr(block.attn, name)
                if e_idx < len(module.experts):
                    yield from module.experts[e_idx].parameters()

            for name in ["c_fc", "c_proj"]:
                module = getattr(block.mlp, name)
                if e_idx < len(module.experts):
                    yield from module.experts[e_idx].parameters()

            if label == "core":
                yield from block.norm_1.parameters()
                yield from block.norm_2.parameters()

    def forward(
        self,
        tokens: torch.Tensor,
        targets: Optional[torch.Tensor],
        select_mask: torch.Tensor,
        optimize: bool = True,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:

        def run_stack(
            tok: torch.Tensor, 
            select_mask: torch.Tensor,
            exp_idx: int,
        ) -> torch.Tensor:

            attn_mask = make_attention_mask(tok, self.config.eos_token_id)
            h = self.embed(tok)
            for block in self.blocks:
                h = block(h, attn_mask=attn_mask, select_mask=select_mask, exp_idx=exp_idx)
            h = self.norm(h)
            return self.unembed(h)

        select_mask[0] = True #core is always on
        exp_idx = -1
        if optimize and select_mask.sum() <= 2:
            exp_idx = select_mask.nonzero()[-1].item() #rightmost expert

        logits = run_stack(tokens, select_mask=select_mask, exp_idx=exp_idx) # (B, T, V)

        loss = None
        if targets is not None:
            loss = F.cross_entropy(
                logits.view(-1, logits.size(-1)),
                targets.reshape(-1),
                ignore_index=-1000,
                reduction="mean",
            )

        return logits, loss