import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GATConv, GINConv, SAGPooling, SAGEConv
from torch_geometric.utils import scatter
from torch_geometric.datasets import Planetoid
from torch_geometric.loader import DataLoader
import itertools
import gc
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
import random
import numpy as np
import os
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

#Best accuracy: 0.7310 with hyperparameters {'head1': 10, 'head2': 4, 'head3': 2, 'hidden_channels': 512, 'qs': 1.0, 'ks': 1.1, 'vs': 1.0}
#Best accuracy: 0.7260 with hyperparameters {'head1': 13, 'head2': 2, 'head3': 1, 'hidden_channels': 512, 'qs': 1.0, 'ks': 1.1, 'vs': 1.0}
#Best accuracy: 0.7310 with hyperparameters {'head1': 12, 'head2': 1, 'head3': 3, 'hidden_channels': 480, 'qs': 1.0, 'ks': 1.1, 'vs': 1.0}
#Best accuracy: 0.7380 with hyperparameters {'head1': 12, 'head2': 1, 'head3': 3, 'hidden_channels': 480, 'qs': 0.9, 'ks': 1.0, 'vs': 1.3}


class Swish(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(x)

class EdgeAttention(nn.Module):
    def __init__(self, in_channels, heads=2, scaling_factors=None):
        super(EdgeAttention, self).__init__()
        self.heads = heads
        self.in_channels = in_channels
        self.query = nn.Linear(in_channels, in_channels * heads, bias=False)
        self.key = nn.Linear(in_channels, in_channels * heads, bias=False)
        self.value = nn.Linear(in_channels, in_channels * heads, bias=False)
        
        self.temperature = nn.Parameter(torch.ones(1) * 0.5)
        
        if scaling_factors is not None:
            assert len(scaling_factors) == 3, "Provide a scaling factor for each of query, key, and value."
            self.query_scaling = scaling_factors[0]
            self.key_scaling = scaling_factors[1]
            self.value_scaling = scaling_factors[2]
        else:
            self.query_scaling = nn.Parameter(torch.ones(1))
            self.key_scaling = nn.Parameter(torch.ones(1))
            self.value_scaling = nn.Parameter(torch.ones(1))
    
        self.output_layer = nn.Sequential(
            nn.Linear(in_channels * heads, in_channels, bias=False),
            Swish(),
            nn.Linear(in_channels, 1, bias=False)
        )
        self.attention_dropout = nn.Dropout(p=0.1)
    
        nn.init.xavier_uniform_(self.query.weight)
        nn.init.xavier_uniform_(self.key.weight)
        nn.init.xavier_uniform_(self.value.weight)
        
    def forward(self, x_i, x_j):
        N = x_i.size(0)
        H = self.heads
        D = self.in_channels

        q = self.query(x_i).view(N, H, D) / (self.query_scaling)
        k = self.key(x_j).view(N, H, D) / (self.key_scaling)
        v = self.value(x_j).view(N, H, D) / (self.value_scaling)

        attention_scores = torch.einsum('nhd,nhd->nh', q, k) / (D ** 0.5 * self.temperature)
        attention_probs = F.softmax(attention_scores, dim=-1)
        attention_probs = self.attention_dropout(attention_probs)

        weighted_values = torch.einsum('nh,nhd->nhd', attention_probs, v)
        weighted_values = weighted_values.view(N, H * D)

        edge_weights = self.output_layer(weighted_values)
        return edge_weights

class KAGNNConv(nn.Module):
    def __init__(self, in_channels, out_channels, heads=2, scaling_factors=None):
        super(KAGNNConv, self).__init__()
        self.mlp = nn.Sequential(
            nn.Linear(in_channels, in_channels),
            nn.ReLU(),
            nn.Linear(in_channels, out_channels)
        )
        self.edge_attention = EdgeAttention(in_channels, heads=heads, scaling_factors=scaling_factors)
        self.residual = nn.Identity()

    def forward(self, x, edge_index, return_attention_weights=False):
        row, col = edge_index
        edge_attention_scores = self.edge_attention(x[row], x[col])
        edge_attention_scores = torch.sigmoid(edge_attention_scores).view(-1, 1)
        x_weighted = x[col] * edge_attention_scores
        out = scatter(x_weighted, row, dim=0, dim_size=x.size(0), reduce='max')
        out = self.mlp(out) + self.residual(x)  
        
        if return_attention_weights:
            return out, edge_attention_scores
        else:
            return out

class KolmogorovArnoldNetwork(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(KolmogorovArnoldNetwork, self).__init__()
        self.psi = nn.Linear(input_dim, hidden_dim)
        self.phi = nn.Linear(input_dim, hidden_dim)
        self.fc_combine = nn.Sequential(
            nn.Linear(hidden_dim * hidden_dim, input_dim),
            nn.ReLU(),
            nn.Linear(input_dim, hidden_dim)
        )
        self.ln = nn.LayerNorm(hidden_dim)  # Adding Layer Normalization

    def forward(self, x):
        psi_output = torch.tanh(self.psi(x))
        phi_output = torch.tanh(self.phi(x))
        combined = torch.einsum('bi,bj->bij', psi_output, phi_output).reshape(x.size(0), -1)
        out = self.fc_combine(combined)
        out = self.ln(out)  # Applying Layer Normalization
        return out, psi_output, phi_output
        

class HierarchicalKAGNN(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, scaling_factors, head1=3, head2=3, head3=4):
        super(HierarchicalKAGNN, self).__init__()
        self.kan = KolmogorovArnoldNetwork(in_channels, hidden_channels)
        self.conv1 = KAGNNConv(hidden_channels, hidden_channels, head1, scaling_factors=scaling_factors)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.conv3 = KAGNNConv(hidden_channels, hidden_channels, head2, scaling_factors=scaling_factors)
        self.conv4 = GATConv(hidden_channels, hidden_channels, dropout=0.6)
        self.conv5 = KAGNNConv(hidden_channels, hidden_channels, head3, scaling_factors=scaling_factors)
        
        self.flatten = nn.Flatten()

        self.fc_final = nn.Sequential(
            nn.Linear(hidden_channels, hidden_channels),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Linear(hidden_channels, out_channels)
        )

    def forward(self, x, edge_index, return_attention_weights=False):

        x, psi_output, phi_output = self.kan(x)

        x, attention_scores1 = self.conv1(x, edge_index, return_attention_weights=True)
        x = F.relu(x)
        
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        
        x, attention_scores2 = self.conv3(x, edge_index, return_attention_weights=True)
        x = F.relu(x)
        
        x = self.conv4(x, edge_index)
        x = F.relu(x)
        
        x, attention_scores3 = self.conv5(x, edge_index, return_attention_weights=True)

        x = self.flatten(x)  # Flatten 

        out = self.fc_final(x)
        
        if return_attention_weights:
            return out, [attention_scores1, attention_scores2, attention_scores3], psi_output, phi_output
        else:
            return out, psi_output, phi_output


def train():
    model.train()
    optimizer.zero_grad()
    out,_,_ = model(data.x, data.edge_index)
    loss = criterion(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    return loss.item()

def test():
    model.eval()
    out,_,_ = model(data.x, data.edge_index)
    pred = out.argmax(dim=1)

    preds_np = pred[data.test_mask].cpu().numpy()
    y_true_np = data.y[data.test_mask].cpu().numpy()
    accuracy = accuracy_score(y_true_np, preds_np)
    precision = precision_score(y_true_np, preds_np, average='macro', zero_division=0)
    recall = recall_score(y_true_np, preds_np, average='macro', zero_division=0)
    f1 = f1_score(y_true_np, preds_np, average='macro', zero_division=0)
    
    return accuracy, f1, precision, recall

# Defining the hyperparametric search space
search_space = {
    'head1': [12],
    'head2': [1],
    'head3': [3],
    'hidden_channels': [480],
    'query_scalings' : [0.8, 0.85, 0.9, 0.95, 1.0],
    'key_scalings' : [0.9, 0.95, 1.0, 1.05, 1.1],
    'value_scalings' : [1.2, 1.25, 1.3, 1.35, 1.4]
}
    
best_acc = 0.0
best_hyperparameters = None

# use itertools.product
for hyperparameters in itertools.product(*search_space.values()):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
    head1, head2, head3, hidden_channels, qs, ks, vs = hyperparameters
    
    scaling_factors = [qs, ks, vs]
    
    print(f"Training with head1={head1}, head2={head2}, head3={head3}, hidden_channels={hidden_channels}, qs={qs}, ks={ks}, vs={vs}", flush=True)
    
    seed = 640
    random.seed(seed)
    np.random.seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
    # load Cora
    dataset = Planetoid(root='./Citeseer', name='Citeseer')
    data = dataset[0]
    data = data.to(device)
    
    model = HierarchicalKAGNN(in_channels=data.x.size(1), 
                              hidden_channels=hidden_channels, 
                              out_channels=dataset.num_classes, 
                              scaling_factors=scaling_factors, 
                              head1=head1, head2=head2, head3=head3).to(device)
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=0.0001, weight_decay=0)
    criterion = nn.CrossEntropyLoss()
    

    for epoch in range(50):
        loss = train()
        if loss == float('inf'):
            print(f'Stopping training due to NaN or infinite values at epoch {epoch}', flush=True)
            break
        acc, f1, precision, recall = test()
        print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Test Accuracy: {acc:.4f}, precision: {precision:.4f}, recall: {recall:.4f}, f1: {f1:.4f}', flush=True)
    
        if acc > best_acc:
            best_acc = acc
            best_hyperparameters = {
                'head1': head1,
                'head2': head2,
                'head3': head3,
                'hidden_channels': hidden_channels,
                'qs':qs,
                'ks':ks,
                'vs':vs
            }
            if acc > 0.70:
                torch.save(model.state_dict(), './best_model_Citeseer.pth')
                print(f"New best model saved at epoch {epoch} with accuracy {acc:.4f}", flush=True)    
    
    #clean space
    model = None
    optimizer = None
    data = None
    dataset = None
    del model, optimizer, data, dataset
    gc.collect()

# print best result
print(f"Best accuracy: {best_acc:.4f} with hyperparameters {best_hyperparameters}")
