"""MLP-based hypernetwork used to produce merging-weights.
"""

import math, torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Dict, List, Tuple
from experiments.hypernet_helpers import _find_module
from tqdm import tqdm
import re

# -----------------------------
# MLP-based HyperNet 
# -----------------------------

class MLPHyperNet(nn.Module):
    """MLP hypernet that generates per-column merging factors. 
    For each LoRA column, projects column features and combines them with
    conditioning tokens (site type, layer id, per-domain id, target-domain
    representation) and outputs a logit per source domain.
    """
    def __init__(
        self,
        *,
        domain_dim: int = 512,
        num_domains: int = 5,
        emb_dim: int = 128,
        hidden_dim: int = 256,
        column_dim_qkv: int = 2304,
        column_dim_proj: int = 768,
        dropout: float = 0.0,
        num_sites_max: int = 12,
    ):
        """Initialize the MLP-based hypernetwork.

        Parameters
        ----------
        domain_dim : int
            Dimension of the input domain representation.
        num_domains : int
            Number of source domains.
        emb_dim : int
            Embedding size used for all projections/tokens.
        hidden_dim : int
            Hidden size of the MLP.
        column_dim_qkv : int
            Feature length F for qkv columns.
        column_dim_proj : int
            Feature length F for proj columns.
        dropout : float
            Dropout probability in the MLP.
        num_sites_max : int
            Maximum number of sites/layers for which embeddings are defined.
        """
        super().__init__()
        self.num_domains = num_domains
        self.emb_dim = emb_dim

        # Project per-column vectors (length F) to model dim E
        self.col_proj_qkv  = nn.Linear(column_dim_qkv, emb_dim)
        self.col_proj_proj = nn.Linear(column_dim_proj, emb_dim)

        # Conditioning tokens
        self.site_embed  = nn.Embedding(2, emb_dim) # site type token: 0=qkv, 1=proj
        self.layer_embed = nn.Embedding(max(1, num_sites_max), emb_dim) # 12 blocks for ViT-B-16
        self.domain_embed = nn.Embedding(num_domains, emb_dim) # domain token
        self.dom_proj = nn.Sequential(
            nn.Linear(domain_dim, hidden_dim), nn.GELU(), nn.Linear(hidden_dim, emb_dim)
        ) # ViT-B-16 emb_dim = 512 --> project to 128

        # Per-column scorer -> 1 logit per (domain, column)
        self.mlp = nn.Sequential(
            nn.Linear(emb_dim * 4, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, 1),
        )

        # Init
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
        nn.init.normal_(self.layer_embed.weight, std=0.02)
        nn.init.normal_(self.domain_embed.weight, std=0.02)

    def forward(
        self,
        column_tensor: torch.Tensor,          # [S, D, C, F]
        site_type: str,                       # 'qkv' or 'proj'
        domain_representation: torch.Tensor,  # [domain_dim]
        mask_domain_idx: Optional[int] = None,
        site_ids: Optional[torch.Tensor] = None,  # [S]
    ) -> torch.Tensor:
        """Compute per-domain weights per column via MLP scoring.

        Parameters
        ----------
        column_tensor : torch.Tensor
            Input tensor [S, D, C, F] of per-domain column features.
        site_type : str
            Either 'qkv' or 'proj', selects projection path.
        domain_representation : torch.Tensor
            Target-domain embedding of shape [domain_dim].
        mask_domain_idx : Optional[int]
            If provided, masks this domain from contributing (logits=-inf).
        site_ids : Optional[torch.Tensor]
            Tensor [S] with site/layer indices used for layer embeddings.

        Returns
        -------
        torch.Tensor
            Weights tensor [S, D, C] summing to 1 over dim=1 for each (S,C).
        """
        assert site_type in ("qkv", "proj")
        assert site_ids is not None, "site_ids must be provided to build the layer token."
        S, D, C, F_in = column_tensor.shape
        E = self.emb_dim

        # (1) Project columns (each F-long column -> E) → [S, D, C, E]
        proj = self.col_proj_qkv if site_type == "qkv" else self.col_proj_proj
        cols = proj(column_tensor)  # [S, D, C, E]

        # (2) Add conditioning tokens
        Cp = cols.shape[2]
        dom_tok   = self.dom_proj(domain_representation).view(1, 1, 1, E).expand(S, D, Cp, E)
        layer_tok = self.layer_embed(site_ids).view(S, 1, 1, E).expand(S, D, Cp, E)
        site_tok  = self.site_embed.weight[0 if site_type == "qkv" else 1].view(1, 1, 1, E).expand(S, D, Cp, E)
        dom_ids   = torch.arange(D, device=cols.device)
        dom_id_tok = self.domain_embed(dom_ids).view(1, D, 1, E).expand(S, D, Cp, E)

        # (3) Score per (domain, column)
        x = torch.cat([cols, dom_tok, layer_tok + site_tok, dom_id_tok], dim=-1)  # [S, D, Cp, 4E]
        logits = self.mlp(x).squeeze(-1)  # [S, D, Cp]

        # (4) Optional: mask a domain from contributing
        if mask_domain_idx is not None:
            logits[:, mask_domain_idx, :] = -float("inf")

        # (5) Softmax over domains → per-column weights
        weights = torch.softmax(logits, dim=1)  # [S, D, Cp]

        return weights  # [S, D, C]
