"""
model.py
电路相似度学习模型定义
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool, global_max_pool, SAGEConv
from typing import Optional


class CircuitGraphEncoder(nn.Module):
    """图神经网络编码器 - 处理电路的图结构"""
    
    def __init__(self, 
                 node_feature_dim: int = 7,
                 edge_feature_dim: int = 32,
                 hidden_dim: int = 128,
                 output_dim: int = 256,
                 num_layers: int = 3,
                 dropout: float = 0.2):
        super(CircuitGraphEncoder, self).__init__()
        
        self.node_feature_encoder = nn.Linear(node_feature_dim, hidden_dim)
        
        self.convs = nn.ModuleList()
        self.batch_norms = nn.ModuleList()
        
        for i in range(num_layers):
            self.convs.append(GCNConv(hidden_dim, hidden_dim))
            self.batch_norms.append(nn.BatchNorm1d(hidden_dim))
        
        self.fc1 = nn.Linear(hidden_dim * 2, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, edge_index, batch=None):
        # 编码节点特征
        x = self.node_feature_encoder(x)
        
        # 图卷积层
        for i, (conv, bn) in enumerate(zip(self.convs, self.batch_norms)):
            x_res = x
            x = conv(x, edge_index)
            x = bn(x)
            x = F.relu(x)
            x = self.dropout(x)
            # 残差连接
            if i > 0:
                x = x + x_res
        
        # 图级池化
        if batch is None:
            batch = torch.zeros(x.size(0), dtype=torch.long, device=x.device)
        
        mean_pool = global_mean_pool(x, batch)
        max_pool = global_max_pool(x, batch)
        x = torch.cat([mean_pool, max_pool], dim=1)
        
        # 全连接层
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        
        return x

class CircuitGraphEncoder(nn.Module):
    """图神经网络编码器 - 使用GraphSAGE处理电路的图结构"""
    
    def __init__(self, 
                 node_feature_dim: int = 7,
                 edge_feature_dim: int = 32,
                 hidden_dim: int = 128,
                 output_dim: int = 256,
                 num_layers: int = 3,
                 dropout: float = 0.2):
        super(CircuitGraphEncoder, self).__init__()
        
        self.node_feature_encoder = nn.Linear(node_feature_dim, hidden_dim)
        
        # 使用GraphSAGE卷积层替代GCN
        self.convs = nn.ModuleList()
        self.batch_norms = nn.ModuleList()
        
        # 第一层
        self.convs.append(SAGEConv(hidden_dim, hidden_dim))
        self.batch_norms.append(nn.BatchNorm1d(hidden_dim))
        
        # 中间层
        for i in range(1, num_layers):
            self.convs.append(SAGEConv(hidden_dim, hidden_dim))
            self.batch_norms.append(nn.BatchNorm1d(hidden_dim))
        
        # 输出层
        self.fc1 = nn.Linear(hidden_dim * 2, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, edge_index, batch=None):
        # 编码节点特征
        x = self.node_feature_encoder(x)
        
        # GraphSAGE卷积层
        for i, (conv, bn) in enumerate(zip(self.convs, self.batch_norms)):
            x_res = x  # 保存残差连接
            
            # GraphSAGE前向传播
            x = conv(x, edge_index)
            x = bn(x)
            x = F.relu(x)
            x = self.dropout(x)
            
            # 残差连接（跳过第一层）
            if i > 0:
                x = x + x_res
        
        # 图级池化
        if batch is None:
            batch = torch.zeros(x.size(0), dtype=torch.long, device=x.device)
        
        mean_pool = global_mean_pool(x, batch)
        max_pool = global_max_pool(x, batch)
        x = torch.cat([mean_pool, max_pool], dim=1)
        
        # 全连接层
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        
        return x

class CircuitCNN(nn.Module):
    """CNN编码器 - 处理电路的矩阵表示"""
    
    def __init__(self, 
                 input_channels: int = 2,
                 output_dim: int = 256,
                 base_channels: int = 32):
        super(CircuitCNN, self).__init__()
        
        # 卷积层
        self.conv1 = nn.Conv2d(input_channels, base_channels, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(base_channels, base_channels*2, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(base_channels*2, base_channels*4, kernel_size=3, padding=1)
        self.conv4 = nn.Conv2d(base_channels*4, base_channels*8, kernel_size=3, padding=1)
        
        # 批归一化层
        self.bn1 = nn.BatchNorm2d(base_channels)
        self.bn2 = nn.BatchNorm2d(base_channels*2)
        self.bn3 = nn.BatchNorm2d(base_channels*4)
        self.bn4 = nn.BatchNorm2d(base_channels*8)
        
        # 池化层
        self.pool = nn.MaxPool2d(2, 2)
        self.adaptive_pool = nn.AdaptiveAvgPool2d((4, 4))
        
        # 全连接层
        self.fc1 = nn.Linear(base_channels*8*4*4, 512)
        self.fc2 = nn.Linear(512, output_dim)
        self.dropout = nn.Dropout(0.3)
        
    def forward(self, x):
        # 卷积块1
        x = self.pool(F.relu(self.bn1(self.conv1(x))))
        # 卷积块2
        x = self.pool(F.relu(self.bn2(self.conv2(x))))
        # 卷积块3
        x = self.pool(F.relu(self.bn3(self.conv3(x))))
        # 卷积块4
        x = F.relu(self.bn4(self.conv4(x)))
        
        # 自适应池化
        x = self.adaptive_pool(x)
        
        # 展平并通过全连接层
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        
        return x


class CrossModalAttention(nn.Module):
    """跨模态注意力融合模块"""
    
    def __init__(self, embedding_dim: int = 256, num_heads: int = 8):
        super(CrossModalAttention, self).__init__()
        
        self.embedding_dim = embedding_dim
        self.num_heads = num_heads
        self.head_dim = embedding_dim // num_heads
        
        assert embedding_dim % num_heads == 0, "embedding_dim must be divisible by num_heads"
        
        # Query, Key, Value 投影层
        self.q_proj = nn.Linear(embedding_dim, embedding_dim)
        self.k_proj = nn.Linear(embedding_dim, embedding_dim)
        self.v_proj = nn.Linear(embedding_dim, embedding_dim)
        
        # 输出投影层
        self.out_proj = nn.Linear(embedding_dim, embedding_dim)
        
        # 层归一化
        self.norm1 = nn.LayerNorm(embedding_dim)
        self.norm2 = nn.LayerNorm(embedding_dim)
        
        # 前馈网络
        self.ffn = nn.Sequential(
            nn.Linear(embedding_dim, embedding_dim * 4),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(embedding_dim * 4, embedding_dim),
            nn.Dropout(0.1)
        )
        
    def forward(self, gnn_embedding, cnn_embedding):
        """
        跨模态注意力融合
        Args:
            gnn_embedding: GNN嵌入 (batch_size, embedding_dim)
            cnn_embedding: CNN嵌入 (batch_size, embedding_dim)
        Returns:
            融合后的嵌入 (batch_size, embedding_dim)
        """
        batch_size = gnn_embedding.size(0)
        
        # 将两个嵌入堆叠为序列 (batch_size, 2, embedding_dim)
        combined = torch.stack([gnn_embedding, cnn_embedding], dim=1)
        
        # 计算 Q, K, V
        Q = self.q_proj(combined)  # (batch_size, 2, embedding_dim)
        K = self.k_proj(combined)  # (batch_size, 2, embedding_dim)
        V = self.v_proj(combined)  # (batch_size, 2, embedding_dim)
        
        # 重塑为多头格式
        Q = Q.view(batch_size, 2, self.num_heads, self.head_dim).transpose(1, 2)
        K = K.view(batch_size, 2, self.num_heads, self.head_dim).transpose(1, 2)
        V = V.view(batch_size, 2, self.num_heads, self.head_dim).transpose(1, 2)
        
        # 计算注意力分数
        scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)
        attn_weights = F.softmax(scores, dim=-1)
        
        # 应用注意力权重
        attn_output = torch.matmul(attn_weights, V)
        
        # 重塑并投影
        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(batch_size, 2, self.embedding_dim)
        attn_output = self.out_proj(attn_output)
        
        # 残差连接和层归一化
        combined = self.norm1(combined + attn_output)
        
        # 前馈网络
        combined = self.norm2(combined + self.ffn(combined))
        
        # 聚合两个模态的特征 (可以使用mean, max或者学习的权重)
        # 这里使用加权平均，权重通过注意力学习
        final_embedding = combined.mean(dim=1)
        
        return final_embedding


class CircuitDistanceModel(nn.Module):
    """组合模型 - 融合GNN和CNN的特征"""
    
    def __init__(self, 
                 gnn_encoder: Optional[CircuitGraphEncoder] = None,
                 cnn_encoder: Optional[CircuitCNN] = None,
                 embedding_dim: int = 256,
                 fusion_method: str = 'attention',
                 node_feature_dim: int = 7,
                 num_attention_heads: int = 8):
        super(CircuitDistanceModel, self).__init__()
        
        # 如果没有提供编码器，创建默认的
        self.gnn_encoder = gnn_encoder or CircuitGraphEncoder(
            node_feature_dim=node_feature_dim,
            output_dim=embedding_dim
        )
        self.cnn_encoder = cnn_encoder or CircuitCNN(
            output_dim=embedding_dim
        )
        
        self.fusion_method = fusion_method
        
        # 融合层
        if fusion_method == 'attention':
            self.fusion_layer = CrossModalAttention(
                embedding_dim=embedding_dim,
                num_heads=num_attention_heads
            )
        elif fusion_method == 'concat':
            self.fusion_layer = nn.Linear(embedding_dim * 2, embedding_dim)
        else:
            self.fusion_layer = nn.Linear(embedding_dim, embedding_dim)
        
        # 投影层
        self.gnn_proj = nn.Linear(embedding_dim, embedding_dim)
        self.cnn_proj = nn.Linear(embedding_dim, embedding_dim)
        
        # 最终投影层
        self.final_proj = nn.Linear(embedding_dim, embedding_dim)
        
    def forward(self, graph_data, matrix_data):
        """
        前向传播
        Args:
            graph_data: 图数据 (包含x, edge_index, batch)
            matrix_data: 矩阵数据 (batch_size, channels, height, width)
        Returns:
            归一化的嵌入向量
        """
        # GNN嵌入
        gnn_embedding = self.gnn_encoder(
            graph_data.x, 
            graph_data.edge_index, 
            graph_data.batch if hasattr(graph_data, 'batch') else None
        )
        gnn_embedding = self.gnn_proj(gnn_embedding)
        
        # CNN嵌入
        cnn_embedding = self.cnn_encoder(matrix_data)
        cnn_embedding = self.cnn_proj(cnn_embedding)
        
        # 特征融合
        if self.fusion_method == 'attention':
            # 使用跨模态注意力融合
            final_embedding = self.fusion_layer(gnn_embedding, cnn_embedding)
            final_embedding = self.final_proj(final_embedding)
        elif self.fusion_method == 'concat':
            combined = torch.cat([gnn_embedding, cnn_embedding], dim=1)
            final_embedding = self.fusion_layer(combined)
        elif self.fusion_method == 'add':
            final_embedding = self.fusion_layer(gnn_embedding + cnn_embedding)
        else:
            final_embedding = gnn_embedding + cnn_embedding
        
        # L2归一化
        return F.normalize(final_embedding, p=2, dim=1)
    
    def compute_distance(self, embedding1, embedding2):
        """计算两个嵌入之间的距离"""
        return torch.norm(embedding1 - embedding2, p=2, dim=1)