import networkx as nx
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GATConv
from torch_geometric.utils import scatter
from torch_geometric.datasets import Planetoid
import random
import numpy as np

seed = 640
random.seed(seed)
np.random.seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

scaling_factors = [torch.tensor(0.9), torch.tensor(1.0), torch.tensor(1.3)]

class Swish(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(x)

class EdgeAttention(nn.Module):
    def __init__(self, in_channels, heads=2, scaling_factors=None):
        super(EdgeAttention, self).__init__()
        self.heads = heads
        self.in_channels = in_channels
        self.query = nn.Linear(in_channels, in_channels * heads, bias=False)
        self.key = nn.Linear(in_channels, in_channels * heads, bias=False)
        self.value = nn.Linear(in_channels, in_channels * heads, bias=False)
        
        self.temperature = nn.Parameter(torch.ones(1) * 0.5)

        if scaling_factors is not None:
            assert len(scaling_factors) == 3, "Provide a scaling factor for each of query, key, and value."
            self.query_scaling = scaling_factors[0]
            self.key_scaling = scaling_factors[1]
            self.value_scaling = scaling_factors[2]
        else:
            self.query_scaling = nn.Parameter(torch.ones(1))
            self.key_scaling = nn.Parameter(torch.ones(1))
            self.value_scaling = nn.Parameter(torch.ones(1))
        
        
        self.output_layer = nn.Sequential(
            nn.Linear(in_channels * heads, in_channels, bias=False),
            Swish(),
            nn.Linear(in_channels, 1, bias=False)
        )
        self.attention_dropout = nn.Dropout(p=0.1)
    
        nn.init.xavier_uniform_(self.query.weight)
        nn.init.xavier_uniform_(self.key.weight)
        nn.init.xavier_uniform_(self.value.weight)
        
    def forward(self, x_i, x_j):
        N = x_i.size(0)
        H = self.heads
        D = self.in_channels

        q = self.query(x_i).view(N, H, D) / (self.query_scaling)
        k = self.key(x_j).view(N, H, D) / (self.key_scaling)
        v = self.value(x_j).view(N, H, D) / (self.value_scaling)

        attention_scores = torch.einsum('nhd,nhd->nh', q, k) / (D ** 0.5 * self.temperature)
        attention_probs = F.softmax(attention_scores, dim=-1)
        attention_probs = self.attention_dropout(attention_probs)

        weighted_values = torch.einsum('nh,nhd->nhd', attention_probs, v)
        weighted_values = weighted_values.view(N, H * D)

        edge_weights = self.output_layer(weighted_values)
        return edge_weights

class KAGNNConv(nn.Module):
    def __init__(self, in_channels, out_channels, heads=2, scaling_factors=None):
        super(KAGNNConv, self).__init__()
        self.mlp = nn.Sequential(
            nn.Linear(in_channels, in_channels),
            nn.ReLU(),
            nn.Linear(in_channels, out_channels)
        )
        self.edge_attention = EdgeAttention(in_channels, heads=heads, scaling_factors=scaling_factors)
        self.residual = nn.Identity()

    def forward(self, x, edge_index, return_attention_weights=False):
        row, col = edge_index
        edge_attention_scores = self.edge_attention(x[row], x[col])
        edge_attention_scores = torch.sigmoid(edge_attention_scores).view(-1, 1)
        x_weighted = x[col] * edge_attention_scores
        out = scatter(x_weighted, row, dim=0, dim_size=x.size(0), reduce='max')
        out = self.mlp(out) + self.residual(x) 
        
        if return_attention_weights:
            return out, edge_attention_scores
        else:
            return out


class KolmogorovArnoldNetwork(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(KolmogorovArnoldNetwork, self).__init__()
        self.psi = nn.Linear(input_dim, hidden_dim)
        self.phi = nn.Linear(input_dim, hidden_dim)
        self.fc_combine = nn.Sequential(
            nn.Linear(hidden_dim * hidden_dim, input_dim),
            nn.ReLU(),
            nn.Linear(input_dim, hidden_dim)
        )
        self.ln = nn.LayerNorm(hidden_dim)  # Adding Layer Normalization

    def forward(self, x):
        psi_output = torch.tanh(self.psi(x))
        phi_output = torch.tanh(self.phi(x))
        combined = torch.einsum('bi,bj->bij', psi_output, phi_output).reshape(x.size(0), -1)
        out = self.fc_combine(combined)
        out = self.ln(out)  # Applying Layer Normalization
        return out, psi_output, phi_output


class HierarchicalKAGNN(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, scaling_factors=scaling_factors, head1=12, head2=1, head3=3):
        super(HierarchicalKAGNN, self).__init__()
        self.kan = KolmogorovArnoldNetwork(in_channels, hidden_channels)
        self.conv1 = KAGNNConv(hidden_channels, hidden_channels, head1, scaling_factors=scaling_factors)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.conv3 = KAGNNConv(hidden_channels, hidden_channels, head2, scaling_factors=scaling_factors)
        self.conv4 = GATConv(hidden_channels, hidden_channels, dropout=0.6)
        self.conv5 = KAGNNConv(hidden_channels, hidden_channels, head3, scaling_factors=scaling_factors)
        
        self.flatten = nn.Flatten()

        self.fc_final = nn.Sequential(
            nn.Linear(hidden_channels, hidden_channels),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Linear(hidden_channels, out_channels)
        )

    def forward(self, x, edge_index, return_attention_weights=False):
        
        x, psi_output, phi_output = self.kan(x)
        x, attention_scores1 = self.conv1(x, edge_index, return_attention_weights=True)
        x = F.relu(x)
        
        x = self.conv2(x, edge_index)
        x = F.relu(x)

        x, attention_scores2 = self.conv3(x, edge_index, return_attention_weights=True)
        x = F.relu(x)
        
        x = self.conv4(x, edge_index)
        x = F.relu(x)
        
        x, attention_scores3 = self.conv5(x, edge_index, return_attention_weights=True)

        x = self.flatten(x)  # Flatten 

        out = self.fc_final(x)
        
        if return_attention_weights:
            return out, [attention_scores1, attention_scores2, attention_scores3], psi_output, phi_output
        else:
            return out, psi_output, phi_output

model_path = './best_model_Citeseer.pth'
# load Cora
dataset = Planetoid(root='./Citeseer', name='Citeseer')
data = dataset[0]
data = data.to(device)

# Instantiate the model
model = HierarchicalKAGNN(in_channels=data.x.size(1), hidden_channels=480, out_channels=dataset.num_classes, scaling_factors=scaling_factors).to(device)
model.load_state_dict(torch.load(model_path))
model.eval()

# Evaluate and visualize the model
with torch.no_grad():
    classification_output, attention_weights_list, psi_output, phi_output = model(data.x, data.edge_index, return_attention_weights=True)

attention_weights_list = [aw.cpu().numpy().flatten() for aw in attention_weights_list]
edge_index = data.edge_index.cpu().numpy()

# Debug: Print attention weights to verify they are different
for idx, attention_weights in enumerate(attention_weights_list):
    print(f"Attention weights for layer {idx+1}: {attention_weights[:100]}")  # Print first 10 weights for each layer


# Create a graph for visualization
G = nx.Graph()
for i in range(edge_index.shape[1]):
    u, v = edge_index[:, i]
    G.add_edge(u, v, weight=attention_weights_list[0][i])

# Function to visualize key paths for multiple attention weights with multi-hop
def visualize_key_paths_multi_hop(G, node_idx, attention_weights_list, threshold=0.5, hops=2):
    fig, axes = plt.subplots(1, len(attention_weights_list), figsize=(18*4, 6*4))
    
    for idx, attention_weights in enumerate(attention_weights_list):
        nx.set_edge_attributes(G, {e: attention_weights[i] for i, e in enumerate(G.edges())}, 'weight')
        nodes_to_include = set([node_idx])
        nodes_to_explore = set([node_idx])
        
        for _ in range(hops):
            new_nodes_to_explore = set()
            for node in nodes_to_explore:
                new_nodes_to_explore.update(set(nx.neighbors(G, node)))
            nodes_to_include.update(new_nodes_to_explore)
            nodes_to_explore = new_nodes_to_explore
        
        subgraph = G.subgraph(nodes_to_include)
        pos = nx.spring_layout(subgraph)
        nx.draw(subgraph, pos, ax=axes[idx], with_labels=True, node_size=1600, font_size=10, font_color='white', node_color='blue')
        nx.draw_networkx_nodes(subgraph, pos, ax=axes[idx], nodelist=[node_idx], node_color='red', node_size=1600)
        
        #edge_labels = nx.get_edge_attributes(subgraph, 'weight')
        #edge_labels = {k: f"{v:.2f}" for k, v in edge_labels.items()}
        #nx.draw_networkx_edge_labels(subgraph, pos, edge_labels=edge_labels, ax=axes[idx])
        
        axes[idx].set_title(f'Attention Layer {idx+1}')
    
    plt.savefig('visualize_key_paths_multi_hop_Citeseer.png')
    plt.close()
    print(f"Attention weights visualization saved to 'visualize_key_paths_multi_hop.png'")

# Visualize key paths for a target node with multi-hop
target_node = 1701
#visualize_key_paths_multi_hop(G, target_node, attention_weights_list, threshold=0.7, hops=2)


# Function to visualize key paths for multiple attention weights with multi-hop
def visualize_key_paths_multi_hop_colored_arcs(G, node_idx, attention_weights_list, threshold=0.5, hops=2):
    colors = ['blue', 'green', 'orange']  # Different colors for different layers
    pos = nx.spring_layout(G)  # Compute positions for all nodes for consistent layout

    fig, ax = plt.subplots(figsize=(10, 10))
    
    for idx, attention_weights in enumerate(attention_weights_list):
        nx.set_edge_attributes(G, {e: attention_weights[i] for i, e in enumerate(G.edges())}, 'weight')
        nodes_to_include = set([node_idx])
        nodes_to_explore = set([node_idx])
        
        for _ in range(hops):
            new_nodes_to_explore = set()
            for node in nodes_to_explore:
                new_nodes_to_explore.update(set(nx.neighbors(G, node)))
            nodes_to_include.update(new_nodes_to_explore)
            nodes_to_explore = new_nodes_to_explore
        
        subgraph = G.subgraph(nodes_to_include)
        
        edges = [(u, v) for u, v, d in subgraph.edges(data=True) if d['weight'] > threshold]
        weights = [d['weight'] for u, v, d in subgraph.edges(data=True) if d['weight'] > threshold]
        
        # Draw straight edges
        nx.draw_networkx_edges(subgraph, pos, edgelist=edges, width=weights, edge_color=colors[idx], ax=ax)
        
        # Draw arc edges for multi-edges
        for u, v, data in subgraph.edges(data=True):
            if data['weight'] > threshold:
                if G.number_of_edges(u, v) > 1:
                    rad = 0.2 * (idx * 2 )  # Vary radius based on layer index
                    nx.draw_networkx_edges(subgraph, pos, edgelist=[(u, v)], width=1.0,
                                           edge_color=colors[idx], ax=ax, connectionstyle=f'arc3,rad={rad}')
    
    nx.draw(subgraph, pos, with_labels=True, node_size=800, font_size=9, font_color='white', node_color='blue', ax=ax)
    nx.draw_networkx_nodes(subgraph, pos, nodelist=[node_idx], node_color='red', node_size=1000, ax=ax)
    
    plt.savefig('visualize_key_paths_multi_hop_colored_arcs_Citeseer.png')
    plt.close()
    print(f"Attention weights visualization saved to 'visualize_key_paths_multi_hop_colored_arcs.png'")

# Visualize key paths for a target node with multi-hop
visualize_key_paths_multi_hop_colored_arcs(G, target_node, attention_weights_list, threshold=0.7, hops=2)

def visualize_key_paths_multi_hop_colored_nodes(G, node_idx, attention_weights_list, threshold=0.5, hops=2):
    pos = nx.spring_layout(G)  # Compute positions for all nodes for consistent layout

    fig, ax = plt.subplots(figsize=(12*2, 12*2))
    
    node_colors = {}
    node_layers = {node: set() for node in G.nodes()}
    
    for idx, attention_weights in enumerate(attention_weights_list):
        nx.set_edge_attributes(G, {e: attention_weights[i] for i, e in enumerate(G.edges())}, 'weight')
        nodes_to_include = set([node_idx])
        nodes_to_explore = set([node_idx])
        
        for _ in range(hops):
            new_nodes_to_explore = set()
            for node in nodes_to_explore:
                new_nodes_to_explore.update(set(nx.neighbors(G, node)))
            nodes_to_include.update(new_nodes_to_explore)
            nodes_to_explore = new_nodes_to_explore
        
        subgraph = G.subgraph(nodes_to_include)
        
        for node in subgraph.nodes():
            node_layers[node].add(idx)
        
        edges = [(u, v) for u, v, d in subgraph.edges(data=True) if d['weight'] > threshold]
        weights = [d['weight'] for u, v, d in subgraph.edges(data=True) if d['weight'] > threshold]
        
        nx.draw_networkx_edges(subgraph, pos, edgelist=edges, width=weights, edge_color='gray', ax=ax)
    
    for node, layers in node_layers.items():
        if node_idx == node:
            node_colors[node] = 'red'
        elif len(layers) == 1:
            layer = list(layers)[0]
            if layer == 0:
                node_colors[node] = 'blue'
            elif layer == 1:
                node_colors[node] = 'green'
            elif layer == 2:
                node_colors[node] = 'orange'
        else:
            node_colors[node] = 'purple'  # Nodes in multiple layers
    
    nx.draw(subgraph, pos, with_labels=True, node_size=1600, font_size=10, font_color='white', node_color=[node_colors[node] for node in subgraph.nodes()], ax=ax)
    
    plt.savefig('visualize_key_paths_multi_hop_colored_nodes_Citeseer.png')
    plt.close()
    print(f"Attention weights visualization saved to 'visualize_key_paths_multi_hop_colored_nodes.png'")

# Visualize key paths for a target node with multi-hop
#visualize_key_paths_multi_hop_colored_nodes(G, target_node, attention_weights_list, threshold=0.7, hops=2)
