import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GINConv, global_mean_pool

def weight_init(m):
    """Custom weight init for Linear layers."""
    if isinstance(m, nn.Linear):
        nn.init.orthogonal_(m.weight.data)
        if hasattr(m, 'bias') and m.bias is not None:
            m.bias.data.fill_(0.0)

def mlp(input_dim, hidden_dim, output_dim, hidden_depth):
    """Creates an MLP with hidden_depth layers of size hidden_dim."""
    if hidden_depth == 0:
        mods = [nn.Linear(input_dim, output_dim)]
    else:
        mods = [nn.Linear(input_dim, hidden_dim), nn.ReLU(inplace=True)]
        for i in range(hidden_depth - 1):
            mods += [nn.Linear(hidden_dim, hidden_dim), nn.ReLU(inplace=True)]
        mods.append(nn.Linear(hidden_dim, output_dim))
    trunk = nn.Sequential(*mods)
    return trunk

class MlpPredictor(nn.Module):
    """MLP predictor."""
    def __init__(self, input_dim, output_dim, hidden_dim, hidden_depth):
        super().__init__()
        self.trunk = mlp(input_dim, hidden_dim, output_dim, hidden_depth)
        self.outputs = dict()
        self.apply(weight_init)
    def forward(self, obs):
        return self.trunk(obs)

class BilinearPredictor(nn.Module):
    """
    Bilinear Transduction: dot product of non-linear embeddings of data point and its delta.
    """
    def __init__(self, input_dim, output_dim, hidden_dim, feature_dim, hidden_depth):
        super().__init__()
        self.output_dim = output_dim
        self.feature_dim = feature_dim
        self.obs_trunk = mlp(input_dim, hidden_dim, feature_dim * output_dim, hidden_depth)
        self.delta_trunk = mlp(input_dim, hidden_dim, feature_dim * output_dim, hidden_depth)
    def forward(self, obs, deltas):
        ob_embedding = self.obs_trunk(obs).view(-1, self.output_dim, self.feature_dim)
        delta_embedding = self.delta_trunk(deltas).view(-1, self.feature_dim, self.output_dim)
        pred = torch.diagonal(torch.matmul(ob_embedding, delta_embedding), dim1=1, dim2=2)
        return pred


def weight_init(m):
    """Custom weight init for Linear layers."""
    if isinstance(m, nn.Linear):
        nn.init.orthogonal_(m.weight.data)
        if hasattr(m, 'bias') and m.bias is not None:
            m.bias.data.fill_(0.0)

def mlp(input_dim, hidden_dim, output_dim, hidden_depth, activation=nn.ReLU):
    """Creates an MLP with hidden_depth layers of size hidden_dim."""
    if hidden_depth == 0:
        mods = [nn.Linear(input_dim, output_dim)]
    else:
        mods = [nn.Linear(input_dim, hidden_dim), activation()]
        for _ in range(hidden_depth - 1):
            mods += [nn.Linear(hidden_dim, hidden_dim), activation()]
        mods.append(nn.Linear(hidden_dim, output_dim))
    trunk = nn.Sequential(*mods)
    return trunk

# --- Fixed Predictor ---

class MultiAnchorGNNPredictor(nn.Module):
    """
    GNN-based predictor using multiple anchors selected via a transducer,
    fusing information with multi-head attention.
    Optionally incorporates anchor weights into the fusion layer.
    """
    def __init__(self, pretrained_encoder, num_tasks, latent_dim, num_heads=4,
                 use_anchor_weights=False, num_candidates=3): # <<< Added use_anchor_weights and num_candidates
        super().__init__()
        self.encoder = pretrained_encoder
        self.latent_dim = latent_dim
        self.num_tasks = num_tasks
        self.use_anchor_weights = use_anchor_weights # <<< Store flag
        self.num_candidates = num_candidates # <<< Store k

        # Attention mechanism (cross-attention)
        self.attention = nn.MultiheadAttention(latent_dim, num_heads, batch_first=True)

        # Fusion layer input dimension depends on whether weights are used
        fusion_input_dim = latent_dim * 2
        if self.use_anchor_weights:
            fusion_input_dim += self.num_candidates # <<< Add k for the anchor weights

        # Fusion MLP
        self.fusion = nn.Linear(fusion_input_dim, latent_dim) # <<< Adjusted input dim
        self.output_layer = nn.Linear(latent_dim, num_tasks)

        # Expose the embedding extractor as obs_trunk.
        self.obs_trunk = self.extract_embedding

    def extract_embedding(self, x, edge_index, edge_attr, batch):
        """
        Uses the pretrained encoder to extract graph-level latent embeddings.
        Assumes encoder has a forward method or similar.
        Modify this if your encoder structure is different.
        """
        # Example: Directly call encoder if it handles pooling
        # return self.encoder(x, edge_index, edge_attr, batch)

        # Example: If encoder has separate gnn and pool components:
        if hasattr(self.encoder, 'gnn') and hasattr(self.encoder, 'pool'):
            node_repr = self.encoder.gnn(x, edge_index, edge_attr)
            latent = self.encoder.pool(node_repr, batch)
            return latent
        else:
            # Fallback or raise error if structure is unknown
            # This might need adjustment based on your actual pretrained_encoder
            # Ensure your encoder object itself is callable and returns graph embeddings
             return self.encoder(x=x, edge_index=edge_index, edge_attr=edge_attr, batch=batch)


    # <<< MODIFIED: Added attention_mask=None to signature >>>
    def forward(self, query_embedding, candidate_embeddings, anchor_weights=None, attention_mask=None):
        """
        Performs multi-anchor fusion via cross-attention.

        Args:
            query_embedding (torch.Tensor): Shape (batch, latent_dim)
            candidate_embeddings (torch.Tensor): Shape (batch, k, latent_dim)
            anchor_weights (torch.Tensor, optional): Shape (batch, k). Defaults to None.
            attention_mask (torch.Tensor, optional): Bool tensor, shape (batch, k).
                                                     True indicates key should be ignored. Defaults to None.

        Returns:
            torch.Tensor: Predictions of shape (batch, num_tasks)
        """
        # Apply cross-attention
        query_attn = query_embedding.unsqueeze(1)  # (batch, 1, latent_dim)

        # <<< MODIFIED: Passed attention_mask to key_padding_mask >>>
        attn_output, _ = self.attention(query=query_attn,
                                        key=candidate_embeddings,
                                        value=candidate_embeddings,
                                        key_padding_mask=attention_mask) # Pass mask here

        attn_output = attn_output.squeeze(1)  # (batch, latent_dim)

        # Fuse the query with the attended candidate information.
        fused_input = [query_embedding, attn_output]

        # <<< Optionally concatenate anchor weights <<<
        if self.use_anchor_weights:
            if anchor_weights is None:
                raise ValueError("Anchor weights must be provided when use_anchor_weights is True")
            # Ensure weights tensor has shape (batch_size, k) -> already expected
            if anchor_weights.shape[1] != self.num_candidates:
                 raise ValueError(f"Received {anchor_weights.shape[1]} anchor weights, but expected {self.num_candidates}")
            fused_input.append(anchor_weights)

        fused = torch.cat(fused_input, dim=-1) # Shape (batch, latent_dim*2 [+ k])

        # Apply fusion MLP and output layer
        fused_hidden = F.relu(self.fusion(fused))  # (batch, latent_dim)
        output = self.output_layer(fused_hidden) # (batch, num_tasks)
        return output
