import networkx as nx
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GATConv, GINConv, SAGPooling, SAGEConv
from torch_geometric.utils import scatter
from torch_geometric.datasets import Planetoid
from torch_geometric.loader import DataLoader
import random
import numpy as np

seed = 640
random.seed(seed)
np.random.seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

scaling_factors = [torch.tensor(0.9), torch.tensor(1.0), torch.tensor(1.3)]

class Swish(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(x)

class EdgeAttention(nn.Module):
    def __init__(self, in_channels, heads=2, scaling_factors=None):
        super(EdgeAttention, self).__init__()
        self.heads = heads
        self.in_channels = in_channels
        self.query = nn.Linear(in_channels, in_channels * heads, bias=False)
        self.key = nn.Linear(in_channels, in_channels * heads, bias=False)
        self.value = nn.Linear(in_channels, in_channels * heads, bias=False)
        
        self.temperature = nn.Parameter(torch.ones(1) * 0.5)

        if scaling_factors is not None:
            assert len(scaling_factors) == 3, "Provide a scaling factor for each of query, key, and value."
            self.query_scaling = scaling_factors[0]
            self.key_scaling = scaling_factors[1]
            self.value_scaling = scaling_factors[2]
        else:
            self.query_scaling = nn.Parameter(torch.ones(1))
            self.key_scaling = nn.Parameter(torch.ones(1))
            self.value_scaling = nn.Parameter(torch.ones(1))
        
        
        self.output_layer = nn.Sequential(
            nn.Linear(in_channels * heads, in_channels, bias=False),
            Swish(),
            nn.Linear(in_channels, 1, bias=False)
        )
        self.attention_dropout = nn.Dropout(p=0.1)
    
        nn.init.xavier_uniform_(self.query.weight)
        nn.init.xavier_uniform_(self.key.weight)
        nn.init.xavier_uniform_(self.value.weight)
        
    def forward(self, x_i, x_j):
        N = x_i.size(0)
        H = self.heads
        D = self.in_channels

        q = self.query(x_i).view(N, H, D) / (self.query_scaling)
        k = self.key(x_j).view(N, H, D) / (self.key_scaling)
        v = self.value(x_j).view(N, H, D) / (self.value_scaling)

        attention_scores = torch.einsum('nhd,nhd->nh', q, k) / (D ** 0.5 * self.temperature)
        attention_probs = F.softmax(attention_scores, dim=-1)
        attention_probs = self.attention_dropout(attention_probs)

        weighted_values = torch.einsum('nh,nhd->nhd', attention_probs, v)
        weighted_values = weighted_values.view(N, H * D)

        edge_weights = self.output_layer(weighted_values)
        return edge_weights

class KAGNNConv(nn.Module):
    def __init__(self, in_channels, out_channels, heads=2, scaling_factors=None):
        super(KAGNNConv, self).__init__()
        self.mlp = nn.Sequential(
            nn.Linear(in_channels, in_channels),
            nn.ReLU(),
            nn.Linear(in_channels, out_channels)
        )
        self.edge_attention = EdgeAttention(in_channels, heads=heads, scaling_factors=scaling_factors)
        self.residual = nn.Identity()

    def forward(self, x, edge_index, return_attention_weights=False):
        row, col = edge_index
        edge_attention_scores = self.edge_attention(x[row], x[col])
        edge_attention_scores = torch.sigmoid(edge_attention_scores).view(-1, 1)
        x_weighted = x[col] * edge_attention_scores
        out = scatter(x_weighted, row, dim=0, dim_size=x.size(0), reduce='max')
        out = self.mlp(out) + self.residual(x) 
        
        if return_attention_weights:
            return out, edge_attention_scores
        else:
            return out


class KolmogorovArnoldNetwork(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(KolmogorovArnoldNetwork, self).__init__()
        self.psi = nn.Linear(input_dim, hidden_dim)
        self.phi = nn.Linear(input_dim, hidden_dim)
        self.fc_combine = nn.Sequential(
            nn.Linear(hidden_dim * hidden_dim, input_dim),
            nn.ReLU(),
            nn.Linear(input_dim, hidden_dim)
        )
        self.ln = nn.LayerNorm(hidden_dim)  # Adding Layer Normalization

    def forward(self, x):
        psi_output = torch.tanh(self.psi(x))
        phi_output = torch.tanh(self.phi(x))
        combined = torch.einsum('bi,bj->bij', psi_output, phi_output).reshape(x.size(0), -1)
        out = self.fc_combine(combined)
        out = self.ln(out)  # Applying Layer Normalization
        return out, psi_output, phi_output


class HierarchicalKAGNN(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, scaling_factors=scaling_factors, head1=12, head2=1, head3=3):
        super(HierarchicalKAGNN, self).__init__()
        self.kan = KolmogorovArnoldNetwork(in_channels, hidden_channels)
        self.conv1 = KAGNNConv(hidden_channels, hidden_channels, head1, scaling_factors=scaling_factors)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.conv3 = KAGNNConv(hidden_channels, hidden_channels, head2, scaling_factors=scaling_factors)
        self.conv4 = GATConv(hidden_channels, hidden_channels, dropout=0.6)
        self.conv5 = KAGNNConv(hidden_channels, hidden_channels, head3, scaling_factors=scaling_factors)
        
        self.flatten = nn.Flatten()

        self.fc_final = nn.Sequential(
            nn.Linear(hidden_channels, hidden_channels),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Linear(hidden_channels, out_channels)
        )

    def forward(self, x, edge_index, return_attention_weights=False):
        
        x, psi_output, phi_output = self.kan(x)
        x, attention_scores1 = self.conv1(x, edge_index, return_attention_weights=True)
        x = F.relu(x)
        
        x = self.conv2(x, edge_index)
        x = F.relu(x)

        x, attention_scores2 = self.conv3(x, edge_index, return_attention_weights=True)
        x = F.relu(x)
        
        x = self.conv4(x, edge_index)
        x = F.relu(x)
        
        x, attention_scores3 = self.conv5(x, edge_index, return_attention_weights=True)

        x = self.flatten(x) 

        out = self.fc_final(x)
        
        if return_attention_weights:
            return out, [attention_scores1, attention_scores2, attention_scores3], psi_output, phi_output
        else:
            return out, psi_output, phi_output

model_path = './best_model_Citeseer.pth'
# load Cora
dataset = Planetoid(root='./Citeseer', name='Citeseer')
data = dataset[0]
data = data.to(device)

model = HierarchicalKAGNN(in_channels=data.x.size(1), hidden_channels=480, out_channels=dataset.num_classes, scaling_factors=scaling_factors).to(device)
model.load_state_dict(torch.load(model_path))
model.eval()
model = model.to(device)

def visualize_attention_weights(model, data, layer_idx):
    model.eval()
    with torch.no_grad():
        classification_output, attention_weights_list, psi_output, phi_output = model(data.x, data.edge_index, return_attention_weights=True)
    
    attention_weights = attention_weights_list[layer_idx - 1].cpu().numpy().flatten()
    edge_index = data.edge_index.cpu().numpy()

    # Normalise attention_weights to [0, 1] range
    if attention_weights.max() != attention_weights.min():
        attention_weights = (attention_weights - attention_weights.min()) / (attention_weights.max() - attention_weights.min())
    else:
        attention_weights = np.ones_like(attention_weights)  # If all weights are equal, use even weights

    G = nx.Graph()
    for i in range(edge_index.shape[1]):
        u, v = edge_index[:, i]
        G.add_edge(u, v, weight=attention_weights[i])

    try:
        pos = nx.spring_layout(G)
    except ValueError:
        pos = nx.random_layout(G)  # If spring_layout fails, use random layout

    edge_colors = plt.cm.Blues(attention_weights)

    plt.figure(figsize=(200, 200))
    #nx.draw(G, pos, with_labels=True, node_size=100, font_size=5, font_color='white', edge_color='grey')
    
    #edges, weights = zip(*nx.get_edge_attributes(G, 'weight').items())
    #nx.draw(G, pos, edgelist=edges, edge_color=weights, width=1.0, edge_cmap=plt.cm.Blues, edge_vmin=min(weights), edge_vmax=max(weights))
    
    edges = nx.draw_networkx_edges(G, pos, edge_color=edge_colors, width=4.0)
    nodes = nx.draw_networkx_nodes(G, pos, node_size=1200)
    labels = nx.draw_networkx_labels(G, pos, font_size=10, font_color='white')

    #plt.colorbar(edges, label='Attention Weight')

    plt.title(f'Attention Weights Visualization for Layer {layer_idx}')
    #plt.show()

    pngname = 'attention_weights_'+str(layer_idx)+'_Citeseer.png'
    #save png
    plt.savefig(pngname)
    plt.close()
    print(f"Attention weights visualization saved to '{pngname}'")



for layer in range(1, 4):
    visualize_attention_weights(model, data, layer_idx=layer)

