
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple
import math

def mlp(d_in, d_hidden, d_out, num_layers=3, act=nn.SiLU):
    layers = []
    dims = [d_in] + [d_hidden]*(num_layers-1) + [d_out]
    for i in range(len(dims)-2):
        layers += [nn.Linear(dims[i], dims[i+1]), act()]
    layers += [nn.Linear(dims[-2], dims[-1])]
    return nn.Sequential(*layers)

class LoRALinear(nn.Module):
    """
    Linear layer with K expert-specific low-rank adapters mixed by gate g (top-k optional).
    W_eff = W + sum_{i in topk} g_i A_i B_i
    """
    def __init__(self, in_features, out_features, K: int, rank: int = 4, top_k: Optional[int] = None, use_adapters: bool = True):
        super().__init__()
        self.lin = nn.Linear(in_features, out_features)
        self.K = K
        self.rank = rank
        self.top_k = top_k
        self.use_adapters = use_adapters
        if use_adapters:
            self.A = nn.Parameter(torch.zeros(K, out_features, rank))
            self.B = nn.Parameter(torch.zeros(K, rank, in_features))
            # init small
            nn.init.kaiming_uniform_(self.A, a=math.sqrt(5))
            nn.init.kaiming_uniform_(self.B, a=math.sqrt(5))

    def forward(self, x: torch.Tensor, g: Optional[torch.Tensor] = None):
        """
        x: (B, *, in_features)
        g: (B, K) gate weights in [0,1]; None => no adapters
        """
        y = self.lin(x)
        if not self.use_adapters or g is None:
            return y
        B = x.shape[0]
        # top-k routing per item
        if self.top_k is None or self.top_k >= self.K:
            idx = torch.arange(self.K, device=x.device).unsqueeze(0).expand(B, -1)
            g_vals = g
        else:
            g_vals, idx = torch.topk(g, k=self.top_k, dim=1)
        # accumulate adapters
        # Reshape for batched matmul: for each item, sum_i g_i * (A_i @ B_i @ x)
        # We'll compute y += sum_i g_i * (x @ B_i^T @ A_i^T)
        x_exp = x.unsqueeze(1)  # (B,1,*,in)
        # Prepare selected adapters
        A_sel = self.A[idx]  # (B, topk, out, r)
        B_sel = self.B[idx]  # (B, topk, r, in)
        # Compute delta = x @ B^T  => (B, topk, *, r)
        delta = torch.einsum('bni,btri->bntr', x_exp, B_sel)  # (B, topk, n, r)
        # then @ A^T: (B, topk, n, out)
        delta = torch.einsum('bntr,btor->btno', delta, A_sel)  # (B, topk, n, out)
        # weight by g
        g_w = g_vals.unsqueeze(-1).unsqueeze(-1)  # (B, topk,1,1)
        delta = (g_w * delta).sum(dim=1)  # (B, n, out)
        # Collapse n dimension if present
        # If x had shape (B, in) we now have (B, 1, out)
        if delta.dim() == 3 and delta.size(1) == 1:
            delta = delta[:,0,:]
        return y + delta

class AdapterMLP(nn.Module):
    """
    Simple MLP denoiser block with adaptered Linear layers.
    """
    def __init__(self, d_in, d_hidden, d_out, K, rank=4, top_k=2, use_adapters=True):
        super().__init__()
        self.l1 = LoRALinear(d_in, d_hidden, K, rank, top_k, use_adapters)
        self.l2 = LoRALinear(d_hidden, d_hidden, K, rank, top_k, use_adapters)
        self.l3 = LoRALinear(d_hidden, d_out, K, rank, top_k, use_adapters)
        self.act = nn.SiLU()

    def forward(self, x, g=None):
        x = self.act(self.l1(x, g))
        x = self.act(self.l2(x, g))
        x = self.l3(x, g)
        return x

class ZDenoiser(nn.Module):
    """
    Denoiser for z-coefficients. Predicts epsilon given noisy z_t, timestep and state.
    Optionally uses parameter-level MoE via adapters mixed by g.
    """
    def __init__(self, z_dim, s_dim, hidden=128, K=4, use_adapters=True, top_k=2, rank=4):
        super().__init__()
        self.embed_t = nn.Embedding(128, 32)  # discrete timesteps <= 128
        self.net = AdapterMLP(d_in=z_dim + s_dim + 32, d_hidden=hidden, d_out=z_dim,
                              K=K, rank=rank, top_k=top_k, use_adapters=use_adapters)

    def forward(self, z_t, t_idx, s, g=None):
        t_emb = self.embed_t(t_idx.clamp(max=self.embed_t.num_embeddings-1))
        x = torch.cat([z_t, s, t_emb], dim=-1)
        eps_hat = self.net(x, g)
        return eps_hat

class RDenoiser(nn.Module):
    """Optional residual denoiser for r-coefficients (orthogonal complement)."""
    def __init__(self, r_dim, s_dim, hidden=128):
        super().__init__()
        self.embed_t = nn.Embedding(128, 32)
        self.net = mlp(d_in=r_dim + s_dim + 32, d_hidden=hidden, d_out=r_dim, num_layers=3)

    def forward(self, r_t, t_idx, s):
        t_emb = self.embed_t(t_idx.clamp(max=self.embed_t.num_embeddings-1))
        x = torch.cat([r_t, s, t_emb], dim=-1)
        eps_hat = self.net(x)
        return eps_hat

class GatePosterior(nn.Module):
    """Predict Dirichlet concentration parameters for q(g_t | s_t, a_t)."""
    def __init__(self, s_dim, a_dim, K, hidden=128):
        super().__init__()
        self.net = mlp(d_in=s_dim + a_dim, d_hidden=hidden, d_out=K, num_layers=3)
        self.softplus = nn.Softplus()

    def forward(self, s, a):
        raw = self.net(torch.cat([s, a], dim=-1))
        return self.softplus(raw) + 1e-3  # concentration > 0

class GatePrior(nn.Module):
    """Predict Dirichlet concentration parameters for prior gate p(g_t | s_t) if desired (not strictly needed)."""
    def __init__(self, s_dim, K, hidden=128):
        super().__init__()
        self.net = mlp(d_in=s_dim, d_hidden=hidden, d_out=K, num_layers=2)
        self.softplus = nn.Softplus()

    def forward(self, s):
        return self.softplus(self.net(s)) + 1e-3

class BaselineActionDenoiser(nn.Module):
    """Baseline diffusion in action space; predicts epsilon_a given noisy a_t, t, s."""
    def __init__(self, a_dim, s_dim, hidden=128):
        super().__init__()
        self.embed_t = nn.Embedding(128, 32)
        self.net = mlp(d_in=a_dim + s_dim + 32, d_hidden=hidden, d_out=a_dim, num_layers=4)

    def forward(self, a_t, t_idx, s):
        t_emb = self.embed_t(t_idx.clamp(max=self.embed_t.num_embeddings-1))
        x = torch.cat([a_t, s, t_emb], dim=-1)
        return self.net(x)
