import os
import sys
sys.path.insert(0, os.getcwd())
import torch
from torch import nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from utils.utils import load_json
from circuit.circuit_manager import transform_adj_to_edge_list

class MultiHeadAttentionGNN(nn.Module):
    def __init__(self, in_channels, out_channels, num_heads, num_edge_types, add_self_loops, dropout):
        super(MultiHeadAttentionGNN, self).__init__()
        self.convs = nn.ModuleList([
            GATConv(in_channels, out_channels // num_heads, heads=num_heads, add_self_loops=add_self_loops, dropout=dropout)
            for _ in range(num_edge_types)
        ])

    def forward(self, ops, adjs):
        outs = []
        for i, conv in enumerate(self.convs):
            outs.append(conv(ops, adjs[i]))
        out = torch.cat(outs, dim=1)
        return out

class MAHTEncoder(nn.Module):
    def __init__(self, in_channels, out_channels, num_heads=1, num_edge_types=7, add_self_loops=True, dropout=0):
        super(MAHTEncoder, self).__init__()
        self.base_gnn = MultiHeadAttentionGNN(in_channels, out_channels, num_heads, num_edge_types, add_self_loops, dropout)
        self.mu = nn.Linear(out_channels * num_edge_types, out_channels)
        self.logvar = nn.Linear(out_channels * num_edge_types, out_channels)
    
    def reparameterize(self, mu, logvar):
        if self.training:
            std = torch.exp(0.5 * logvar)
            eps = torch.randn_like(std)
            return eps.mul(std).add_(mu)
        else:
            return mu

    def forward(self, ops, edge_indices):
        out = self.base_gnn(ops, edge_indices)
        mu = self.mu(out)
        logvar = self.logvar(out)
        z = self.reparameterize(mu, logvar)
        return mu, logvar, z

class FeatureDecoder(nn.Module):
    def __init__(self, embedding_dim, feature_dim, dropout):
        super(FeatureDecoder, self).__init__()
        self.decoder = nn.Sequential(
            nn.Linear(embedding_dim, 128),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(128, feature_dim),
            nn.Sigmoid()
        )

    def forward(self, z):
        reconstructed_features = self.decoder(z)
        return reconstructed_features

class EdgeDecoder(nn.Module):
    def __init__(self, num_features, num_edge_types, dropout):
        super(EdgeDecoder, self).__init__()
        self.decoders = nn.ModuleList([
            nn.Sequential(
                nn.Linear(num_features * 2, num_features),
                nn.ReLU(),
                nn.Dropout(dropout),
                nn.Linear(num_features, 1),
                nn.Sigmoid()
            ) for _ in range(num_edge_types)
        ])

    def forward(self, z, edge_indices):
        outputs = []
        for decoder, edge_index in zip(self.decoders, edge_indices):
            if edge_index.numel() == 0:
                outputs.append(torch.zeros(0, 1, device=z.device))
                continue
            src, tgt = edge_index
            edge_features = torch.cat([z[src], z[tgt]], dim=1)
            prob = decoder(edge_features)
            outputs.append(prob)
        return outputs

class MultiChannelGNN(nn.Module):
    def __init__(self, in_channels, out_channels, num_heads, num_edge_types, add_self_loops, dropout):
        super(MultiChannelGNN, self).__init__()
        self.encoder = MAHTEncoder(in_channels, out_channels, num_heads, num_edge_types, add_self_loops, dropout)
        self.edge_decoder = EdgeDecoder(out_channels, num_edge_types, dropout)
        self.feature_decoder = FeatureDecoder(out_channels, in_channels, dropout)

    def forward(self, ops, edge_indices):
        mu, logvar, z = self.encoder(ops, edge_indices)
        recon_edges = self.edge_decoder(z, edge_indices)
        recon_features = self.feature_decoder(z)
        return mu, logvar, recon_edges, recon_features

class VAELoss(nn.Module):
    def __init__(self):
        super(VAELoss, self).__init__()

    def forward(self, recon_features, original_features, recon_edges, original_edges, mu, logvar):
        # feature reconstruction loss
        feature_reconstruction_loss = F.binary_cross_entropy(recon_features, original_features, reduction='mean')

        # edge reconstruction loss
        edge_reconstruction_loss = 0
        for recon_edge, original_edge in zip(recon_edges, original_edges):
            if original_edge.numel() == 0 or recon_edge.numel() == 0:  # 检查是否存在边
                continue
            original_edge = original_edge.view_as(recon_edge).float()
            if recon_edge.size() != original_edge.size():
                raise ValueError(f"Size mismatch: recon_edge size {recon_edge.size()} vs original_edge size {original_edge.size()}")
            edge_reconstruction_loss += F.binary_cross_entropy(recon_edge, original_edge.float(), reduction='mean')

        # KL divergence loss
        kl_divergence_loss = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
        
        return feature_reconstruction_loss + edge_reconstruction_loss + kl_divergence_loss

if __name__ == "__main__":

    # 示例数据
    num_nodes = 12
    in_channels = 17  # 节点特征维度
    out_channels = 32  # 编码器输出维度
    num_heads = 4
    num_edge_types = 7
    dropout = 0.1

    # 创建示例图数据
    data = load_json("circuit\data\data_4_qubits.json")
    circuit_adjs = [data[i]['adj_matrix_group'] for i in range(len(data))]
    circuit_edges = []
    for i in range(len(circuit_adjs)):
        circuit_edges.append(transform_adj_to_edge_list(circuit_adjs[i]))
    data_list = []
    ops = [data[i]['gate_matrix'] for i in range(len(data))]
    for op, circuit_edge in zip(ops, circuit_edges):
        data_list.append(Data(x=torch.tensor(op, dtype=torch.float), edge_index=circuit_edge))
    # 使用DataLoader进行批量处理
    loader = DataLoader(data_list, batch_size=32, shuffle=False)
    # 定义模型
    model = MultiChannelGNN(in_channels, out_channels, num_heads, num_edge_types, True, dropout)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    loss_fn = VAELoss()

    # 训练循环
    model.train()
    for epoch in range(10):
        for batch in loader:
            optimizer.zero_grad()
            mu, logvar, recon_edges, recon_features = model(batch.x, batch.edge_index)
            # 重构边的原始值（这里假设batch.edge_index包含了所有类型的边）
            original_edges = [batch.edge_index[i] for i in range(num_edge_types)]
            # 将 original_edges 转换为与 recon_edges 一致的形状
            original_edges_flatten = []
            for edge_index in original_edges:
                src, tgt = edge_index
                num_edges = src.size(0)
                original_edge_flat = torch.ones((num_edges, 1), dtype=torch.float32, device=edge_index.device)  # 存在的边设为1
                #print(original_edge_flat)
                original_edges_flatten.append(original_edge_flat)

            # 计算VAE重构损失
            loss = loss_fn(recon_features, batch.x, recon_edges, original_edges_flatten, mu, logvar)

            loss.backward()
            optimizer.step()
            if epoch == 9:
                print(recon_features)
        print(f'Epoch {epoch}, Loss: {loss.item()}')