import torch
import torch.nn as nn

def weight_init(m):
    """Custom weight init for Linear layers."""
    if isinstance(m, nn.Linear):
        nn.init.orthogonal_(m.weight.data)
        if hasattr(m, 'bias') and m.bias is not None:
            m.bias.data.fill_(0.0)
            
def mlp(input_dim, hidden_dim, output_dim, hidden_depth, activation=nn.ReLU):
    """Creates an MLP with hidden_depth layers of size hidden_dim."""
    if hidden_depth == 0:
        mods = [nn.Linear(input_dim, output_dim)]
    else:
        mods = [nn.Linear(input_dim, hidden_dim), activation()]
        for _ in range(hidden_depth - 1):
            mods += [nn.Linear(hidden_dim, hidden_dim), activation()]
        mods.append(nn.Linear(hidden_dim, output_dim))
    trunk = nn.Sequential(*mods)
    # Apply weight initialization to MLP layers
    # trunk.apply(weight_init) # Optional: Apply here or apply to the whole parent module
    return trunk

# --- Feature Encoder (for embedding_source='feature') ---

class FeatureEncoderMLP(nn.Module):
    """ MLP to encode raw input features into embeddings. """
    def __init__(self, input_dim, output_embedding_dim, hidden_dim, hidden_depth):
        super().__init__()
        self.mlp = mlp(input_dim, hidden_dim, output_embedding_dim, hidden_depth)
        # print(f"Initialized FeatureEncoderMLP: Input {input_dim} -> Output {output_embedding_dim}")
        self.apply(weight_init) # Apply weight init to this module

    def forward(self, features):
        return self.mlp(features)
    

class MultiAnchorFeaturePredictor(nn.Module):
    """
    Feature-based predictor using multiple anchors selected via a transducer,
    fusing information with multi-head attention.
    Takes embeddings generated by FeatureEncoderMLP as input.
    Optionally incorporates anchor weights into the fusion layer.
    """
    # Note: Arguments match MultiAnchorGNNPredictor except for pretrained_encoder -> input_dim
    def __init__(self, cfg, input_dim, num_tasks):
        super().__init__()
        # input_dim here is the EMBEDDING dimension (output of FeatureEncoderMLP)
        self.encoder = FeatureEncoderMLP(input_dim=input_dim, 
                                         output_embedding_dim=cfg.model.encoder.output_dim,
                                         hidden_dim=cfg.model.encoder.hidden_dim,
                                         hidden_depth=cfg.model.encoder.hidden_depth)
        self.num_tasks = num_tasks
        self.embedding_dim = cfg.model.encoder.output_dim
        self.latent_dim=cfg.model.latent_dim
        self.num_heads=cfg.model.num_heads
        self.use_anchor_weights=cfg.transducer.use_anchor_weights
        self.num_candidates=cfg.model.num_candidates

        # Project input embeddings (optional, could use input_dim directly if = latent_dim)
        # If input_dim might differ from internal latent_dim, add projections:
        self.query_proj = nn.Linear(self.embedding_dim, self.latent_dim) if self.embedding_dim != self.latent_dim else nn.Identity()
        self.anchor_proj = nn.Linear(self.embedding_dim, self.latent_dim) if self.embedding_dim != self.latent_dim else nn.Identity()

        # Attention mechanism (operates on latent_dim)
        self.attention = nn.MultiheadAttention(self.latent_dim, self.num_heads, batch_first=True)

        # Fusion layer input dimension
        fusion_input_dim = self.latent_dim * 2 # Projected Query + Attended Anchors
        if self.use_anchor_weights:
            fusion_input_dim += self.num_candidates # Add k for the anchor weights

        # Fusion MLP and Output Layer
        self.fusion = mlp(fusion_input_dim, self.latent_dim, self.latent_dim, hidden_depth=1)
        self.output_layer = nn.Linear(self.latent_dim, num_tasks)

        # Apply weight init AFTER defining layers
        self.apply(weight_init)

    def extract_embedding(self, features):
        # This method might be called by the trainer for consistency,
        # but for this predictor, the input *is* already the embedding
        # (or needs projection if input_dim != latent_dim).
        # However, the trainer should ideally get embeddings from FeatureEncoderMLP *before* calling the predictor.
        # If called, just return the input (it's already an embedding).
        return self.encoder(features)


    def forward(self, query_embedding, candidate_embeddings, anchor_weights=None, attention_mask=None):
        """
        Performs multi-anchor fusion via cross-attention using feature embeddings.

        Args:
            query_embedding (torch.Tensor): Shape (batch, embedding_dim) - Output of FeatureEncoderMLP
            candidate_embeddings (torch.Tensor): Shape (batch, k, embedding_dim) - Output of FeatureEncoderMLP
            anchor_weights (torch.Tensor, optional): Shape (batch, k). Defaults to None.
            attention_mask (torch.Tensor, optional): Bool tensor, shape (batch, k). True=ignore.
        """
        batch_size = query_embedding.shape[0]
        k = candidate_embeddings.shape[1]

        # 1. Project embeddings to latent dimension if needed
        query_latent = self.query_proj(query_embedding)     # (batch, latent_dim)
        # Reshape anchors for projection
        cand_emb_flat = candidate_embeddings.view(batch_size * k, self.embedding_dim)
        anchor_latent_flat = self.anchor_proj(cand_emb_flat)
        anchor_latent = anchor_latent_flat.view(batch_size, k, self.latent_dim) # (batch, k, latent_dim)

        # 2. Apply cross-attention
        query_attn = query_latent.unsqueeze(1)  # (batch, 1, latent_dim)

        attn_output, _ = self.attention(query=query_attn,
                                        key=anchor_latent,  # Use projected anchors
                                        value=anchor_latent, # Use projected anchors
                                        key_padding_mask=attention_mask)

        attn_output = attn_output.squeeze(1)  # (batch, latent_dim)

        # 3. Fuse the query with the attended candidate information.
        fused_input = [query_latent, attn_output] # Use projected query

        if self.use_anchor_weights:
            if anchor_weights is None: raise ValueError("Anchor weights required")
            if anchor_weights.shape[1] != self.num_candidates: raise ValueError("Incorrect number of anchor weights")
            fused_input.append(anchor_weights)
        fused = torch.cat(fused_input, dim=-1)

        # 4. Apply fusion MLP and output layer
        fused_hidden = self.fusion(fused) # Activation (ReLU) is inside mlp utility
        output = self.output_layer(fused_hidden)
        return output
