"""Cross-attention hypernetwork used to produce merging-weights.
Takes per-domain LoRA columns and a target-domain representation and produces
attention-based weights over source domains. 
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional

class CrossAttnHyperNet(nn.Module):
    """Attention-based hypernetwork that outputs domain merging weights.

    Given a set of per-domain LoRA columns and a target-domain embedding,
    computes attention logits/weights over source domains. Supports three
    output granularities via 'output_mode':
    - 'column': weights per layer and per column
    - 'layer' : weights per layer (columns pooled)
    - 'model' : one global weight per source model
    """

    def __init__(self, *,
        domain_dim: int = 512,
        num_domains: int = 5,
        emb_dim: int = 128,
        hidden_dim: int = 256,
        column_dim_qkv: int = 2304,
        column_dim_proj: int = 768,
        cross_nheads: int = 1,
        head_combine: str = "mean",
        num_sites_max: int = 12,
        site_embed_dim: int = 32,
        temperature: float = 1.0,
        use_domain_token: bool = True,
        domain_token_scale: float = 0.1, 
        output_mode: str = "column",   # 'column' (default), 'layer', 'model'
    ):
        """Initialize the cross-attention hypernetwork.

        Parameters
        ----------
        domain_dim : int
            Dimension of the input domain representation.
        num_domains : int
            Number of source domains.
        emb_dim : int
            Embedding size used internally for projections and attention.
        hidden_dim : int
            Hidden size of the MLP that projects the domain representation.
        column_dim_qkv : int
            Feature size of qkv columns to be projected.
        column_dim_proj : int
            Feature size of proj columns to be projected.
        cross_nheads : int
            Number of attention heads.
        head_combine : str
            How to combine head logits (currently 'mean').
        num_sites_max : int
            Maximum number of sites/layers for which embeddings are defined.
        site_embed_dim : int
            Dimension of the site/layer index embedding.
        temperature : float
            Temperature applied to attention logits before softmax.
        use_domain_token : bool
            Whether to add learned per-domain prior tokens to content.
        domain_token_scale : float
            Scale for the domain prior tokens when added to content.
        output_mode : str
            One of {'column','layer','model'} determining pooling behavior.
        """
        super().__init__()
        self.num_domains = num_domains
        self.emb_dim = emb_dim
        self.cross_nheads = max(1, cross_nheads)
        self.head_combine = head_combine
        assert output_mode in ("column", "layer", "model"), f"invalid output_mode={output_mode}"
        self.output_mode = output_mode
        self.temperature = temperature

        # projections
        self.col_proj_qkv  = nn.Linear(column_dim_qkv,  emb_dim)
        self.col_proj_proj = nn.Linear(column_dim_proj, emb_dim)
        self.dom_proj = nn.Sequential(
            nn.Linear(domain_dim, hidden_dim), nn.GELU(), nn.Linear(hidden_dim, emb_dim)
        )

        # layer token
        self.enable_site_embed = (num_sites_max > 0 and site_embed_dim > 0)
        if self.enable_site_embed:
            self.site_embed = nn.Embedding(num_sites_max, site_embed_dim)
            self.site_embed_proj = nn.Linear(site_embed_dim, emb_dim)

        # layer-type token
        self.site_type_embed = nn.Embedding(2, emb_dim)  # 0=qkv, 1=proj

        # per-source domain token prior
        self.use_domain_token = use_domain_token
        if use_domain_token:
            tok_dim = max(8, emb_dim // 4)
            self.domain_id_embed = nn.Embedding(num_domains, tok_dim)
            self.domain_id_proj  = nn.Linear(tok_dim, emb_dim, bias=False)
            self.domain_token_scale = domain_token_scale

        self.key_proj   = nn.Linear(emb_dim, emb_dim, bias=False)
        self.query_proj = nn.Linear(emb_dim, emb_dim, bias=False)

        self.ln_cols  = nn.LayerNorm(emb_dim)
        self.ln_query = nn.LayerNorm(emb_dim)

        self._init_params()

    def forward(self,
        column_tensor: torch.Tensor,          # [S, D, C, F]
        site_type: str,                       # 'qkv' or 'proj'
        domain_representation: torch.Tensor,  # [E'] 
        mask_domain_idx: Optional[int] = None,
        site_ids: Optional[torch.Tensor] = None,    # [S]
        domain_ids: Optional[torch.Tensor] = None,  # [D]
    ) -> torch.Tensor:
        """Compute attention weights over source domains.

        Parameters
        ----------
        column_tensor : torch.Tensor
            Tensor of per-domain parameter LoRA columns with shape [S, D, C, F],
            where S=sites/layers, D=domains, C=columns per site, F=features.
        site_type : str
            Either 'qkv' or 'proj', selects which projection to use.
        domain_representation : torch.Tensor
            Target-domain embedding of shape [E'].
        mask_domain_idx : Optional[int]
            If provided, mask this source domain (weight goes to 0).
        site_ids : Optional[torch.Tensor]
            Optional site indices [S] used for site embeddings.
        domain_ids : Optional[torch.Tensor]
            Optional domain indices [D] used for domain prior tokens.

        Returns
        -------
        torch.Tensor
            Weights over domains shaped according to 'output_mode':
            - 'column': [S, D, C]
            - 'layer' : [S, D, C] (broadcast from [S, D])
            - 'model' : [S, D, C] (broadcast from [D])
        """
        S0, D, C0, _ = column_tensor.shape
        assert D == self.num_domains and site_type in ("qkv","proj")

        proj = self.col_proj_qkv if site_type == "qkv" else self.col_proj_proj

        # 1) Project columns
        col_emb = proj(column_tensor)                      # [S,D,C,E]

        # 2) Add layer/site context tokens
        if self.output_mode != "model":
            if self.enable_site_embed and site_ids is not None:
                se = self.site_embed_proj(self.site_embed(site_ids)).view(S0,1,1,-1)
                col_emb = col_emb + se                          # layer/site context
            st = self.site_type_embed.weight[0 if site_type=="qkv" else 1].view(1,1,1,-1)
            col_emb = col_emb + st                              # site-type (qkv/proj) content cue
        
        # 3) Normalize content -> comparable keys
        col_emb = self.ln_cols(col_emb)

        # 4) Domain-bias priors token
        if self.use_domain_token:
            if domain_ids is None:
                domain_ids = torch.arange(D, device=col_emb.device)
            dom_tok = self.domain_id_proj(self.domain_id_embed(domain_ids)).view(1, D, 1, -1)
            col_emb = col_emb + self.domain_token_scale * dom_tok   
        
        # 4.5) Pre-attention pooling for ablations (per-model, per-layer)
        if self.output_mode == "layer":
            # Pool across columns only -> per-layer embedding
            col_emb = col_emb.mean(dim=2, keepdim=True)  # [S, D, 1, E]
        elif self.output_mode == "model":
            # Pool across sites and columns -> domain-only embedding
            col_emb = col_emb.mean(dim=(0, 2), keepdim=True)  # [1, D, 1, E]

        E, H = self.emb_dim, self.cross_nheads
        Eh = E // H

        # 5) Calculate keys
        S_eff, _, _, _ = col_emb.shape
        K = self.key_proj(col_emb).view(S_eff, D, -1, H, Eh)  # [S_eff,D,C_eff*,H,Eh]
        q_dom = self.dom_proj(domain_representation)
        #========================================================
        # Ablation --> Use noise instead of domain representation
        #q_dom = torch.randn_like(q_dom)
        #========================================================
        q_dom = self.ln_query(q_dom)
        Q = self.query_proj(q_dom).view(H, Eh)                # [H,Eh]
        logits_h = torch.einsum("sdchf,hf->hsdc", K, Q)       # [H,S_eff,D,C_eff*]
        logits_h = logits_h / self.temperature
        logits = logits_h.mean(dim=0)                             # [S_eff,D,C_eff*]

        debug = False
        if self.output_mode == "column":
            if mask_domain_idx is not None:
                logits[:, mask_domain_idx, :] = float("-inf")
            weights = torch.softmax(logits, dim=1)   # [S0,D,C0]
            if debug:
                softmax_dim = 1
                print(f"[CrossAttnHyperNet] output_mode=column, softmax_dim={softmax_dim}, logits.shape={tuple(logits.shape)}, weights.shape={tuple(weights.shape)}")
                sums = weights.sum(dim=softmax_dim)
                print(f"[CrossAttnHyperNet] sum over dim={softmax_dim}: shape={tuple(sums.shape)}, min={sums.min().item():.6f}, max={sums.max().item():.6f}")
                assert torch.allclose(sums, torch.ones_like(sums), atol=1e-5), "Column-mode weights do not sum to 1 along dim=1"
            return weights

        elif self.output_mode == "layer":
            logits_layer = logits.mean(dim=2)  # [S0,D]
            if mask_domain_idx is not None:
                logits_layer[:, mask_domain_idx] = float("-inf")
            weights_layer = torch.softmax(logits_layer, dim=1)  # [S0,D]
            if debug:
                softmax_dim = 1
                print(f"[CrossAttnHyperNet] output_mode=layer, softmax_dim={softmax_dim}, logits_layer.shape={tuple(logits_layer.shape)}, weights_layer.shape={tuple(weights_layer.shape)}")
                sums_layer = weights_layer.sum(dim=softmax_dim)
                print(f"[CrossAttnHyperNet] sum over dim={softmax_dim} (layer logits): shape={tuple(sums_layer.shape)}, min={sums_layer.min().item():.6f}, max={sums_layer.max().item():.6f}")
                assert torch.allclose(sums_layer, torch.ones_like(sums_layer), atol=1e-5), "Layer-mode weights_layer do not sum to 1 along dim=1"
            weights = weights_layer.unsqueeze(-1).expand(S0, D, C0)  # [S0,D,C0]
            if debug:
                sums = weights.sum(dim=1)
                print(f"[CrossAttnHyperNet] expanded weights.shape={tuple(weights.shape)}, sum over dim=1: shape={tuple(sums.shape)}, min={sums.min().item():.6f}, max={sums.max().item():.6f}")
                assert torch.allclose(sums, torch.ones_like(sums), atol=1e-5), "Layer-mode expanded weights do not sum to 1 along dim=1"
            return weights

        else:  # 'model'
            logits_model = logits.mean(dim=(0, 2))  # [D]
            if mask_domain_idx is not None:
                logits_model[mask_domain_idx] = float("-inf")
            weights_model = torch.softmax(logits_model, dim=0)  # [D]
            if debug:
                softmax_dim = 0
                print(f"[CrossAttnHyperNet] output_mode=model, softmax_dim={softmax_dim}, logits_model.shape={tuple(logits_model.shape)}, weights_model.shape={tuple(weights_model.shape)}")
                sum_model = weights_model.sum(dim=softmax_dim)
                print(f"[CrossAttnHyperNet] sum over dim={softmax_dim} (model logits): value={sum_model.item():.6f}")
                assert torch.allclose(sum_model, torch.ones_like(sum_model), atol=1e-5), "Model-mode weights_model do not sum to 1 along dim=0"
            weights = weights_model.view(1, D, 1).expand(S0, D, C0)  # [S0,D,C0]
            if debug:
                sums = weights.sum(dim=1)
                print(f"[CrossAttnHyperNet] expanded weights.shape={tuple(weights.shape)}, sum over dim=1: shape={tuple(sums.shape)}, min={sums.min().item():.6f}, max={sums.max().item():.6f}")
                assert torch.allclose(sums, torch.ones_like(sums), atol=1e-5), "Model-mode expanded weights do not sum to 1 along dim=1"
            return weights

    def _init_params(self):
        """Initialize learnable parameters."""
        def init_linear(m):
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None: nn.init.zeros_(m.bias)
        self.apply(init_linear)
        if hasattr(self, "site_embed"):
            nn.init.normal_(self.site_embed.weight, std=0.02)
        if hasattr(self, "site_embed_proj"):
            nn.init.xavier_uniform_(self.site_embed_proj.weight); nn.init.zeros_(self.site_embed_proj.bias)
        if hasattr(self, "domain_id_embed"):
            nn.init.normal_(self.domain_id_embed.weight, std=0.02)
        if hasattr(self, "domain_id_proj"):
            nn.init.xavier_uniform_(self.domain_id_proj.weight)

