#!/usr/bin/env python
# coding: utf-8

#Best LR: 0.0001, Best heads1: 18, Best heads1: 1, Best Optimizer: AdamW, Best Weight Decay: 0, Best en_inner: 2048, Best cla_inner: 512, Best num_latent: 2048, Best beta: 4500, Best epoch1: 160, Best epoch2: 65, Best accuracy: 0.882, Best precision: 0.8737130055776394, Best recall: 0.8669810991227285, Best f1: 0.8687265141986422

import torch
import torch.nn.functional as F
from torch.nn import Linear
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GATConv, GCNConv, SAGEConv, GINConv
from torch_geometric.loader import DataLoader
from torch.optim import Adam
import itertools
from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix, roc_auc_score, accuracy_score
import gc

torch.manual_seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# reset random seed
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)
    torch.cuda.manual_seed_all(42)  



class MultiScaleGNN(torch.nn.Module):
    def __init__(self, num_node_features, num_latent, num_classes, heads1, heads2, en_inner, cla_inner):
        super(MultiScaleGNN, self).__init__()
        self.lin0 = Linear(num_node_features, en_inner)
        
        self.gcn1 = GCNConv(en_inner, en_inner)
        self.sage1 = SAGEConv(en_inner, en_inner)
        self.gin1 = GINConv(Linear(en_inner, en_inner))
        self.gat1 = GATConv(en_inner, en_inner, heads1, dropout=0.6)
        
        self.gcn2 = GCNConv(en_inner * (heads1 + 3), en_inner)
        self.sage2 = SAGEConv(en_inner * (heads1 + 3), en_inner)
        self.gat2 = GATConv(en_inner* (heads1 + 3), en_inner, heads2, dropout=0.6)
        
        self.lin1 = Linear(en_inner * (heads1 + 3), num_latent)  # mu
        self.lin2 = Linear(en_inner * (heads1 + 3), num_latent)  # logvar
        
        self.decoder_lin = Linear(num_latent, num_node_features)
        
        self.classifier1 = Linear(num_latent, cla_inner)
        self.classifier2 = Linear(cla_inner, num_classes)

    def encode(self, x, edge_index):
        x = F.relu(self.lin0(x))
        
        gcn_x1 = F.relu(self.gcn1(x, edge_index))
        sage_x1 = F.relu(self.sage1(x, edge_index))
        gin_x1 = F.relu(self.gin1(x, edge_index))
        gat_x1 = F.relu(self.gat1(x, edge_index))
        
        x1 = torch.cat([gcn_x1, sage_x1, gin_x1, gat_x1], dim=1)
        
        mu = self.lin1(x1)
        logvar = self.lin2(x1)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        if self.training:
            std = logvar.mul(0.5).exp_()
            eps = torch.randn_like(std)
            return eps.mul(std).add_(mu)
        else:
            return mu

    def decode(self, z):
        return self.decoder_lin(z)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        mu, logvar = self.encode(x, edge_index)
        z = self.reparameterize(mu, logvar)
        recon_x = self.decode(z)
        class_logits = self.classifier1(z)
        class_logits = self.classifier2(class_logits)
        return recon_x, class_logits, mu, logvar


def loss_function(recon_x, x, mu, logvar):
    recon_loss = F.mse_loss(recon_x, x, reduction='sum')
    kl_divergence = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return recon_loss + kl_divergence

class GATNet(torch.nn.Module):
    def __init__(self, num_features, num_classes, hidden):
        super(GATNet, self).__init__()
        self.gatconv1 = GATConv(num_features, hidden, heads=20)
        self.gatconv2 = GATConv(hidden * 20, num_classes, heads=20)
    
    def forward(self, x, edge_index):
        x = F.dropout(x, training=self.training)
        x = F.elu(self.gatconv1(x, edge_index))
        x = F.dropout(x, training=self.training)
        x = self.gatconv2(x, edge_index)
        return F.log_softmax(x, dim=1)


# define hyperparameter space
learning_rates = [0.0001]
heads1_options = [17, 18, 19]
heads2_options = [1]
weight_decays = [0]  
en_inners = [2048] 
cla_inners = [512] 
gcn_inners = [1024] 
num_latents = [2048] 

epoch1_options = [150, 155, 160, 165, 170]
epoch2_options = [55, 60, 65, 70, 75]
alpha_options = [1]
beta_options = [4300, 4500, 4600, 4700, 4800, 4900]

results = []

optimizer_options = {
    'Adam': torch.optim.Adam,
    #'SGD': torch.optim.SGD,
    #'RMSprop': torch.optim.RMSprop,
    'AdamW': torch.optim.AdamW
}

for lr, num_latent, heads1, heads2, weight_decay, optimizer_name, en_inner, cla_inner, gcn_inner, alpha, beta, epoch1, epoch2 in itertools.product(learning_rates, num_latents, heads1_options, heads2_options, weight_decays, optimizer_options.keys(), en_inners, cla_inners, gcn_inners, alpha_options, beta_options, epoch1_options, epoch2_options):
    torch.cuda.empty_cache()
    torch.manual_seed(42)
    # Load the Cora dataset
    dataset = Planetoid(root='./Cora', name='Cora')
    data = dataset[0].to(device)
    
    # Initialize and train GraphVAEGAT model
    model = MultiScaleGNN(num_node_features=data.num_features, num_latent=num_latent, num_classes=dataset.num_classes, heads1 = heads1, 
    heads2 = heads2, en_inner = en_inner, cla_inner = cla_inner).to(device)
    
    optimizer = optimizer_options[optimizer_name](model.parameters(), lr=lr, weight_decay=weight_decay)
    
    def loss_function(recon_x, x, mu, logvar):
        recon_loss = F.mse_loss(recon_x, x, reduction='sum')
        kl_divergence = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        return recon_loss + kl_divergence
    
    model.train()
    for epoch in range(epoch1):
        optimizer.zero_grad()
        recon_x, class_logits, mu, logvar = model(data)
        loss = alpha * loss_function(recon_x, data.x, mu, logvar) + beta * F.nll_loss(class_logits, data.y)
        loss.backward()
        optimizer.step()
    
    # Extract node representations using the pre-trained GraphVAEGAT model
    model.eval()
    with torch.no_grad():
        node_repr, _ = model.encode(data.x, data.edge_index)
    
    combined_features = torch.cat([data.x, node_repr], dim=1).to(device)

    data.x = combined_features
    
    #Initialize and train GAT-based model using node representations
    gat_model = GATNet(num_features=data.x.size(1), num_classes=dataset.num_classes, hidden=6).to(device)
    optimizer = torch.optim.AdamW(gat_model.parameters(), lr=0.001, weight_decay=0.001)
    data = data.to(device)

    gat_model.train()
    for epoch in range(epoch2):
        optimizer.zero_grad()
        out = gat_model(combined_features, data.edge_index)
        loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask]) 
        loss.backward()
        optimizer.step()
    
    # Evaluation
    gat_model.eval()
    with torch.no_grad():
        logits = gat_model(combined_features, data.edge_index)  
        preds = logits.argmax(dim=1)
        preds_np = preds[data.test_mask].cpu().numpy()
        y_true_np = data.y[data.test_mask].cpu().numpy()
        
        accuracy = accuracy_score(y_true_np, preds_np)
        
        precision = precision_score(y_true_np, preds_np, average='macro', zero_division=0)
        recall = recall_score(y_true_np, preds_np, average='macro', zero_division=0)
        f1 = f1_score(y_true_np, preds_np, average='macro', zero_division=0)
        
        print(f"LR: {lr}, heads1: {heads1}, heads2: {heads2}, Optimizer: {optimizer_name}, Weight Decay: {weight_decay}, en_inner: {en_inner}, cla_inner: {cla_inner}, num_latent: {num_latent}, beta: {beta}, epoch1: {epoch1}, epoch2: {epoch2}, accuracy: {accuracy}, precision: {precision}, recall: {recall}, f1: {f1}", flush=True)
        # 记录结果
        results.append((lr, heads1, heads2, optimizer_name, weight_decay, en_inner, cla_inner, num_latent, beta, epoch1, epoch2, accuracy, precision, recall, f1))
    
    #clean space
    model = None
    gat_model = None
    optimizer = None
    data = None
    dataset = None
    del model, gat_model, optimizer, data, dataset
    gc.collect()

# print best model
best_result = max(results, key=lambda x: x[11]) 
print(f"Best LR: {best_result[0]}, Best heads1: {best_result[1]}, Best heads1: {best_result[2]}, Best Optimizer: {best_result[3]}, Best Weight Decay: {best_result[4]}, Best en_inner: {best_result[5]}, Best cla_inner: {best_result[6]}, Best num_latent: {best_result[7]}, Best beta: {best_result[8]}, Best epoch1: {best_result[9]}, Best epoch2: {best_result[10]}, Best accuracy: {best_result[11]}, Best precision: {best_result[12]}, Best recall: {best_result[13]}, Best f1: {best_result[14]}", flush=True)



