import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GATConv, GINConv, SAGPooling, SAGEConv
from torch_geometric.utils import scatter
from torch_geometric.datasets import Planetoid
from torch_geometric.loader import DataLoader
import itertools
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
import random
import numpy as np

seed = 520
random.seed(seed)
np.random.seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

scaling_factors = [torch.tensor(1.0), torch.tensor(1.31), torch.tensor(0.85)]

class Swish(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(x)

class EdgeAttention(nn.Module):
    def __init__(self, in_channels, heads=2, scaling_factors=None):
        super(EdgeAttention, self).__init__()
        self.heads = heads
        self.in_channels = in_channels
        self.query = nn.Linear(in_channels, in_channels * heads, bias=False)
        self.key = nn.Linear(in_channels, in_channels * heads, bias=False)
        self.value = nn.Linear(in_channels, in_channels * heads, bias=False)
        
        self.temperature = nn.Parameter(torch.ones(1) * 0.5)

        if scaling_factors is not None:
            assert len(scaling_factors) == 3, "Provide a scaling factor for each of query, key, and value."
            self.query_scaling = scaling_factors[0]
            self.key_scaling = scaling_factors[1]
            self.value_scaling = scaling_factors[2]
        else:
            self.query_scaling = nn.Parameter(torch.ones(1))
            self.key_scaling = nn.Parameter(torch.ones(1))
            self.value_scaling = nn.Parameter(torch.ones(1))
        
        
        self.output_layer = nn.Sequential(
            nn.Linear(in_channels * heads, in_channels, bias=False),
            Swish(),
            nn.Linear(in_channels, 1, bias=False)
        )
        self.attention_dropout = nn.Dropout(p=0.1)
    
        nn.init.xavier_uniform_(self.query.weight)
        nn.init.xavier_uniform_(self.key.weight)
        nn.init.xavier_uniform_(self.value.weight)
        
    def forward(self, x_i, x_j):
        N = x_i.size(0)
        H = self.heads
        D = self.in_channels

        q = self.query(x_i).view(N, H, D) / (self.query_scaling)
        k = self.key(x_j).view(N, H, D) / (self.key_scaling)
        v = self.value(x_j).view(N, H, D) / (self.value_scaling)

        attention_scores = torch.einsum('nhd,nhd->nh', q, k) / (D ** 0.5 * self.temperature)
        attention_probs = F.softmax(attention_scores, dim=-1)
        attention_probs = self.attention_dropout(attention_probs)

        weighted_values = torch.einsum('nh,nhd->nhd', attention_probs, v)
        weighted_values = weighted_values.view(N, H * D)

        edge_weights = self.output_layer(weighted_values)
        return edge_weights

class KAGNNConv(nn.Module):
    def __init__(self, in_channels, out_channels, heads=2, scaling_factors=None):
        super(KAGNNConv, self).__init__()
        self.mlp = nn.Sequential(
            nn.Linear(in_channels, in_channels),
            nn.ReLU(),
            nn.Linear(in_channels, out_channels)
        )
        self.edge_attention = EdgeAttention(in_channels, heads=heads, scaling_factors=scaling_factors)
        self.residual = nn.Identity()

    def forward(self, x, edge_index, return_attention_weights=False):
        row, col = edge_index
        edge_attention_scores = self.edge_attention(x[row], x[col])
        edge_attention_scores = torch.sigmoid(edge_attention_scores).view(-1, 1)
        x_weighted = x[col] * edge_attention_scores
        out = scatter(x_weighted, row, dim=0, dim_size=x.size(0), reduce='min')
        out = self.mlp(out) + self.residual(x) 
        
        if return_attention_weights:
            return out, edge_attention_scores
        else:
            return out


class KolmogorovArnoldNetwork(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(KolmogorovArnoldNetwork, self).__init__()
        self.psi = nn.Linear(input_dim, hidden_dim)
        self.phi = nn.Linear(input_dim, hidden_dim)
        self.fc_combine = nn.Sequential(
            nn.Linear(hidden_dim * hidden_dim, input_dim),
            nn.ReLU(),
            nn.Linear(input_dim, hidden_dim)
        )
        self.ln = nn.LayerNorm(hidden_dim)

    def forward(self, x):
        psi_output = torch.tanh(self.psi(x))
        phi_output = torch.tanh(self.phi(x))
        combined = torch.einsum('bi,bj->bij', psi_output, phi_output).reshape(x.size(0), -1)
        out = self.fc_combine(combined)
        out = self.ln(out)
        return out, psi_output, phi_output


class HierarchicalKAGNN(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, scaling_factors=scaling_factors, head1=8, head2=2, head3=4):
        super(HierarchicalKAGNN, self).__init__()
        self.kan = KolmogorovArnoldNetwork(in_channels, hidden_channels)
        self.conv1 = KAGNNConv(hidden_channels, hidden_channels, head1, scaling_factors=scaling_factors)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.conv3 = KAGNNConv(hidden_channels, hidden_channels, head2, scaling_factors=scaling_factors)
        self.conv4 = GATConv(hidden_channels, hidden_channels, dropout=0.6)
        self.conv5 = KAGNNConv(hidden_channels, hidden_channels, head3, scaling_factors=scaling_factors)
        
        self.flatten = nn.Flatten()

        self.fc_final = nn.Sequential(
            nn.Linear(hidden_channels, hidden_channels),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Linear(hidden_channels, out_channels)
        )

    def forward(self, x, edge_index, return_attention_weights=False):
        
        x, psi_output, phi_output = self.kan(x)
        x, attention_scores1 = self.conv1(x, edge_index, return_attention_weights=True)
        x = F.relu(x)
        
        x = self.conv2(x, edge_index)
        x = F.relu(x)

        x, attention_scores2 = self.conv3(x, edge_index, return_attention_weights=True)
        x = F.relu(x)
        
        x = self.conv4(x, edge_index)
        x = F.relu(x)
        
        x, attention_scores3 = self.conv5(x, edge_index, return_attention_weights=True)

        x = self.flatten(x)  # Flatten 

        out = self.fc_final(x)
        
        if return_attention_weights:
            return out, [attention_scores1, attention_scores2, attention_scores3], psi_output, phi_output
        else:
            return out, psi_output, phi_output
        
# load Cora
dataset = Planetoid(root='./Cora', name='Cora')
data = dataset[0]
data = data.to(device)

results = []

model = HierarchicalKAGNN(in_channels=data.x.size(1), hidden_channels=384, out_channels=dataset.num_classes, scaling_factors=scaling_factors).to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=0)
criterion = nn.CrossEntropyLoss()
best_acc = 0.0

def train():
    model.train()
    optimizer.zero_grad()
    out,_,_ = model(data.x, data.edge_index)
    loss = criterion(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    return loss.item()

def test():
    model.eval()
    out,_,_ = model(data.x, data.edge_index)
    pred = out.argmax(dim=1)

    preds_np = pred[data.test_mask].cpu().numpy()
    y_true_np = data.y[data.test_mask].cpu().numpy()
    accuracy = accuracy_score(y_true_np, preds_np)
    precision = precision_score(y_true_np, preds_np, average='macro', zero_division=0)
    recall = recall_score(y_true_np, preds_np, average='macro', zero_division=0)
    f1 = f1_score(y_true_np, preds_np, average='macro', zero_division=0)
    
    return accuracy, f1, precision, recall

# train
for epoch in range(40):
    loss = train()
    acc, f1, precision, recall = test()
    results.append((epoch, acc))
    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Test Accuracy: {acc:.4f}, precision: {precision:.4f}, recall: {recall:.4f}, f1: {f1:.4f}')
    
    if acc > best_acc and acc >0.83:
        best_acc = acc
        torch.save(model.state_dict(), './best_model_v4.pth')
        print(f"New best model saved at epoch {epoch} with accuracy {acc:.4f}")
    

best_result = max(results, key=lambda x: x[1])
print(f"Best epochs: {best_result[0]}, Best accuracy: {best_result[1]:.4f}")
