import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import (
    GCNConv, GATv2Conv, TransformerConv,
    GINConv, SAGEConv
)

class GNN(nn.Module):
    def __init__(
        self,
        in_dim,
        hidden_dim,
        out_dim,
        model_type='gcn',
        num_layers=2,
        heads=4,
        dropout=0.2,
        mode='embedder' # or 'predictor' 
    ):
        super().__init__()
        self.model_type = model_type.lower()
        self.mode = mode
        self.dropout = dropout
        self.hidden_dim = hidden_dim

        self.encoder = nn.ModuleList()
        self.num_layers = num_layers

        if self.model_type == 'mlp':
            self.encoder.append(nn.Linear(in_dim, hidden_dim))
            for _ in range(num_layers - 1):
                self.encoder.append(nn.Linear(hidden_dim, hidden_dim))

        elif self.model_type == 'gcn':
            self.encoder.append(GCNConv(in_dim, hidden_dim))
            for _ in range(num_layers - 2):
                self.encoder.append(GCNConv(hidden_dim, hidden_dim))
            self.encoder.append(GCNConv(hidden_dim, hidden_dim))
        elif self.model_type == 'gat':
            self.encoder.append(GATv2Conv(in_dim, hidden_dim, heads=heads, concat=True))
            for _ in range(num_layers - 2):
                self.encoder.append(GATv2Conv(hidden_dim * heads, hidden_dim, heads=heads, concat=True))
            self.encoder.append(GATv2Conv(hidden_dim * heads, out_dim, heads=1, concat=False))

        elif self.model_type == 'transformer':
            self.encoder.append(TransformerConv(in_dim, hidden_dim, heads=heads))
            for _ in range(num_layers - 2):
                self.encoder.append(TransformerConv(hidden_dim * heads, hidden_dim, heads=heads))
            self.encoder.append(TransformerConv(hidden_dim * heads, out_dim, heads=1))

        elif self.model_type == 'gin':
            self.input_proj = nn.Linear(in_dim, hidden_dim)
        
            for _ in range(num_layers):
                mlp = nn.Sequential(
                    nn.Linear(hidden_dim, hidden_dim),
                    nn.ReLU(),
                    nn.Linear(hidden_dim, hidden_dim)
                )
                self.encoder.append(GINConv(mlp))

        elif self.model_type == 'sage':
            self.encoder.append(SAGEConv(in_dim, hidden_dim))
            for _ in range(num_layers - 1):
                self.encoder.append(SAGEConv(hidden_dim, hidden_dim))

        else:
            raise ValueError(f"Unknown model type: {self.model_type}")

        if self.mode == 'predictor':
            self.head = nn.Linear(hidden_dim, out_dim)

    def forward(self, x, edge_index):
        if self.model_type == 'gin':
            x = self.input_proj(x)
        for i, layer in enumerate(self.encoder):
            if self.model_type in ['mlp']:
                x = layer(x)
            else:
                x = layer(x, edge_index)
            if i < self.num_layers-1 or 1:
                x = F.relu(x)
                x = F.dropout(x, p=self.dropout, training=self.training)
        if self.mode == 'embedder':
            return x
        else:
            return self.head(x)



def build_mlp(input_dim, hidden_dim, output_dim, num_layers):
    """Builds a simple MLP with configurable depth."""
    layers = []
    if num_layers == 1:
        layers.append(nn.Linear(input_dim, output_dim))
    else:
        layers.append(nn.Linear(input_dim, hidden_dim))
        layers.append(nn.ReLU())
        for _ in range(num_layers - 2):
            layers.append(nn.Linear(hidden_dim, hidden_dim))
            layers.append(nn.ReLU())
        layers.append(nn.Linear(hidden_dim, output_dim))
    return nn.Sequential(*layers)


class GraphCBM(nn.Module):
    def __init__(self, 
                 gnn: nn.Module,
                 text_emb_dim: int,
                 n_concepts: int, 
                 hidden_dim: int,
                 out_dim: int,
                 mlp_layers: int = 2):
        super().__init__()

        # Use pre-defined GNN in embedder mode
        self.gnn = gnn
        assert getattr(self.gnn, "mode", None) == "embedder", "GNN must be in 'embedder' mode"

        # Project text embedding into same space as GNN
        self.text_proj = nn.Linear(text_emb_dim, hidden_dim)

        # Flexible MLP on cosine similarity
        self.classifier = build_mlp(input_dim=n_concepts, hidden_dim=hidden_dim, output_dim=out_dim, num_layers=mlp_layers)

    def forward(self, data, text_emb):
        # 1. Graph embedding (mean over nodes)
        gnn_emb = self.gnn(data.x, data.edge_index)

        # 2. Project provided text embedding
        text_emb = self.text_proj(text_emb)  # [B, hidden_dim]

        # 3. Cosine similarity
        # cosine_sim = F.cosine_similarity(gnn_emb, text_emb, dim=-1)
        cosine_sim = F.cosine_similarity(gnn_emb.unsqueeze(1), text_emb.unsqueeze(0), dim=-1)

        # 4. Classify
        return self.classifier(cosine_sim)


class MLP(torch.nn.Module):
    def __init__(self, in_channels, out_channels, num_layers, hidden_channels=64, dropout=0):
        super(MLP, self).__init__()
        self.num_layers = num_layers
        self.layers = torch.nn.ModuleList()
        if num_layers == 1:
            self.layers.append(torch.nn.Linear(in_channels, out_channels))
        else:
            self.layers.append(torch.nn.Linear(in_channels, hidden_channels))
            for _ in range(num_layers-2):
                self.layers.append(torch.nn.Linear(hidden_channels, hidden_channels))
            self.layers.append(torch.nn.Linear(hidden_channels, out_channels))
        self.dropout = torch.nn.Dropout(p=dropout)

    def forward(self, x):
        for i in range(self.num_layers):
            x = self.layers[i](x)
            if i < self.num_layers-1:
                x = self.dropout(x)
        log_probs = F.log_softmax(x, dim=1)
        return log_probs
