# -*- coding: utf-8 -*-
"""train_gnn.ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/1NuqW-aGowo1ilzVQDVn241Ch72SN2ulU
"""

!pip3 install torch_geometric
!pip3 install torch_scatter -f https://pytorch-geometric.com/whl/torch-2.4.0%2Bcu121.html

# Commented out IPython magic to ensure Python compatibility.
from google.colab import drive
drive.mount('/content/drive/')
# %cd /content/drive/MyDrive/GNN/


import numpy as np
import math
import time

import seaborn as sns
from tqdm.notebook import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import LambdaLR

from torch_geometric.data import Data, Dataset
from torch_geometric.loader import DataLoader
import torch_geometric.nn as tgnn
from torch_scatter import scatter

from sklearn.preprocessing import MinMaxScaler, StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, r2_score

class Metric(object):
    def __init__(self):
        self.all = self.init_pack()

    def init_pack(self):
        return {
            'cnt': 0,
            'apes': [],                                # absolute percentage error
            'errbnd_cnt': np.array([0.0, 0.0, 0.0]),   # error bound count
            'errbnd_val': np.array([0.3, 0.1, 0.05]), # error bound value: 0.1, 0.05, 0.01
        }

    def update_pack(self, ps, gs, pack):
        for i in range(len(ps)):
            ape = np.abs(ps[i] - gs[i]) / (np.abs(gs[i]) + 0.000001)
            pack['errbnd_cnt'][ape <= pack['errbnd_val']] += 1
            pack['apes'].append(ape)
        pack['cnt'] += len(ps)

    def measure_pack(self, pack):
        acc = np.mean(pack['apes'])
        err = (pack['errbnd_cnt'] / pack['cnt'])[0]
        err1 = (pack['errbnd_cnt'] / pack['cnt'])[1]
        err2 = (pack['errbnd_cnt'] / pack['cnt'])[2]
        return acc, err, err1, err2, pack['cnt']

    def update(self, ps, gs):
        self.update_pack(ps, gs, self.all)

    def get(self):
        return self.measure_pack(self.all)


node_train = torch.load("./raw_data/old.node.pt").numpy()
edge_train = torch.load("./raw_data/old.edge.pt").numpy()
data_train = torch.load("./raw_data/old.global.pt").numpy()
y_train = torch.load("./raw_data/old.total_energy.pt").numpy()

node_test = torch.load("./raw_data/gemma.node.pt").numpy()
edge_test = torch.load("./raw_data/gemma.edge.pt").numpy()
data_test = torch.load("./raw_data/gemma.global.pt").numpy()
y_test = torch.load("./raw_data/gemma.total_energy.pt").numpy()

node_test2 = torch.load("./raw_data/gemma2.node.pt").numpy()
edge_test2 = torch.load("./raw_data/gemma2.edge.pt").numpy()
data_test2 = torch.load("./raw_data/gemma2.global.pt").numpy()
y_test2 = torch.load("./raw_data/gemma2.total_energy.pt").numpy()

node_test3 = torch.load("./raw_data/bloom.node.pt").numpy()
edge_test3 = torch.load("./raw_data/bloom.edge.pt").numpy()
data_test3 = torch.load("./raw_data/bloom.global.pt").numpy()
y_test3 = torch.load("./raw_data/bloom.total_energy.pt").numpy()

node_test4 = torch.load("./raw_data/qwen2.node.pt").numpy()
edge_test4 = torch.load("./raw_data/qwen2.edge.pt").numpy()
data_test4 = torch.load("./raw_data/qwen2.global.pt").numpy()
y_test4 = torch.load("./raw_data/qwen2.total_energy.pt").numpy()

node_test5 = torch.load("./raw_data/llama.node.pt").numpy()
edge_test5 = torch.load("./raw_data/llama.edge.pt").numpy()
data_test5 = torch.load("./raw_data/llama.global.pt").numpy()
y_test5 = torch.load("./raw_data/llama.total_energy.pt").numpy()

node_test6 = torch.load("./raw_data/mixtral.node.pt").numpy()
edge_test6 = torch.load("./raw_data/mixtral.edge.pt").numpy()
data_test6 = torch.load("./raw_data/mixtral.global.pt").numpy()
y_test6 = torch.load("./raw_data/mixtral.total_energy.pt").numpy()


scaler = StandardScaler()
scaler.fit(data_train)
data_train = scaler.transform(data_train)
data_test = scaler.transform(data_test)
data_test2 = scaler.transform(data_test2)
data_test3 = scaler.transform(data_test3)
data_test4 = scaler.transform(data_test4)
data_test5 = scaler.transform(data_test5)
data_test6 = scaler.transform(data_test6)

scaler = StandardScaler()
node_train = scaler.fit_transform(node_train.reshape(-1, node_train.shape[-1])).reshape(node_train.shape)
node_test = scaler.transform(node_test.reshape(-1, node_test.shape[-1])).reshape(node_test.shape)
node_test2 = scaler.transform(node_test2.reshape(-1, node_test2.shape[-1])).reshape(node_test2.shape)
node_test3 = scaler.transform(node_test3.reshape(-1, node_test3.shape[-1])).reshape(node_test3.shape)
node_test4 = scaler.transform(node_test4.reshape(-1, node_test4.shape[-1])).reshape(node_test4.shape)
node_test5 = scaler.transform(node_test5.reshape(-1, node_test5.shape[-1])).reshape(node_test5.shape)
node_test6 = scaler.transform(node_test6.reshape(-1, node_test6.shape[-1])).reshape(node_test6.shape)

scaler = StandardScaler()
y_train = y_train.reshape(-1,1)
y_test = y_test.reshape(-1,1)
y_test2 = y_test2.reshape(-1,1)
y_test3 = y_test3.reshape(-1,1)
y_test4 = y_test4.reshape(-1,1)
y_test5 = y_test5.reshape(-1,1)
y_test6 = y_test6.reshape(-1,1)
scaler.fit_transform(y_train)
y_train = scaler.transform(y_train)
y_test = scaler.transform(y_test)
y_test2 = scaler.transform(y_test2)
y_test3 = scaler.transform(y_test3)
y_test4 = scaler.transform(y_test4)
y_test5 = scaler.transform(y_test5)
y_test6 = scaler.transform(y_test6)



class GraphDataset(Dataset):
  def __init__(self, node, edge, data, label):
    super(GraphDataset, self).__init__("./", None, None)
    self.node_list = []
    self.edge_list = []
    self.data_list = []
    self.label_list = []
    self.graph_list = []

    l_node = node.tolist()
    l_edge = edge.tolist()
    l_data = data.tolist()
    l_label = label.tolist()

    for i in range(len(l_node)):
      temp = np.array(l_node[i], dtype=np.float32)
      #print(temp.shape)
      self.node_list.append(temp)
      temp = np.array(l_edge[i], dtype=np.int64)
      #print(temp.shape)
      self.edge_list.append(temp)
      temp = np.array(l_data[i], dtype=np.float32)
      #print(temp.shape)
      self.data_list.append(torch.from_numpy(temp).type(torch.float))
      temp = np.array(l_label[i], dtype=np.float32)
      #print(temp.shape)
      self.label_list.append(temp)

    self.process_data()

  def process_data(self):
    for i in range(len(self.node_list)):
      node = torch.from_numpy(self.node_list[i]).type(torch.float)
      edge = torch.from_numpy(self.edge_list[i]).type(torch.long)
      label = torch.from_numpy(self.label_list[i]).type(torch.float)
      mdata = Data(x = node, edge_index = edge, y = label)
      self.graph_list.append(mdata)

  def get(self, idx):
    return self.graph_list[idx], self.data_list[idx]

  def len(self):
    return len(self.node_list)


train_dataset = GraphDataset(node_train, edge_train, data_train, y_train)
test_dataset = GraphDataset(node_test, edge_test, data_test, y_test)
test_dataset2 = GraphDataset(node_test2, edge_test2, data_test2, y_test2)
test_dataset3 = GraphDataset(node_test3, edge_test3, data_test3, y_test3)
test_dataset4 = GraphDataset(node_test4, edge_test4, data_test4, y_test4)
test_dataset5 = GraphDataset(node_test5, edge_test5, data_test5, y_test5)
test_dataset6 = GraphDataset(node_test6, edge_test6, data_test6, y_test6)

def init_tensor(tensor, init_type, nonlinearity):
    if tensor is None or init_type is None:
        return
    if init_type =='thomas':
        size = tensor.size(-1)
        stdv = 1. / math.sqrt(size)
        nn.init.uniform_(tensor, -stdv, stdv)
    elif init_type == 'kaiming_normal_in':
        nn.init.kaiming_normal_(tensor, mode='fan_in', nonlinearity=nonlinearity)
    elif init_type == 'kaiming_normal_out':
        nn.init.kaiming_normal_(tensor, mode='fan_out', nonlinearity=nonlinearity)
    elif init_type == 'kaiming_uniform_in':
        nn.init.kaiming_uniform_(tensor, mode='fan_in', nonlinearity=nonlinearity)
    elif init_type == 'kaiming_uniform_out':
        nn.init.kaiming_uniform_(tensor, mode='fan_out', nonlinearity=nonlinearity)
    elif init_type == 'orthogonal':
        nn.init.orthogonal_(tensor, gain=nn.init.calculate_gain(nonlinearity))
    else:
        raise ValueError(f'Unknown initialization type: {init_type}')


class GraphNet(torch.nn.Module):
  def __init__(self, num_node_features=44, gnn_layer="SAGEConv", num_other_features=33, gnn_hidden=512, fc_hidden=512, reduce_func="max",):
    super(GraphNet, self).__init__()

    self.reduce_func = reduce_func
    self.num_node_features = num_node_features
    self.num_other_features = num_other_features
    self.gnn_layer_func = getattr(tgnn, gnn_layer)
    self.graph_conv_1 = self.gnn_layer_func(num_node_features, gnn_hidden, normalize=True)

    self.graph_conv_2 = self.gnn_layer_func(gnn_hidden, gnn_hidden, normalize=True)
    self.gnn_drop_1 = nn.Dropout(p=0.05)
    self.gnn_drop_2 = nn.Dropout(p=0.05)
    self.gnn_relu1 = nn.ReLU()
    self.gnn_relu2 = nn.ReLU()

    self.norm_sf_linear = nn.Linear(num_other_features, gnn_hidden)
    self.norm_sf_drop = nn.Dropout(p=0.05)
    self.norm_sf_relu = nn.ReLU()

    self.fc_1 = nn.Linear(gnn_hidden + gnn_hidden, fc_hidden)
    self.fc_2 = nn.Linear(fc_hidden, fc_hidden)
    self.fc_drop_1 = nn.Dropout(p=0.05)
    self.fc_drop_2 = nn.Dropout(p=0.05)
    self.fc_relu1 = nn.ReLU()
    self.fc_relu2 = nn.ReLU()
    self.predictor = nn.Linear(fc_hidden, 1)
    self._initialize_weights()

  def _initialize_weights(self):
    for m in self.modules():
      if isinstance(m, nn.Linear):
        init_tensor(m.weight, "thomas", "relu")
        init_tensor(m.bias, "thomas", "relu")
      elif isinstance(m, self.gnn_layer_func):
        pass

  def forward(self, graph, data):
    x, A = graph.x, graph.edge_index
    x = self.graph_conv_1(x, A)
    x = self.gnn_relu1(x)
    x = self.gnn_drop_1(x)

    x = self.graph_conv_2(x, A)
    x = self.gnn_relu2(x)
    x = self.gnn_drop_2(x)
    gnn_feat = scatter(x, graph.batch, dim=0, reduce=self.reduce_func)

    static_feature = self.norm_sf_linear(data)
    static_feature = self.norm_sf_drop(static_feature)
    static_feature = self.norm_sf_relu(static_feature)
    x = torch.cat([gnn_feat, static_feature], dim=1)

    x = self.fc_1(x)
    x = self.fc_relu1(x)
    x = self.fc_drop_1(x)
    x = self.fc_2(x)
    x = self.fc_relu2(x)
    feat = self.fc_drop_2(x)
    x = self.predictor(feat)

    return x




device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


EPOCHS = 500
BATCH_SIZE = 1024
LEARNING_RATE = 0.01

NUM_FEATURES = node_train.shape[2]
NUM_STATIC_FEATURES = data_train.shape[1]


train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE)
test_loader2 = DataLoader(dataset=test_dataset2, batch_size=BATCH_SIZE)
test_loader3 = DataLoader(dataset=test_dataset3, batch_size=BATCH_SIZE)
test_loader4 = DataLoader(dataset=test_dataset4, batch_size=BATCH_SIZE)
test_loader5 = DataLoader(dataset=test_dataset5, batch_size=BATCH_SIZE)
test_loader6 = DataLoader(dataset=test_dataset6, batch_size=BATCH_SIZE)


model = GraphNet(num_node_features=NUM_FEATURES, gnn_layer="SAGEConv", num_other_features=NUM_STATIC_FEATURES, gnn_hidden=512, fc_hidden=512, reduce_func="sum",)
model.to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)


lambda2 = lambda epoch: 0.95 ** epoch
scheduler = LambdaLR(optimizer, lr_lambda=[lambda2])


for e in tqdm(range(1, EPOCHS+1)):
  metric2 = Metric()
  #metric3 = Metric()
  model.train()
  num_iter = len(train_loader)
  iteration = 0

  for graph_train_batch, data_train_batch in train_loader:
    graph_train_batch.y = graph_train_batch.y.view(-1, 1)
    graph_train_batch, data_train_batch = graph_train_batch.to(device), data_train_batch.to(device)
    optimizer.zero_grad()
    y_train_pred = model(graph_train_batch, data_train_batch)
    train_loss = criterion(y_train_pred, graph_train_batch.y)
    train_loss.backward()
    optimizer.step()
    ps = y_train_pred.data.cpu().numpy()[:, 0].tolist()
    gs = graph_train_batch.y.data.cpu().numpy()[:, 0].tolist()
    metric2.update(ps, gs)
    acc, err, err1, err2, cnt = metric2.get()
#    #if iteration % 10 == 0:
#      #print("Epoch[{}/{}]({}/{}) Loss:{:.5f} MAPE:{:.5f} " "ErrBnd(0.1):{:.5f}".format(e, EPOCHS, iteration+1, num_iter, train_loss.item(), acc, err))
    iteration += 1

  #with torch.no_grad():
  #  model.eval()
  #  for graph_val_batch, data_val_batch in val_loader:
  #    graph_val_batch.y = graph_val_batch.y.view(-1, 1)
  #    graph_val_batch, data_val_batch = graph_val_batch.to(device), data_val_batch.to(device)
  #    y_val_pred = model(graph_val_batch, data_val_batch)
  #    val_loss = criterion(y_val_pred / graph_val_batch.y, graph_val_batch.y / graph_val_batch.y)
  #    ps = y_val_pred.data.cpu().numpy()[:, 0].tolist()
  #    gs = graph_val_batch.y.data.cpu().numpy()[:, 0].tolist()
  #    metric3.update(ps, gs)
  #    acc, err, cnt = metric3.get()
      #print("Val starts: Loss:{:.5f} MAPE:{:.5f} " "ErrBnd(0.1):{:.5f}".format(val_loss.item(), acc, err))

  scheduler.step()



torch.save(model.state_dict(), "graph_model.pt")
#model.load_state_dict(torch.load("graph_model.pt", weights_only=True))

gemma_list = []
metric4 = Metric()
with torch.no_grad():
  model.eval()
  for graph_batch, data_batch in test_loader:
    graph_batch.y = graph_batch.y.view(-1, 1)
    graph_batch, data_batch = graph_batch.to(device), data_batch.to(device)
    y_test_pred = model(graph_batch, data_batch)
    ps = y_test_pred.data.cpu().numpy()[:, 0].tolist()
    gemma_list.extend(ps)
    gs = graph_batch.y.data.cpu().numpy()[:, 0].tolist()
    metric4.update(ps, gs)

#acc, err, cnt = metric3.get()
#print("neural network test - MAPE:{:.5f} " "ErrBnd(0.1):{:.5f}".format(acc, err))
acc, err, err1, err2, cnt = metric4.get()
print("Graph neural network gemma test - MAPE:{:.5f} ErrBnd(0.3):{:.5f} ErrBnd(0.1):{:.5f} ErrBnd(0.05):{:.5f}".format(acc, err, err1, err2))


gemma2_list = []
metric4 = Metric()
with torch.no_grad():
  model.eval()
  for graph_batch, data_batch in test_loader2:
    graph_batch.y = graph_batch.y.view(-1, 1)
    graph_batch, data_batch = graph_batch.to(device), data_batch.to(device)
    y_test_pred = model(graph_batch, data_batch)
    ps = y_test_pred.data.cpu().numpy()[:, 0].tolist()
    gemma2_list.extend(ps)
    gs = graph_batch.y.data.cpu().numpy()[:, 0].tolist()
    metric4.update(ps, gs)

#acc, err, cnt = metric3.get()
#print("neural network test - MAPE:{:.5f} " "ErrBnd(0.1):{:.5f}".format(acc, err))
acc, err, err1, err2, cnt = metric4.get()
print("Graph neural network gemma2 test - MAPE:{:.5f} ErrBnd(0.3):{:.5f} ErrBnd(0.1):{:.5f} ErrBnd(0.05):{:.5f}".format(acc, err, err1, err2))

bloom_list = []
metric4 = Metric()
with torch.no_grad():
  model.eval()
  for graph_batch, data_batch in test_loader3:
    graph_batch.y = graph_batch.y.view(-1, 1)
    graph_batch, data_batch = graph_batch.to(device), data_batch.to(device)
    y_test_pred = model(graph_batch, data_batch)
    ps = y_test_pred.data.cpu().numpy()[:, 0].tolist()
    bloom_list.extend(ps)
    gs = graph_batch.y.data.cpu().numpy()[:, 0].tolist()
    metric4.update(ps, gs)

#acc, err, cnt = metric3.get()
#print("neural network test - MAPE:{:.5f} " "ErrBnd(0.1):{:.5f}".format(acc, err))
acc, err, err1, err2, cnt = metric4.get()
print("Graph neural network bloom test - MAPE:{:.5f} ErrBnd(0.3):{:.5f} ErrBnd(0.1):{:.5f} ErrBnd(0.05):{:.5f}".format(acc, err, err1, err2))

qwen2_list = []
metric4 = Metric()
with torch.no_grad():
  model.eval()
  for graph_batch, data_batch in test_loader4:
    graph_batch.y = graph_batch.y.view(-1, 1)
    graph_batch, data_batch = graph_batch.to(device), data_batch.to(device)
    y_test_pred = model(graph_batch, data_batch)
    ps = y_test_pred.data.cpu().numpy()[:, 0].tolist()
    qwen2_list.extend(ps)
    gs = graph_batch.y.data.cpu().numpy()[:, 0].tolist()
    metric4.update(ps, gs)

#acc, err, cnt = metric3.get()
#print("neural network test - MAPE:{:.5f} " "ErrBnd(0.1):{:.5f}".format(acc, err))
acc, err, err1, err2, cnt = metric4.get()
print("Graph neural network qwen2 test - MAPE:{:.5f} ErrBnd(0.3):{:.5f} ErrBnd(0.1):{:.5f} ErrBnd(0.05):{:.5f}".format(acc, err, err1, err2))

llama_list = []
metric4 = Metric()
with torch.no_grad():
  model.eval()
  for graph_batch, data_batch in test_loader5:
    graph_batch.y = graph_batch.y.view(-1, 1)
    graph_batch, data_batch = graph_batch.to(device), data_batch.to(device)
    y_test_pred = model(graph_batch, data_batch)
    ps = y_test_pred.data.cpu().numpy()[:, 0].tolist()
    llama_list.extend(ps)
    gs = graph_batch.y.data.cpu().numpy()[:, 0].tolist()
    metric4.update(ps, gs)

#acc, err, cnt = metric3.get()
#print("neural network test - MAPE:{:.5f} " "ErrBnd(0.1):{:.5f}".format(acc, err))
acc, err, err1, err2, cnt = metric4.get()
print("Graph neural network llama test - MAPE:{:.5f} ErrBnd(0.3):{:.5f} ErrBnd(0.1):{:.5f} ErrBnd(0.05):{:.5f}".format(acc, err, err1, err2))

mixtral_list=[]
metric4 = Metric()
with torch.no_grad():
  model.eval()
  for graph_batch, data_batch in test_loader6:
    graph_batch.y = graph_batch.y.view(-1, 1)
    graph_batch, data_batch = graph_batch.to(device), data_batch.to(device)
    y_test_pred = model(graph_batch, data_batch)
    ps = y_test_pred.data.cpu().numpy()[:, 0].tolist()
    mixtral_list.extend(ps)
    gs = graph_batch.y.data.cpu().numpy()[:, 0].tolist()
    metric4.update(ps, gs)

#acc, err, cnt = metric3.get()
#print("neural network test - MAPE:{:.5f} " "ErrBnd(0.1):{:.5f}".format(acc, err))
acc, err, err1, err2, cnt = metric4.get()
print("Graph neural network mixtral test - MAPE:{:.5f} ErrBnd(0.3):{:.5f} ErrBnd(0.1):{:.5f} ErrBnd(0.05):{:.5f}".format(acc, err, err1, err2))


with open('key_all.txt', 'w') as f:
  for item in gemma_list:
    f.write("%s\n" % str(item))

  for item in gemma2_list:
    f.write("%s\n" % str(item))

  for item in bloom_list:
    f.write("%s\n" % str(item))

  for item in qwen2_list:
    f.write("%s\n" % str(item))

  for item in llama_list:
    f.write("%s\n" % str(item))

  for item in mixtral_list:
    f.write("%s\n" % str(item))

  f.close()