import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GATConv, GINConv, SAGPooling, SAGEConv
from torch_geometric.utils import scatter
from torch_geometric.datasets import Planetoid
from torch_geometric.loader import DataLoader
import random
import numpy as np
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt

seed = 640
random.seed(seed)
np.random.seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

scaling_factors = [torch.tensor(0.9), torch.tensor(1.0), torch.tensor(1.3)]

class Swish(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(x)

class Swish(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(x)

class EdgeAttention(nn.Module):
    def __init__(self, in_channels, heads=2, scaling_factors=None):
        super(EdgeAttention, self).__init__()
        self.heads = heads
        self.in_channels = in_channels
        self.query = nn.Linear(in_channels, in_channels * heads, bias=False)
        self.key = nn.Linear(in_channels, in_channels * heads, bias=False)
        self.value = nn.Linear(in_channels, in_channels * heads, bias=False)
        
        self.temperature = nn.Parameter(torch.ones(1) * 0.5)

        if scaling_factors is not None:
            assert len(scaling_factors) == 3, "Provide a scaling factor for each of query, key, and value."
            self.query_scaling = scaling_factors[0]
            self.key_scaling = scaling_factors[1]
            self.value_scaling = scaling_factors[2]
        else:
            self.query_scaling = nn.Parameter(torch.ones(1))
            self.key_scaling = nn.Parameter(torch.ones(1))
            self.value_scaling = nn.Parameter(torch.ones(1))
        
        
        self.output_layer = nn.Sequential(
            nn.Linear(in_channels * heads, in_channels, bias=False),
            Swish(),
            nn.Linear(in_channels, 1, bias=False)
        )
        self.attention_dropout = nn.Dropout(p=0.1)
    
        nn.init.xavier_uniform_(self.query.weight)
        nn.init.xavier_uniform_(self.key.weight)
        nn.init.xavier_uniform_(self.value.weight)
        
    def forward(self, x_i, x_j):
        N = x_i.size(0)
        H = self.heads
        D = self.in_channels

        q = self.query(x_i).view(N, H, D) / (self.query_scaling)
        k = self.key(x_j).view(N, H, D) / (self.key_scaling)
        v = self.value(x_j).view(N, H, D) / (self.value_scaling)

        attention_scores = torch.einsum('nhd,nhd->nh', q, k) / (D ** 0.5 * self.temperature)
        attention_probs = F.softmax(attention_scores, dim=-1)
        attention_probs = self.attention_dropout(attention_probs)

        weighted_values = torch.einsum('nh,nhd->nhd', attention_probs, v)
        weighted_values = weighted_values.view(N, H * D)

        edge_weights = self.output_layer(weighted_values)
        return edge_weights

class KAGNNConv(nn.Module):
    def __init__(self, in_channels, out_channels, heads=2, scaling_factors=None):
        super(KAGNNConv, self).__init__()
        self.mlp = nn.Sequential(
            nn.Linear(in_channels, in_channels),
            nn.ReLU(),
            nn.Linear(in_channels, out_channels)
        )
        self.edge_attention = EdgeAttention(in_channels, heads=heads, scaling_factors=scaling_factors)
        self.residual = nn.Identity()

    def forward(self, x, edge_index, return_attention_weights=False):
        row, col = edge_index
        edge_attention_scores = self.edge_attention(x[row], x[col])
        edge_attention_scores = torch.sigmoid(edge_attention_scores).view(-1, 1)
        x_weighted = x[col] * edge_attention_scores
        out = scatter(x_weighted, row, dim=0, dim_size=x.size(0), reduce='max')
        out = self.mlp(out) + self.residual(x) 
        
        if return_attention_weights:
            return out, edge_attention_scores
        else:
            return out


class KolmogorovArnoldNetwork(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(KolmogorovArnoldNetwork, self).__init__()
        self.psi = nn.Linear(input_dim, hidden_dim)
        self.phi = nn.Linear(input_dim, hidden_dim)
        self.fc_combine = nn.Sequential(
            nn.Linear(hidden_dim * hidden_dim, input_dim),
            nn.ReLU(),
            nn.Linear(input_dim, hidden_dim)
        )
        self.ln = nn.LayerNorm(hidden_dim)  # Adding Layer Normalization

    def forward(self, x):
        psi_output = torch.tanh(self.psi(x))
        phi_output = torch.tanh(self.phi(x))
        combined = torch.einsum('bi,bj->bij', psi_output, phi_output).reshape(x.size(0), -1)
        out = self.fc_combine(combined)
        out = self.ln(out)  # Applying Layer Normalization
        return out, psi_output, phi_output


class HierarchicalKAGNN(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, scaling_factors=scaling_factors, head1=12, head2=1, head3=3):
        super(HierarchicalKAGNN, self).__init__()
        self.kan = KolmogorovArnoldNetwork(in_channels, hidden_channels)
        self.conv1 = KAGNNConv(hidden_channels, hidden_channels, head1, scaling_factors=scaling_factors)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.conv3 = KAGNNConv(hidden_channels, hidden_channels, head2, scaling_factors=scaling_factors)
        self.conv4 = GATConv(hidden_channels, hidden_channels, dropout=0.6)
        self.conv5 = KAGNNConv(hidden_channels, hidden_channels, head3, scaling_factors=scaling_factors)
        
        self.flatten = nn.Flatten()

        self.fc_final = nn.Sequential(
            nn.Linear(hidden_channels, hidden_channels),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Linear(hidden_channels, out_channels)
        )

    def forward(self, x, edge_index, return_attention_weights=False):
        
        x, psi_output, phi_output = self.kan(x)
        x, attention_scores1 = self.conv1(x, edge_index, return_attention_weights=True)
        x = F.relu(x)
        
        x = self.conv2(x, edge_index)
        x = F.relu(x)

        x, attention_scores2 = self.conv3(x, edge_index, return_attention_weights=True)
        x = F.relu(x)
        
        x = self.conv4(x, edge_index)
        x = F.relu(x)
        
        x, attention_scores3 = self.conv5(x, edge_index, return_attention_weights=True)

        x = self.flatten(x)  # Flatten 

        out = self.fc_final(x)
        
        if return_attention_weights:
            return out, [attention_scores1, attention_scores2, attention_scores3], psi_output, phi_output
        else:
            return out, psi_output, phi_output

model_path = './best_model_Citeseer.pth'
# load Cora
dataset = Planetoid(root='./Citeseer', name='Citeseer')
data = dataset[0]
data = data.to(device)

model = HierarchicalKAGNN(in_channels=data.x.size(1), hidden_channels=480, out_channels=dataset.num_classes, scaling_factors=scaling_factors).to(device)
model.load_state_dict(torch.load(model_path))
model.eval()
model = model.to(device)


def visualize_kan_outputs(psi_output, phi_output):
    plt.figure(figsize=(200, 60))
    plt.subplot(1, 3, 1)
    #plt.figure(figsize=(60, 60))  
    sns.heatmap(data.x.cpu(), cmap='viridis')  
    plt.title('Original Features Heatmap of Citeseer Dataset')  
    plt.xlabel('Features')  
    plt.ylabel('Nodes')  
    plt.subplot(1, 3, 2)
    sns.heatmap(psi_output.detach().cpu().numpy(), cmap='viridis')
    plt.title('Psi Output')
    plt.subplot(1, 3, 3)
    sns.heatmap(phi_output.detach().cpu().numpy(), cmap='viridis')
    plt.title('Phi Output')
    #plt.show()
    #save png
    plt.savefig('visualize_kan_outputs_Citeseer.png')
    plt.close()
    print(f"KA visualization saved to 'visualize_kan_outputs.png'")

_, psi_output, phi_output = model(data.x, data.edge_index)
visualize_kan_outputs(psi_output, phi_output)



def visualize_feature_distribution(labels, psi_output, phi_output):
    # Perform t-SNE dimensionality reduction on psi and phi outputs
    tsne = TSNE(n_components=2, random_state=42)
    psi_2d = tsne.fit_transform(psi_output.detach().cpu().numpy())
    phi_2d = tsne.fit_transform(phi_output.detach().cpu().numpy())
    
    labels = labels.cpu().numpy()
    
    # Visualising the 2D distribution of psi and phi
    plt.figure(figsize=(20, 8))
    plt.subplot(1, 2, 1)
    plt.scatter(psi_2d[:, 0], psi_2d[:, 1], c=labels, label='Psi Output', cmap='viridis', s=5)
    plt.colorbar(label='Class')
    plt.title('Psi Output Distribution')
    plt.xlabel('Component 1')
    plt.ylabel('Component 2')
    #plt.legend()
    
    
    plt.subplot(1, 2, 2)
    plt.scatter(phi_2d[:, 0], phi_2d[:, 1], c=labels, label='Phi Output', cmap='viridis', s=5)
    plt.colorbar(label='Class')
    plt.title('Phi Output Distribution')
    plt.xlabel('Component 1')
    plt.ylabel('Component 2')
    #plt.legend()
    #plt.show()
    
    #save png
    plt.savefig('visualize_feature_distribution_Citeseer.png')
    plt.close()
    print(f"KA visualization saved to 'visualize_feature_distribution.png'")

visualize_feature_distribution(data.y, psi_output, phi_output)



def visualize_similarity_matrix(psi_output, phi_output):
    # Compute the similarity matrix of psi and phi
    similarity_matrix = torch.matmul(psi_output, phi_output.T)
    
    plt.figure(figsize=(6, 6))
    sns.heatmap(similarity_matrix.detach().cpu().numpy(), cmap='coolwarm')
    plt.title('Similarity Matrix between Psi and Phi')
    #plt.show()
    #save png
    plt.savefig('visualize_similarity_matrix_Citeseer.png')
    plt.close()
    print(f"KA visualization saved to 'visualize_similarity_matrix.png'")

visualize_similarity_matrix(psi_output, phi_output)


