import argparse
import os.path as osp
import time

import json
import torch
import torch.nn as nn
import pyg_lib
import sys
import copy
import os
import seaborn as sns
import time
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from torch_geometric.nn import MessagePassing, GCNConv, GATv2Conv, GINConv 
from torch_geometric.utils import degree, add_self_loops
from sklearn import metrics
from layers import SAGEAggregator, SpikeGCNConv, SpikeGCNConvScale, SpikeGCNConv_multi, SpikeGCNConv_OneThreshold, SpikeGCNConvDegree, SpikeGCNConvDegreeFeat, SpikeGCNConvDegreeFeatCluster, SpikeGINConvDegreeFeat, SpikeGATConvDegreeFeat
from utils import (add_selfloops, set_seed, tab_printer)
# from torch.utils.data import DataLoader
from torch_geometric.datasets import Flickr, Reddit, Planetoid, Reddit2, Yelp, TUDataset, GNNBenchmarkDataset
from torch_geometric.utils import to_scipy_sparse_matrix, degree
from tqdm import tqdm
from spikingjelly.clock_driven import encoding, functional
from torch_geometric.loader import NeighborLoader, DataLoader
from torch import tensor
from torch.utils.data import Subset
from matplotlib import pyplot as plt
from utils import rename_folder_with_suffix
import neuron
import numpy as np
from sklearn.cluster import KMeans
from sklearn.neighbors import KNeighborsClassifier


from torch_geometric.data import Data
from torch_geometric.nn import LayerNorm
from sklearn.model_selection import train_test_split
from torch_geometric.nn.pool import global_mean_pool, global_add_pool
from datetime import datetime
from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import KFold

import requests
from bson import ObjectId

class JSONEncoder(json.JSONEncoder):
    """Extend json-encoder class to support ObjectId."""

    def default(self, o):
        if isinstance(o, ObjectId):
            return str(o)
        return json.JSONEncoder.default(self, o)


class StreamRedirector:
    def __init__(self, file_object, stdout):
        self.file_object = file_object
        self.stdout = stdout

    def write(self, message):
        self.file_object.write(message)
        self.stdout.write(message)

    def flush(self):
        # This flushes the message to the file and to the console
        self.file_object.flush()
        self.stdout.flush()

class SNN(nn.Module):
    def __init__(self, in_features, out_features, data_x, data, hids=[32], alpha=1.0, T=1,
                 dropout=0.5, bias=True, aggr='mean', sampler='sage',
                 surrogate='triangle', sizes=[5, 2], concat=False, act='LIF'):

        super().__init__()

        tau = 1.0
        sampler == 'gcn'
        aggregators, snn = nn.ModuleList(), nn.ModuleList()
        hids.append(out_features)

        for hid in hids:
            aggregators.append(SAGEAggregator(in_features, hid,
                                              concat=concat, bias=bias,
                                              aggr=aggr))

            snn.append(neuron.LIF(tau, alpha=alpha, surrogate=surrogate))
            in_features = hid * 2 if concat else hid
        self.data = data
        self.out_features = out_features
        self.data_x = data_x
        self.aggregators = aggregators
        self.dropout = nn.Dropout(dropout)
        self.snn = snn
        self.sizes = sizes
        self.T = T
        # self.pooling = nn.Linear(T * in_features, out_features)

    def encode(self, nodes, num_nodes, h):
        spikes = []
        sizes = self.sizes

        for i, aggregator in enumerate(self.aggregators):
            self_x = h[:-1]
            neigh_x = []
            for j, n_x in enumerate(h[1:]):
                neigh_x.append(n_x.view(-1, sizes[j], h[0].size(-1)))

            out = self.snn[i](aggregator(self_x, neigh_x))
            if i != len(sizes) - 1:
                out = self.dropout(out)
                h = torch.split(out, num_nodes[:-(i + 1)])
        spikes = out
        return spikes
    

    def forward(self, nodes):
        sizes = self.sizes
        x = self.data_x
        input_encoder = encoding.PoissonEncoder()
        # if self.poisson:
            # x = input_encoder(x)
        hidden = [x[nodes].to(device)]
        num_nodes = [nodes.size(0)]
        nbr = nodes
        for size in sizes:
            nbr = self.sampler(nbr, size)
            num_nodes.append(nbr.size(0))
            hidden.append(x[nbr].to(device))
            
        for t in range(args.T):
            if t==0:
                out_spikes_counter = self.encode(nodes, num_nodes, hidden)
            else:
                out_spikes_counter += self.encode(nodes, num_nodes, hidden)
        
        # STDP algorithm
        neuron.reset_net(self)
        return out_spikes_counter / args.T

class SNNGCNN(torch.nn.Module): 
     
    def __init__(self, in_features, out_features, hidden = [128],
                 dropout = 0.5, bias = False, T = 15, neuron_type = 'LIF',
                 quantize = False, bn = False, thtr = False,
                 aggr='mean', thr = 0.25):
        super().__init__()

        # BN will be deprecated with values BNTT, TDBN Layers,
        self.bn = bn        
        self.hidden = hidden
        self.convs = nn.ModuleList()
        self.bns = nn.ModuleList()
        self.convs.append(SpikeGCNConv(in_features, self.hidden[0], neuron_type = neuron_type, quantize = quantize, threshold_trainable=thtr, aggr=aggr, thr =thr))
        for idx, hid in enumerate(self.hidden):
            self.bns.append(nn.BatchNorm1d(self.hidden[idx]))
            if idx == len(self.hidden) - 1:
                self.convs.append(SpikeGCNConv(self.hidden[idx], out_features, neuron_type = neuron_type, quantize = quantize, threshold_trainable=thtr, aggr=aggr, thr=thr))
            else:
                self.convs.append(SpikeGCNConv(self.hidden[idx], self.hidden[idx+1], neuron_type = neuron_type, quantize = quantize, threshold_trainable=thtr, aggr=aggr, thr=thr))

        self.T = T
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, edge_index):
        for t in range(self.T):
            if t == 0 :
                out_spike_counter =  self.encode(x, edge_index)
            else:
                out_spike_counter += self.encode(x, edge_index)

        neuron.reset_net(self)
        
        return out_spike_counter / self.T
        
    def encode(self, x, edge_index):
        input_encoder = encoding.PoissonEncoder()
        x = input_encoder(x)
        for idx, conv in enumerate(self.convs):
            x = conv(x, edge_index)
            if idx != len(self.convs) -1:
                x = self.dropout(x)
            if self.bn:
                x = self.bns[idx](x)             
        
        return x
    

class SNNGCNN_GC(torch.nn.Module): 
     
    def __init__(self, in_features, out_features, hidden = [128],
                 dropout = 0.5, bias = False, T = 15, neuron_type = 'LIF',
                 quantize = False, bn = False, thtr = False,
                 aggr='mean', thr = 0.25, poisson = True):
        super().__init__()

        # BN will be deprecated with values BNTT, TDBN Layers,
        self.bn = bn        
        self.hidden = hidden
        self.poisson = poisson
        self.convs = nn.ModuleList()
        self.bns = nn.ModuleList()
        self.convs.append(SpikeGCNConv(in_features, self.hidden[0], neuron_type = neuron_type, quantize = quantize, threshold_trainable=thtr, aggr=aggr, thr =thr))
        for idx, hid in enumerate(self.hidden):
            # self.bns.append(nn.BatchNorm1d(self.hidden[idx]))
            # self.bns.append(LayerNorm(self.hidden[idx]))
            
            if idx == len(self.hidden) - 2:
                # self.convs.append(SpikeGCNConv(self.hidden[idx], out_features, neuron_type = neuron_type, quantize = quantize, threshold_trainable=thtr, aggr=aggr, thr=thr))
                break
            else:
                self.convs.append(SpikeGCNConv(self.hidden[idx], self.hidden[idx+1], neuron_type = neuron_type, quantize = quantize, threshold_trainable=thtr, aggr=aggr, thr=thr))
        self.lin = torch.nn.Linear(self.hidden[len(hidden)-1], out_features)
        self.T = T
        self.decode = neuron.LIF(ssize=out_features)
        self.dropout = nn.Dropout(dropout)
        self.origconv = GCNConv(self.hidden[len(hidden)-2],self.hidden[len(hidden)-1] )
        
    def forward(self, data):
        for t in range(self.T):
            if t == 0 :
                out_spike_counter =  self.encode(data)
            else:
                out_spike_counter += self.encode(data)

        neuron.reset_net(self)
        
        return out_spike_counter / self.T
        
    def encode(self, data):
        x, edge_index = data.x, data.edge_index
        if self.poisson:
            input_encoder = encoding.PoissonEncoder()
            x = input_encoder(x)
        for idx, conv in enumerate(self.convs):
            x = conv(x, edge_index)
            # if idx != len(self.convs) -1:
            x = self.dropout(x)
            # if self.bn:
                # x = self.bns[idx](x)        
        x = self.origconv(x, edge_index)
        x = global_mean_pool(x, data.batch)
        # print(x.size())
        x = self.dropout(x)
        x = self.lin(x)
        # x = self.decode(x)
        # print(x)
        return x

class SNNGCNN_GC_Degree(torch.nn.Module): 
     
    def __init__(self, in_features, out_features, device, hidden = [128],
                 dropout = 0.5, bias = False, T = 15, neuron_type = 'LIF',
                 quantize = False, bn = False, thtr = False,
                 aggr='mean', thr = 0.25, poisson = True, bins = 10):
        super().__init__()

        # BN will be deprecated with values BNTT, TDBN Layers,
        global cur_exp_folder
        self.bn = bn        
        self.hidden = hidden
        self.poisson = poisson
        self.convs = nn.ModuleList()
        self.bns = nn.ModuleList()
        self.bins = bins
        self.convs.append(SpikeGCNConvDegree(in_features, self.hidden[0], neuron_type = neuron_type, quantize = quantize, threshold_trainable=thtr, aggr=aggr, thr =thr,
                                             bins = self.bins, device = device))
        for idx, hid in enumerate(self.hidden):
            # self.bns.append(nn.BatchNorm1d(self.hidden[idx]))
            # self.bns.append(LayerNorm(self.hidden[idx]))
            
            if idx == len(self.hidden) - 2:
                # self.convs.append(SpikeGCNConv(self.hidden[idx], out_features, neuron_type = neuron_type, quantize = quantize, threshold_trainable=thtr, aggr=aggr, thr=thr))
                break
            else:
                self.convs.append(SpikeGCNConvDegree(self.hidden[idx], self.hidden[idx+1], neuron_type = neuron_type, quantize = quantize, threshold_trainable=thtr, aggr=aggr, thr=thr,
                                                     bins = self.bins, device=device))
        self.lin = torch.nn.Linear(self.hidden[len(hidden)-1], out_features)
        self.T = T
        self.decode = neuron.LIF(ssize=out_features)
        self.dropout = nn.Dropout(dropout)
        self.origconv = GCNConv(self.hidden[len(hidden)-2],self.hidden[len(hidden)-1])
        self.threshold_list = [[] for i in range(len(self.convs))]
        
    def forward(self, data):
        for t in range(self.T):
            if t == 0 :
                out_spike_counter =  self.encode(data)
            else:
                out_spike_counter += self.encode(data)

        neuron.reset_net(self)
        # for idx,convs in enumerate(self.convs):
        #     self.threshold_list[idx].append(convs.neuron.cur_threshold.data)
        #     print('yeah')
        return out_spike_counter / self.T
        
    def encode(self, data):
        x, edge_index = data.x, data.edge_index
        if self.poisson:
            input_encoder = encoding.PoissonEncoder()
            x = input_encoder(x)
        for idx, conv in enumerate(self.convs):
            x = conv(x, edge_index)
            # if idx != len(self.convs) -1:
            x = self.dropout(x)
            # if self.bn:
                # x = self.bns[idx](x)        
        x = self.origconv(x, edge_index)
        x = global_mean_pool(x, data.batch)
        # print(x.size())
        x = self.dropout(x)
        x = self.lin(x)
        # x = self.decode(x)
        # print(x)
        return x

class SNNGCNN_GC_Degree_Feat(torch.nn.Module): 
     
    def __init__(self, in_features, out_features, device, hidden = [128],
                 dropout = 0.5, bias = False, T = 15, neuron_type = 'LIF',
                 quantize = False, bn = False, thtr = False,
                 aggr='mean', thr = 0.25, poisson = True, bins = 10, degree_to_label = None):
        super().__init__()

        # BN will be deprecated with values BNTT, TDBN Layers,
        self.bn = bn        
        self.hidden = hidden
        self.poisson = poisson
        self.convs = nn.ModuleList()
        self.bns = nn.ModuleList()
        self.bins = bins
        self.degree_to_label = degree_to_label
        if self.degree_to_label:
            self.convs.append(SpikeGCNConvDegreeFeat(in_features, self.hidden[0], neuron_type = neuron_type, quantize = quantize, threshold_trainable=thtr, aggr=aggr, thr =thr,
                                                bins = self.bins, device = device,degree_to_label=degree_to_label))
        else:
            self.convs.append(SpikeGCNConvDegreeFeat(in_features, self.hidden[0], neuron_type = neuron_type, quantize = quantize, threshold_trainable=thtr, aggr=aggr, thr =thr,
                                                bins = self.bins, device = device))
        for idx, hid in enumerate(self.hidden):
            # self.bns.append(nn.BatchNorm1d(self.hidden[idx]))
            # self.bns.append(LayerNorm(self.hidden[idx]))
            
            if idx == len(self.hidden) - 2:
                # self.convs.append(SpikeGCNConv(self.hidden[idx], out_features, neuron_type = neuron_type, quantize = quantize, threshold_trainable=thtr, aggr=aggr, thr=thr))
                break
            else:
                if self.degree_to_label:
                    self.convs.append(SpikeGCNConvDegreeFeat(self.hidden[idx], self.hidden[idx+1], neuron_type = neuron_type, quantize = quantize, threshold_trainable=thtr, aggr=aggr, thr =thr,
                                                        bins = self.bins, device = device,degree_to_label=degree_to_label))
                else:
                    self.convs.append(SpikeGCNConvDegreeFeat(self.hidden[idx], self.hidden[idx+1], neuron_type = neuron_type, quantize = quantize, threshold_trainable=thtr, aggr=aggr, thr =thr,
                                                        bins = self.bins, device = device))
        self.lin = torch.nn.Linear(self.hidden[len(hidden)-1], out_features)
        self.T = T
        self.decode = neuron.LIF(ssize=out_features)
        self.dropout = nn.Dropout(dropout)
        self.origconv = GCNConv(self.hidden[len(hidden)-2],self.hidden[len(hidden)-1])
        self.threshold_list = [[] for i in range(len(self.convs))]
        
    def forward(self, data):
        for t in range(self.T):
            if t == 0 :
                out_spike_counter =  self.encode(data)
            else:
                out_spike_counter += self.encode(data)

        neuron.reset_net(self)
        # for idx,convs in enumerate(self.convs):
        #     self.threshold_list[idx].append(convs.neuron.cur_threshold.data)
        #     print('yeah')
        return out_spike_counter / self.T
        
    def encode(self, data):
        x, edge_index = data.x, data.edge_index
        if self.poisson:
            input_encoder = encoding.PoissonEncoder()
            x = input_encoder(x)
        for idx, conv in enumerate(self.convs):
            x = conv(x, edge_index)
            # if idx != len(self.convs) -1:
            x = self.dropout(x)
            # if self.bn:
                # x = self.bns[idx](x)        
        x = self.origconv(x, edge_index)
        x = global_mean_pool(x, data.batch)
        # print(x.size())
        x = self.dropout(x)
        x = self.lin(x)
        # x = self.decode(x)
        # print(x)
        return x

class SGAT_Degree_Feat(torch.nn.Module): 
     
    def __init__(self, in_features, out_features, device, hidden = [64,64,64],
                 dropout = 0.5, bias = False, T = 15, neuron_type = 'LIF',
                 quantize = False, bn = False, thtr = False,
                 aggr='mean', thr = 0.25, poisson = True, bins = 10, degree_to_label = None):
        super().__init__()

        # BN will be deprecated with values BNTT, TDBN Layers,
        self.bn = bn        
        # self.hidden = hidden
        self.hidden = [64,64,64]
        self.poisson = poisson
        self.convs = nn.ModuleList()
        self.bns = nn.ModuleList()
        self.bins = bins
        self.degree_to_label = degree_to_label
        if self.degree_to_label:
            self.convs.append(SpikeGATConvDegreeFeat(in_features, self.hidden[0], neuron_type = neuron_type, quantize = quantize, threshold_trainable=thtr, aggr=aggr, thr =thr,
                                                bins = self.bins, device = device,degree_to_label=degree_to_label))
        else:
            self.convs.append(SpikeGATConvDegreeFeat(in_features, self.hidden[0], neuron_type = neuron_type, quantize = quantize, threshold_trainable=thtr, aggr=aggr, thr =thr,
                                                bins = self.bins, device = device))
        for idx, hid in enumerate(self.hidden):
            if idx == len(self.hidden) - 2:
                break
            else:
                if self.degree_to_label:
                    self.convs.append(SpikeGATConvDegreeFeat(self.hidden[idx], self.hidden[idx+1], neuron_type = neuron_type, quantize = quantize, threshold_trainable=thtr, aggr=aggr, thr =thr,
                                                        bins = self.bins, device = device,degree_to_label=degree_to_label))
                else:
                    self.convs.append(SpikeGATConvDegreeFeat(self.hidden[idx], self.hidden[idx+1], neuron_type = neuron_type, quantize = quantize, threshold_trainable=thtr, aggr=aggr, thr =thr,
                                                        bins = self.bins, device = device))
        
        self.lin = torch.nn.Linear(self.hidden[len(hidden)-1], out_features)
        self.T = T
        self.dropout = nn.Dropout(dropout)
        self.origconv = GATv2Conv(self.hidden[len(hidden)-2], self.hidden[len(hidden)-1], heads=4, concat = False)
        self.threshold_list = [[] for i in range(len(self.convs))]
        
    def forward(self, data):
        for t in range(self.T):
            if t == 0 :
                out_spike_counter =  self.encode(data)
            else:
                out_spike_counter += self.encode(data)
        neuron.reset_net(self)
        return out_spike_counter / self.T
        
    def encode(self, data):
        x, edge_index = data.x, data.edge_index
        if self.poisson:
            input_encoder = encoding.PoissonEncoder()
            x = input_encoder(x)
        for idx, conv in enumerate(self.convs):
            x = conv(x, edge_index)
            x = self.dropout(x)
        x = self.origconv(x, edge_index)
        x = global_mean_pool(x, data.batch)
        # print(x.size())
        x = self.dropout(x)
        x = self.lin(x)
        return x

class SNNGCNN_GC_Degree_Feat_Spikeonly(torch.nn.Module): 
     
    def __init__(self, in_features, out_features, device, hidden = [128],
                 dropout = 0.5, bias = False, T = 15, neuron_type = 'LIF',
                 quantize = False, bn = False, thtr = False,
                 aggr='mean', thr = 0.25, poisson = True, bins = 10):
        super().__init__()

        # BN will be deprecated with values BNTT, TDBN Layers,
        self.bn = bn        
        self.hidden = hidden
        self.poisson = poisson
        self.convs = nn.ModuleList()
        self.bns = nn.ModuleList()
        self.bins = bins
        self.convs.append(SpikeGCNConvDegreeFeat(in_features, self.hidden[0], neuron_type = neuron_type, quantize = quantize, threshold_trainable=thtr, aggr=aggr, thr =thr,
                                             bins = self.bins, device = device))
        for idx, hid in enumerate(self.hidden):
            # self.bns.append(nn.BatchNorm1d(self.hidden[idx]))
            # self.bns.append(LayerNorm(self.hidden[idx]))
            
            if idx == len(self.hidden) - 1:
                # self.convs.append(SpikeGCNConv(self.hidden[idx], out_features, neuron_type = neuron_type, quantize = quantize, threshold_trainable=thtr, aggr=aggr, thr=thr))
                break
            else:
                self.convs.append(SpikeGCNConvDegreeFeat(self.hidden[idx], self.hidden[idx+1], neuron_type = neuron_type, quantize = quantize, threshold_trainable=thtr, aggr=aggr, thr=thr,
                                                     bins = self.bins, device=device))
        self.lin = torch.nn.Linear(self.hidden[len(hidden)-1], out_features)
        self.T = T
        self.dropout = nn.Dropout(dropout)
        self.threshold_list = [[] for i in range(len(self.convs))]
        
    def forward(self, data):
        for t in range(self.T):
            if t == 0 :
                out_spike_counter =  self.encode(data)
            else:
                out_spike_counter += self.encode(data)

        neuron.reset_net(self)
        # for idx,convs in enumerate(self.convs):
        #     self.threshold_list[idx].append(convs.neuron.cur_threshold.data)
        #     print('yeah')
        return out_spike_counter / self.T
        
    def encode(self, data):
        x, edge_index = data.x, data.edge_index
        if self.poisson:
            input_encoder = encoding.PoissonEncoder()
            x = input_encoder(x)
        for idx, conv in enumerate(self.convs):
            x = conv(x, edge_index)
            x = self.dropout(x)
        x = global_mean_pool(x, data.batch)
        x = self.dropout(x)
        x = self.lin(x)
        return x
    
class SNNGCNN_GC_Degree_Feat_Cluster(torch.nn.Module): 
     
    def __init__(self, in_features, out_features, device, hidden = [128],
                 dropout = 0.5, bias = False, T = 15, neuron_type = 'LIF',
                 quantize = False, bn = False, thtr = False,
                 aggr='mean', thr = 0.25, poisson = True, bins = 5, degree_to_label = None):
        super().__init__()

        # BN will be deprecated with values BNTT, TDBN Layers,
        self.bn = bn        
        self.hidden = hidden
        self.poisson = poisson
        self.convs = nn.ModuleList()
        self.bns = nn.ModuleList()
        self.bins = bins
        self.convs.append(SpikeGCNConvDegreeFeatCluster(in_features, self.hidden[0], neuron_type = neuron_type, quantize = quantize, threshold_trainable=thtr, aggr=aggr, thr =thr,
                                             bins = self.bins, device = device, degree_to_label = degree_to_label))
        for idx, hid in enumerate(self.hidden):
            # self.bns.append(nn.BatchNorm1d(self.hidden[idx]))
            # self.bns.append(LayerNorm(self.hidden[idx]))
            
            if idx == len(self.hidden) - 2:
                # self.convs.append(SpikeGCNConv(self.hidden[idx], out_features, neuron_type = neuron_type, quantize = quantize, threshold_trainable=thtr, aggr=aggr, thr=thr))
                break
            else:
                self.convs.append(SpikeGCNConvDegreeFeatCluster(self.hidden[idx], self.hidden[idx+1], neuron_type = neuron_type, quantize = quantize, threshold_trainable=thtr, aggr=aggr, thr=thr,
                                                     bins = self.bins, device=device,  degree_to_label = degree_to_label))
        self.lin = torch.nn.Linear(self.hidden[len(hidden)-1], out_features)
        self.T = T
        self.decode = neuron.LIF(ssize=out_features)
        self.dropout = nn.Dropout(dropout)
        self.origconv = GCNConv(self.hidden[len(hidden)-2],self.hidden[len(hidden)-1])
        self.threshold_list = [[] for i in range(len(self.convs))]
        
    def forward(self, data):
        for t in range(self.T):
            if t == 0 :
                out_spike_counter =  self.encode(data)
            else:
                out_spike_counter += self.encode(data)

        neuron.reset_net(self)
        # for idx,convs in enumerate(self.convs):
        #     self.threshold_list[idx].append(convs.neuron.cur_threshold.data)
        #     print('yeah')
        return out_spike_counter / self.T
        
    def encode(self, data):
        x, edge_index = data.x, data.edge_index
        if self.poisson:
            input_encoder = encoding.PoissonEncoder()
            x = input_encoder(x)
        for idx, conv in enumerate(self.convs):
            x = conv(x, edge_index)
            # if idx != len(self.convs) -1:
            x = self.dropout(x)
            # if self.bn:
                # x = self.bns[idx](x)        
        x = self.origconv(x, edge_index)
        x = global_mean_pool(x, data.batch)
        # print(x.size())
        x = self.dropout(x)
        x = self.lin(x)
        # x = self.decode(x)
        # print(x)
        return x

class SNNGIN_GC_Degree_Feat(torch.nn.Module): 
     
    def __init__(self, in_features, out_features, device, hidden = [128],
                 dropout = 0.5, bias = False, T = 15, neuron_type = 'LIF',
                 quantize = False, bn = False, thtr = False,
                 aggr='mean', thr = 0.25, poisson = True, bins = 10, degree_to_label = None):
        super().__init__()

        # BN will be deprecated with values BNTT, TDBN Layers,
        self.bn = bn        
        self.hidden = hidden
        self.poisson = poisson
        self.convs = nn.ModuleList()
        self.bns = nn.ModuleList()
        self.bins = bins
        self.degree_to_label = degree_to_label
        
        if self.degree_to_label:
            self.convs.append(SpikeGINConvDegreeFeat(in_features, self.hidden[0], neuron_type = neuron_type, quantize = quantize, threshold_trainable=thtr, aggr=aggr, thr =thr,
                                                bins = self.bins, device = device,degree_to_label=degree_to_label))
        else:
            self.convs.append(SpikeGINConvDegreeFeat(in_features, self.hidden[0], neuron_type = neuron_type, quantize = quantize, threshold_trainable=thtr, aggr=aggr, thr =thr,
                                                bins = self.bins, device = device))
        for idx, hid in enumerate(self.hidden):
            # self.bns.append(nn.BatchNorm1d(self.hidden[idx]))
            # self.bns.append(LayerNorm(self.hidden[idx]))
            
            if idx == len(self.hidden) - 2:
                # self.convs.append(SpikeGCNConv(self.hidden[idx], out_features, neuron_type = neuron_type, quantize = quantize, threshold_trainable=thtr, aggr=aggr, thr=thr))
                break
            else:
                if self.degree_to_label:
                    self.convs.append(SpikeGINConvDegreeFeat(self.hidden[idx], self.hidden[idx+1], neuron_type = neuron_type, quantize = quantize, threshold_trainable=thtr, aggr=aggr, thr =thr,
                                                        bins = self.bins, device = device,degree_to_label=degree_to_label))
                else:
                    self.convs.append(SpikeGINConvDegreeFeat(self.hidden[idx], self.hidden[idx+1], neuron_type = neuron_type, quantize = quantize, threshold_trainable=thtr, aggr=aggr, thr =thr,
                                                        bins = self.bins, device = device))
        self.origconv = GINConv(torch.nn.Linear(self.hidden[len(hidden)-1], self.hidden[len(hidden)-1]))
        self.lin = torch.nn.Linear(self.hidden[len(hidden)-1], out_features)
        self.T = T
        self.dropout = nn.Dropout(dropout)
        self.threshold_list = [[] for i in range(len(self.convs))]
        
    def forward(self, data):
        for t in range(self.T):
            if t == 0 :
                out_spike_counter =  self.encode(data)
            else:
                out_spike_counter += self.encode(data)

        neuron.reset_net(self)
        # for idx,convs in enumerate(self.convs):
        #     self.threshold_list[idx].append(convs.neuron.cur_threshold.data)
        #     print('yeah')
        return out_spike_counter / self.T
    
    def encode(self, data):
        x, edge_index = data.x, data.edge_index
        if self.poisson:
            input_encoder = encoding.PoissonEncoder()
            x = input_encoder(x)
        
        h1 = self.convs[0](x, edge_index)
        h2 = self.convs[1](h1, edge_index)
        h = self.origconv(h2, edge_index)
        h = global_add_pool(h, data.batch)
        h = self.lin(h) 
        return h
            
    # def encode(self, data):
    #     x, edge_index = data.x, data.edge_index
    #     hid = [-1 for i in range(len(self.convs))]
    #     if self.poisson:
    #         input_encoder = encoding.PoissonEncoder()
    #         x = input_encoder(x)
    #     for idx, conv in enumerate(self.convs):
    #         x = conv(x, edge_index)
    #         hid[idx] = x
    #         # if idx != len(self.convs) -1:
    #         x = self.dropout(x)
    #         # if self.bn:
    #             # x = self.bns[idx](x)        
    #     # x = self.origconv(x, edge_index)
    #     x = global_add_pool(x, data.batch)
    #     # print(x.size())
    #     x = self.dropout(x)
    #     x = self.lin(x)
    #     # x = self.decode(x)
    #     # print(x)
          # return x
    
class SNNGCNN_GC_Spikeonly(torch.nn.Module): 
     
    def __init__(self, in_features, out_features, hidden = [128],
                 dropout = 0.5, bias = False, T = 15, neuron_type = 'LIF',
                 quantize = False, bn = False, thtr = False,
                 aggr='mean', thr = 0.25, poisson = True):
        super().__init__()

        # BN will be deprecated with values BNTT, TDBN Layers,
        self.bn = bn        
        self.hidden = hidden
        self.poisson = poisson
        self.convs = nn.ModuleList()
        self.bns = nn.ModuleList()
        self.convs.append(SpikeGCNConv(in_features, self.hidden[0], neuron_type = neuron_type, quantize = quantize, threshold_trainable=thtr, aggr=aggr, thr =thr))
        for idx, hid in enumerate(self.hidden):
            # self.bns.append(nn.BatchNorm1d(self.hidden[idx]))
            # self.bns.append(LayerNorm(self.hidden[idx]))
            
            if idx == len(self.hidden) - 1:
                # self.convs.append(SpikeGCNConv(self.hidden[idx], out_features, neuron_type = neuron_type, quantize = quantize, threshold_trainable=thtr, aggr=aggr, thr=thr))
                break
            else:
                self.convs.append(SpikeGCNConv(self.hidden[idx], self.hidden[idx+1], neuron_type = neuron_type, quantize = quantize, threshold_trainable=thtr, aggr=aggr, thr=thr))
        self.lin = torch.nn.Linear(self.hidden[len(hidden)-1], out_features)
        self.T = T
        self.decode = neuron.LIF(ssize=out_features)
        self.dropout = nn.Dropout(dropout)
        self.origconv = GCNConv(self.hidden[len(hidden)-2],self.hidden[len(hidden)-1] )
        
    def forward(self, data):
        for t in range(self.T):
            if t == 0 :
                out_spike_counter =  self.encode(data)
            else:
                out_spike_counter += self.encode(data)

        neuron.reset_net(self)
        
        return out_spike_counter / self.T
        
    def encode(self, data):
        x, edge_index = data.x, data.edge_index
        if self.poisson:
            input_encoder = encoding.PoissonEncoder()
            x = input_encoder(x)
        for idx, conv in enumerate(self.convs):
            x = conv(x, edge_index)
            # if idx != len(self.convs) -1:
            x = self.dropout(x)
            # if self.bn:
                # x = self.bns[idx](x)        
        # x = self.origconv(x, edge_index)
        x = global_mean_pool(x, data.batch)
        # print(x.size())
        x = self.dropout(x)
        x = self.lin(x)
        x = self.decode(x)
        # print(x)
        return x
    
class SNNGCNN_GC_outspike(torch.nn.Module): 
     
    def __init__(self, in_features, out_features, hidden = [128],
                 dropout = 0.5, bias = False, T = 15, neuron_type = 'LIF',
                 quantize = False, bn = False, thtr = False,
                 aggr='mean', thr = 0.25, poisson = True):
        super().__init__()

        # BN will be deprecated with values BNTT, TDBN Layers,
        self.bn = bn        
        self.hidden = hidden
        self.poisson = poisson
        self.convs = nn.ModuleList()
        self.bns = nn.ModuleList()
        self.convs.append(SpikeGCNConv(in_features, self.hidden[0], neuron_type = neuron_type, quantize = quantize, threshold_trainable=thtr, aggr=aggr, thr =thr))
        for idx, hid in enumerate(self.hidden):
            # self.bns.append(nn.BatchNorm1d(self.hidden[idx]))
            # self.bns.append(LayerNorm(self.hidden[idx]))
            
            if idx == len(self.hidden) - 2:
                # self.convs.append(SpikeGCNConv(self.hidden[idx], out_features, neuron_type = neuron_type, quantize = quantize, threshold_trainable=thtr, aggr=aggr, thr=thr))
                break
            else:
                self.convs.append(SpikeGCNConv(self.hidden[idx], self.hidden[idx+1], neuron_type = neuron_type, quantize = quantize, threshold_trainable=thtr, aggr=aggr, thr=thr))
        self.lin = torch.nn.Linear(self.hidden[len(hidden)-1], out_features)
        self.T = T
        self.decode = neuron.LIF(ssize=out_features)
        self.dropout = nn.Dropout(dropout)
        self.origconv = GCNConv(self.hidden[len(hidden)-2],self.hidden[len(hidden)-1] )
        
    def forward(self, data):
        for t in range(self.T):
            if t == 0 :
                out_spike_counter =  self.encode(data)
            else:
                out_spike_counter += self.encode(data)

        neuron.reset_net(self)
        
        return out_spike_counter / self.T
        
    def encode(self, data):
        x, edge_index = data.x, data.edge_index
        if self.poisson:
            input_encoder = encoding.PoissonEncoder()
            x = input_encoder(x)
        for idx, conv in enumerate(self.convs):
            x = conv(x, edge_index)
            # if idx != len(self.convs) -1:
            x = self.dropout(x)
            # if self.bn:
                # x = self.bns[idx](x)        
        x = self.origconv(x, edge_index)
        x = global_mean_pool(x, data.batch)
        # print(x.size())
        x = self.dropout(x)
        x = self.lin(x)
        x = self.decode(x)
        # print(x)
        return x

class SNNGCNN_GC_onethreshold(torch.nn.Module): 
     
    def __init__(self, in_features, out_features, hidden = [128],
                 dropout = 0.5, bias = False, T = 15, neuron_type = 'LIF',
                 quantize = False, bn = False, thtr = False,
                 aggr='mean', thr = 0.25, poisson = True):
        super().__init__()

        # BN will be deprecated with values BNTT, TDBN Layers,
        self.bn = bn        
        self.hidden = hidden
        self.poisson = poisson
        self.convs = nn.ModuleList()
        self.bns = nn.ModuleList()
        self.convs.append(SpikeGCNConv_OneThreshold(in_features, self.hidden[0], neuron_type = neuron_type, quantize = quantize, threshold_trainable=thtr, aggr=aggr, thr =thr))
        for idx, hid in enumerate(self.hidden):
            # self.bns.append(nn.BatchNorm1d(self.hidden[idx]))
            # self.bns.append(LayerNorm(self.hidden[idx]))
            
            if idx == len(self.hidden) - 2:
                # self.convs.append(SpikeGCNConv(self.hidden[idx], out_features, neuron_type = neuron_type, quantize = quantize, threshold_trainable=thtr, aggr=aggr, thr=thr))
                break
            else:
                self.convs.append(SpikeGCNConv_OneThreshold(self.hidden[idx], self.hidden[idx+1], neuron_type = neuron_type, quantize = quantize, threshold_trainable=thtr, aggr=aggr, thr=thr))
        self.lin = torch.nn.Linear(self.hidden[len(hidden)-1], out_features)
        self.T = T
        self.decode = neuron.LIF(ssize=out_features)
        self.dropout = nn.Dropout(dropout)
        self.origconv = GCNConv(self.hidden[len(hidden)-2],self.hidden[len(hidden)-1] )
        
    def forward(self, data):
        for t in range(self.T):
            if t == 0 :
                out_spike_counter =  self.encode(data)
            else:
                out_spike_counter += self.encode(data)

        neuron.reset_net(self)
        
        return out_spike_counter / self.T
        
    def encode(self, data):
        x, edge_index = data.x, data.edge_index
        if self.poisson:
            input_encoder = encoding.PoissonEncoder()
            x = input_encoder(x)
        for idx, conv in enumerate(self.convs):
            x = conv(x, edge_index)
            # if idx != len(self.convs) -1:
            x = self.dropout(x)
            # if self.bn:
                # x = self.bns[idx](x)        
        x = self.origconv(x, edge_index)
        x = global_mean_pool(x, data.batch)
        # print(x.size())
        x = self.dropout(x)
        x = self.lin(x)
        # x = self.decode(x)
        # print(x)
        return x

class SNNGCNN_GC_TMPool(torch.nn.Module): 
     
    def __init__(self, in_features, out_features, hidden = [128],
                 dropout = 0.5, bias = False, T = 15, neuron_type = 'LIF',
                 quantize = False, bn = False, thtr = False,
                 aggr='mean', thr = 0.25, poisson = True):
        super().__init__()

        # BN will be deprecated with values BNTT, TDBN Layers,
        self.bn = bn        
        self.hidden = hidden
        self.poisson = poisson
        self.convs = nn.ModuleList()
        self.bns = nn.ModuleList()
        self.convs.append(SpikeGCNConv(in_features, self.hidden[0], neuron_type = neuron_type, quantize = quantize, threshold_trainable=thtr, aggr=aggr, thr =thr))
        for idx, hid in enumerate(self.hidden):
            # self.bns.append(nn.BatchNorm1d(self.hidden[idx]))
            # self.bns.append(LayerNorm(self.hidden[idx]))
            
            if idx == len(self.hidden) - 2:
                # self.convs.append(SpikeGCNConv(self.hidden[idx], out_features, neuron_type = neuron_type, quantize = quantize, threshold_trainable=thtr, aggr=aggr, thr=thr))
                break
            else:
                self.convs.append(SpikeGCNConv(self.hidden[idx], self.hidden[idx+1], neuron_type = neuron_type, quantize = quantize, threshold_trainable=thtr, aggr=aggr, thr=thr))
        self.lin = torch.nn.Linear(self.hidden[len(hidden)-1], out_features)
        self.T = T
        self.decode = neuron.LIF(ssize=out_features)
        self.dropout = nn.Dropout(dropout)
        self.origconv = GCNConv(self.hidden[len(hidden)-2],self.hidden[len(hidden)-1])
        self.out_
        
    def forward(self, data):
        for t in range(self.T):
            if t == 0 :
                out_spike_counter =  self.encode(data)
            else:
                out_spike_counter += self.encode(data)

        neuron.reset_net(self)
        
        return out_spike_counter / self.T
        
    def encode(self, data):
        x, edge_index = data.x, data.edge_index
        if self.poisson:
            input_encoder = encoding.PoissonEncoder()
            x = input_encoder(x)
        for idx, conv in enumerate(self.convs):
            x = conv(x, edge_index)
            # if idx != len(self.convs) -1:
            x = self.dropout(x)
            # if self.bn:
                # x = self.bns[idx](x)        
        x = self.origconv(x, edge_index)
        x = global_mean_pool(x, data.batch)
        # print(x.size())
        x = self.dropout(x)
        x = self.lin(x)
        # x = self.decode(x)
        # print(x)
        return x
    

class SNNGCNN_GC_Multi(torch.nn.Module): 
     
    def __init__(self, in_features, out_features, hidden = [128],
                 dropout = 0.5, bias = False, T = 15, neuron_type = 'LIF',
                 quantize = False, bn = False, thtr = False,
                 aggr='mean', thr = 0.25, poisson = True):
        super().__init__()

        # BN will be deprecated with values BNTT, TDBN Layers,
        self.bn = bn        
        self.hidden = hidden
        self.poisson = poisson
        self.convs = nn.ModuleList()
        self.bns = nn.ModuleList()
        self.convs.append(SpikeGCNConv_multi(in_features, self.hidden[0], neuron_type = neuron_type, quantize = quantize, threshold_trainable=thtr, aggr=aggr, thr =thr))
        for idx, hid in enumerate(self.hidden):
            # self.bns.append(nn.BatchNorm1d(self.hidden[idx]))
            # self.bns.append(LayerNorm(self.hidden[idx]))
            
            if idx == len(self.hidden) - 2:
                # self.convs.append(SpikeGCNConv(self.hidden[idx], out_features, neuron_type = neuron_type, quantize = quantize, threshold_trainable=thtr, aggr=aggr, thr=thr))
                break
            else:
                self.convs.append(SpikeGCNConv_multi(self.hidden[idx], self.hidden[idx+1], neuron_type = neuron_type, quantize = quantize, threshold_trainable=thtr, aggr=aggr, thr=thr))
        self.lin = torch.nn.Linear(self.hidden[len(hidden)-1], out_features)
        self.T = T
        self.decode = neuron.LIF(ssize=out_features)
        self.dropout = nn.Dropout(dropout)
        self.origconv = GCNConv(self.hidden[len(hidden)-2],self.hidden[len(hidden)-1] )
        
    def forward(self, data):
        for t in range(self.T):
            if t == 0 :
                out_spike_counter =  self.encode(data)
            else:
                out_spike_counter += self.encode(data)

        neuron.reset_net(self)
        
        return out_spike_counter / self.T
        
    def encode(self, data):
        x, edge_index = data.x, data.edge_index
        if self.poisson:
            input_encoder = encoding.PoissonEncoder()
            x = input_encoder(x)
        for idx, conv in enumerate(self.convs):
            x = conv(x, edge_index)
            # if idx != len(self.convs) -1:
            x = self.dropout(x)
            # if self.bn:
                # x = self.bns[idx](x)        
        x = self.origconv(x, edge_index)
        x = global_mean_pool(x, data.batch)
        # print(x.size())
        x = self.dropout(x)
        x = self.lin(x)
        # x = self.decode(x)
        # print(x)
        return x
    

class SNNGCNN_GC_TMPool(torch.nn.Module): 
     
    def __init__(self, in_features, out_features, hidden = [128],
                 dropout = 0.5, bias = False, T = 15, neuron_type = 'LIF',
                 quantize = False, bn = False, thtr = False,
                 aggr='mean', thr = 0.25, poisson = True):
        super().__init__()

        # BN will be deprecated with values BNTT, TDBN Layers,
        self.bn = bn        
        self.hidden = hidden
        self.poisson = poisson
        self.convs = nn.ModuleList()
        self.bns = nn.ModuleList()
        self.convs.append(SpikeGCNConv(in_features, self.hidden[0], neuron_type = neuron_type, quantize = quantize, threshold_trainable=thtr, aggr=aggr, thr =thr))
        for idx, hid in enumerate(self.hidden):
            # self.bns.append(nn.BatchNorm1d(self.hidden[idx]))
            # self.bns.append(LayerNorm(self.hidden[idx]))
            
            if idx == len(self.hidden) - 2:
                # self.convs.append(SpikeGCNConv(self.hidden[idx], out_features, neuron_type = neuron_type, quantize = quantize, threshold_trainable=thtr, aggr=aggr, thr=thr))
                break
            else:
                self.convs.append(SpikeGCNConv(self.hidden[idx], self.hidden[idx+1], neuron_type = neuron_type, quantize = quantize, threshold_trainable=thtr, aggr=aggr, thr=thr))
        self.T = T
        self.lin = torch.nn.Linear(self.hidden[len(hidden)-1], out_features)
        self.decode = neuron.LIF(ssize=out_features)
        self.dropout = nn.Dropout(dropout)
        self.origconv = GCNConv(self.hidden[len(hidden)-2],self.hidden[len(hidden)-1])
        self.out_feature = out_features
        self.time_dim_transform = torch.nn.Linear(self.hidden[-1] * self.T, self.hidden[-1])
        
    def forward(self, data):
        time_step_outputs = []
        for t in range(self.T):
            x = self.encode(data)
            time_step_outputs.append(x.unsqueeze(1))  # Add time dimension

        # Concatenate along the time dimension
        out_time_concat = torch.cat(time_step_outputs, dim=1)

        # Reshape for linear transformation: Flatten time and feature dimensions
        batch_size, time_steps, features = out_time_concat.size()
        out_flattened = out_time_concat.view(batch_size, -1)  # Flatten time and feature dimensions

        # Apply linear transformation
        out_transformed = self.time_dim_transform(out_flattened)

        # Final dropout and linear layer for classification or regression
        out_dropped = self.dropout(out_transformed)
        out = self.lin(out_dropped)
        
        neuron.reset_net(self)
        
        return out
        
    def encode(self, data):
        x, edge_index = data.x, data.edge_index
        if self.poisson:
            input_encoder = encoding.PoissonEncoder()
            x = input_encoder(x)
        for idx, conv in enumerate(self.convs):
            x = conv(x, edge_index)
            x = self.dropout(x)
            # if self.bn:
                # x = self.bns[idx](x)        
        x = self.origconv(x, edge_index)
        x = global_mean_pool(x, data.batch)
        x = self.dropout(x)
        return x
    
class SNNGCNN_GC_AM(torch.nn.Module): 
     
    def __init__(self, in_features, out_features, hidden = [128],
                 dropout = 0.5, bias = False, T = 15, neuron_type = 'LIF',
                 quantize = False, bn = False, thtr = False,
                 aggr='mean', thr = 0.25):
        super().__init__()

        # BN will be deprecated with values BNTT, TDBN Layers,
        self.bn = bn        
        self.hidden = hidden
        self.convs = nn.ModuleList()
        self.bns = nn.ModuleList()
        self.convs.append(SpikeGCNConv(in_features, self.hidden[0], neuron_type = neuron_type, quantize = quantize, threshold_trainable=thtr, aggr=aggr, thr =thr))
        for idx, hid in enumerate(self.hidden):
            # self.bns.append(nn.BatchNorm1d(self.hidden[idx]))
            # self.bns.append(LayerNorm(self.hidden[idx]))
            
            if idx == len(self.hidden) - 2:
                # self.convs.append(SpikeGCNConv(self.hidden[idx], out_features, neuron_type = neuron_type, quantize = quantize, threshold_trainable=thtr, aggr=aggr, thr=thr))
                break;
            else:
                self.convs.append(SpikeGCNConv(self.hidden[idx], self.hidden[idx+1], neuron_type = neuron_type, quantize = quantize, threshold_trainable=thtr, aggr=aggr, thr=thr))
        self.lin = torch.nn.Linear(self.hidden[len(hidden)-1], 3)
        self.T = T
        self.decode = neuron.LIF(ssize=out_features)
        self.dropout = nn.Dropout(dropout)
        self.origconv = GCNConv(self.hidden[len(hidden)-2], self.hidden[len(hidden)-1])
        
        self.admloss = AdMSoftmaxLoss(train_data.num_classes, train_data.num_classes
                                 , s=20.0, m=0.3)
        
    def forward(self, data):
        for t in range(self.T):
            if t == 0 :
                out_spike_counter =  self.encode(data)
            else:
                out_spike_counter += self.encode(data)

        neuron.reset_net(self)
        
        return out_spike_counter / self.T
        
    def encode(self, data):
        x, edge_index = data.x, data.edge_index
        input_encoder = encoding.PoissonEncoder()
        x = input_encoder(x)
        for idx, conv in enumerate(self.convs):
            x = conv(x, edge_index)
            # if idx != len(self.convs) -1:
            x = self.dropout(x)
            # if self.bn:
                # x = self.bns[idx](x)        
        x = self.origconv(x, edge_index)
        x = global_mean_pool(x, data.batch)
        # print(x.size())
        x = self.dropout(x)
        x = self.lin(x)
        # x = self.decode(x)
        # print(x)
        return x

    
class SNNGCNN_hdld(torch.nn.Module): 
    def __init__(self, in_features, out_features, hidden = [128],
                 dropout = 0.5, bias = False, T = 15, neuron_type = 'LIF',
                 quantize = False, bn = False):
        super().__init__()

        # BN will be deprecated with values BNTT, TDBN Layers,
        self.bn = bn        
        self.hidden = hidden
        self.convs = nn.ModuleList()
        self.high_convs = nn.ModuleList()
        self.bns = nn.ModuleList()
        self.convs.append(SpikeGCNConv(in_features, self.hidden[0], neuron_type = neuron_type, quantize = quantize))
        self.high_convs.append(SpikeGCNConv(in_features, self.hidden[0], neuron_type = neuron_type, quantize = quantize))
        for idx, hid in enumerate(self.hidden):
            self.bns.append(nn.BatchNorm1d(self.hidden[idx]))
            if idx == len(self.hidden) - 1:
                self.convs.append(SpikeGCNConv(self.hidden[idx], out_features, neuron_type = neuron_type, quantize = quantize))
                self.high_convs.append(SpikeGCNConv(self.hidden[idx], out_features, neuron_type = neuron_type, quantize = quantize))
            else:
                self.convs.append(SpikeGCNConv(self.hidden[idx], self.hidden[idx+1], neuron_type = neuron_type, quantize = quantize))
                self.high_convs.append(SpikeGCNConv(self.hidden[idx], self.hidden[idx+1], neuron_type = neuron_type, quantize = quantize))
                
        self.T = T
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, edge_index, new_x):
        deg = degree(edge_index[1], dtype = torch.long)
        counts = torch.bincount(deg)
        cum_counts = torch.cumsum(counts, 0)
        outlier_percentile = 0.95
        boundary = cum_counts[-1] * outlier_percentile
        boundary_deg = 0
        for idx in range(cum_counts.size(0)):
            if cum_counts[idx] > boundary:
                boundary_deg = idx
                break
        # print(boundary_deg)
        highdeg = torch.argwhere(deg > boundary_deg).reshape(-1)
        lowdeg = torch.argwhere(deg <= boundary_deg).reshape(-1)
        new_x[lowdeg] = 0
        x[highdeg] = 0
        
        
        for t in range(self.T):
            if t == 0 :
                out_spike_counter =  self.encode(x, edge_index, new_x)
            else:
                out_spike_counter += self.encode(x, edge_index, new_x)

        neuron.reset_net(self)
        
        return out_spike_counter / self.T
        
    def encode(self, x, edge_index, new_x):
        input_encoder = encoding.PoissonEncoder()
        x = input_encoder(x)
        new_x = input_encoder(new_x)
        
        for idx, conv in enumerate(self.convs):
            x = conv(x, edge_index)
            new_x = self.high_convs[idx](new_x, edge_index)
            if idx != len(self.convs) -1:
                x = self.dropout(x)
                new_x = self.dropout(new_x)
            if self.bn:
                x = self.bns[idx](x)             
        
        x = (x + new_x) / 2
        return x
    
class SNNGCNN_scale(torch.nn.Module): 
    def __init__(self, in_features, out_features, num_vertex, hidden = [128],
                 dropout = 0.5, bias = False, T = 15, neuron_type = 'LIF',
                 quantize = False, bn = False):
        super().__init__()

        # BN will be deprecated with values BNTT, TDBN Layers,
        self.bn = bn        
        self.hidden = hidden
        self.convs = nn.ModuleList()
        self.bns = nn.ModuleList()
        self.convs.append(SpikeGCNConvScale(in_features, self.hidden[0], num_vertex, 0, neuron_type = neuron_type, quantize = quantize))
        for idx, hid in enumerate(self.hidden):
            self.bns.append(nn.BatchNorm1d(self.hidden[idx]))
            if idx == len(self.hidden) - 1:
                self.convs.append(SpikeGCNConvScale(self.hidden[idx], out_features, num_vertex, 1, neuron_type = neuron_type, quantize = quantize))
            else:
                self.convs.append(SpikeGCNConvScale(self.hidden[idx], self.hidden[idx+1], num_vertex, 1,neuron_type = neuron_type, quantize = quantize))

        self.T = T
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, edge_index):
        for t in range(self.T):
            if t == 0 :
                out_spike_counter =  self.encode(x, edge_index)
            else:
                out_spike_counter += self.encode(x, edge_index)

        neuron.reset_net(self)
        
        return out_spike_counter / self.T
        
    def encode(self, x, edge_index):
        input_encoder = encoding.PoissonEncoder()
        x = input_encoder(x)
        for idx, conv in enumerate(self.convs):
            x = conv(x, edge_index)
            if idx != len(self.convs) -1:
                x = self.dropout(x)
            if self.bn:
                x = self.bns[idx](x)             
        
        return x
    
class SNNGCNN_hdld(torch.nn.Module): 
    def __init__(self, in_features, out_features, hidden = [128],
                 dropout = 0.5, bias = False, T = 15, neuron_type = 'LIF',
                 quantize = False, bn = False):
        super().__init__()

        # BN will be deprecated with values BNTT, TDBN Layers,
        self.bn = bn        
        self.hidden = hidden
        self.convs = nn.ModuleList()
        self.high_convs = nn.ModuleList()
        self.bns = nn.ModuleList()
        self.convs.append(SpikeGCNConv(in_features, self.hidden[0], neuron_type = neuron_type, quantize = quantize))
        self.high_convs.append(SpikeGCNConv(in_features, self.hidden[0], neuron_type = neuron_type, quantize = quantize))
        for idx, hid in enumerate(self.hidden):
            self.bns.append(nn.BatchNorm1d(self.hidden[idx]))
            if idx == len(self.hidden) - 1:
                self.convs.append(SpikeGCNConv(self.hidden[idx], out_features, neuron_type = neuron_type, quantize = quantize))
                self.high_convs.append(SpikeGCNConv(self.hidden[idx], out_features, neuron_type = neuron_type, quantize = quantize))
            else:
                self.convs.append(SpikeGCNConv(self.hidden[idx], self.hidden[idx+1], neuron_type = neuron_type, quantize = quantize))
                self.high_convs.append(SpikeGCNConv(self.hidden[idx], self.hidden[idx+1], neuron_type = neuron_type, quantize = quantize))
                
        self.T = T
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, edge_index, new_x):
        deg = degree(edge_index[1], dtype = torch.long)
        counts = torch.bincount(deg)
        cum_counts = torch.cumsum(counts, 0)
        outlier_percentile = 0.95
        boundary = cum_counts[-1] * outlier_percentile
        boundary_deg = 0
        for idx in range(cum_counts.size(0)):
            if cum_counts[idx] > boundary:
                boundary_deg = idx
                break
        # print(boundary_deg)
        highdeg = torch.argwhere(deg > boundary_deg).reshape(-1)
        lowdeg = torch.argwhere(deg <= boundary_deg).reshape(-1)
        new_x[lowdeg] = 0
        x[highdeg] = 0
        
        
        for t in range(self.T):
            if t == 0 :
                out_spike_counter =  self.encode(x, edge_index, new_x)
            else:
                out_spike_counter += self.encode(x, edge_index, new_x)

        neuron.reset_net(self)
        
        return out_spike_counter / self.T
        
    def encode(self, x, edge_index, new_x):
        input_encoder = encoding.PoissonEncoder()
        x = input_encoder(x)
        new_x = input_encoder(new_x)
        
        for idx, conv in enumerate(self.convs):
            x = conv(x, edge_index)
            new_x = self.high_convs[idx](new_x, edge_index)
            if idx != len(self.convs) -1:
                x = self.dropout(x)
                new_x = self.dropout(new_x)
            if self.bn:
                x = self.bns[idx](x)             
        
        x = (x + new_x) / 2
        return x
        
class SNNGCNN_hdld(torch.nn.Module): 
    def __init__(self, in_features, out_features, hidden = [128],
                 dropout = 0.5, bias = False, T = 15, neuron_type = 'LIF',
                 quantize = False, bn = False):
        super().__init__()

        # BN will be deprecated with values BNTT, TDBN Layers,
        self.bn = bn        
        self.hidden = hidden
        self.convs = nn.ModuleList()
        self.high_convs = nn.ModuleList()
        self.bns = nn.ModuleList()
        self.convs.append(SpikeGCNConv(in_features, self.hidden[0], neuron_type = neuron_type, quantize = quantize))
        self.high_convs.append(SpikeGCNConv(in_features, self.hidden[0], neuron_type = neuron_type, quantize = quantize))
        for idx, hid in enumerate(self.hidden):
            self.bns.append(nn.BatchNorm1d(self.hidden[idx]))
            if idx == len(self.hidden) - 1:
                self.convs.append(SpikeGCNConv(self.hidden[idx], out_features, neuron_type = neuron_type, quantize = quantize))
                self.high_convs.append(SpikeGCNConv(self.hidden[idx], out_features, neuron_type = neuron_type, quantize = quantize))
            else:
                self.convs.append(SpikeGCNConv(self.hidden[idx], self.hidden[idx+1], neuron_type = neuron_type, quantize = quantize))
                self.high_convs.append(SpikeGCNConv(self.hidden[idx], self.hidden[idx+1], neuron_type = neuron_type, quantize = quantize))
                
        self.T = T
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, edge_index, new_x):
        deg = degree(edge_index[1], dtype = torch.long)
        counts = torch.bincount(deg)
        cum_counts = torch.cumsum(counts, 0)
        outlier_percentile = 0.95
        boundary = cum_counts[-1] * outlier_percentile
        boundary_deg = 0
        for idx in range(cum_counts.size(0)):
            if cum_counts[idx] > boundary:
                boundary_deg = idx
                break
        # print(boundary_deg)
        highdeg = torch.argwhere(deg > boundary_deg).reshape(-1)
        lowdeg = torch.argwhere(deg <= boundary_deg).reshape(-1)
        new_x[lowdeg] = 0
        x[highdeg] = 0
        
        
        for t in range(self.T):
            if t == 0 :
                out_spike_counter =  self.encode(x, edge_index, new_x)
            else:
                out_spike_counter += self.encode(x, edge_index, new_x)

        neuron.reset_net(self)
        
        return out_spike_counter / self.T
        
    def encode(self, x, edge_index, new_x):
        input_encoder = encoding.PoissonEncoder()
        x = input_encoder(x)
        new_x = input_encoder(new_x)
        
        for idx, conv in enumerate(self.convs):
            x = conv(x, edge_index)
            new_x = self.high_convs[idx](new_x, edge_index)
            if idx != len(self.convs) -1:
                x = self.dropout(x)
                new_x = self.dropout(new_x)
            if self.bn:
                x = self.bns[idx](x)             
        
        x = (x + new_x) / 2
        return x

class SNNGCNN_GC_Norm(torch.nn.Module): 
    def __init__(self, in_features, out_features, hidden = [128],
                 dropout = 0.5, bias = False, T = 5, neuron_type = 'LIF',
                 quantize = False, bn = False, thtr = False,
                 aggr='mean', thr = 0.25, poisson = True):
        super().__init__()

        # BN will be deprecated with values BNTT, TDBN Layers,
        self.bn = bn        
        self.hidden = hidden
        self.poisson = poisson
        self.convs = nn.ModuleList()
        self.bns = nn.ModuleList()
        self.convs.append(SpikeGCNConv(in_features, self.hidden[0], neuron_type = neuron_type, quantize = quantize, threshold_trainable=thtr, aggr=aggr, thr =thr))
        for idx, hid in enumerate(self.hidden):
            # self.bns.append(nn.BatchNorm1d(self.hidden[idx]))
            # self.bns.append(LayerNorm(self.hidden[idx]))
            
            if idx == len(self.hidden) - 2:
                # self.convs.append(SpikeGCNConv(self.hidden[idx], out_features, neuron_type = neuron_type, quantize = quantize, threshold_trainable=thtr, aggr=aggr, thr=thr))
                break
            else:
                self.convs.append(SpikeGCNConv(self.hidden[idx], self.hidden[idx+1], neuron_type = neuron_type, quantize = quantize, threshold_trainable=thtr, aggr=aggr, thr=thr))
        self.lin = torch.nn.Linear(self.hidden[len(hidden)-1], out_features)
        self.T = T
        self.decode = neuron.LIF(ssize=out_features)
        self.dropout = nn.Dropout(dropout)
        self.origconv = GCNConv(self.hidden[len(hidden)-2],self.hidden[len(hidden)-1] )
        
    def forward(self, data):
        for t in range(self.T):
            if t == 0 :
                out_spike_counter =  self.encode(data)
            else:
                out_spike_counter += self.encode(data)

        neuron.reset_net(self)
        
        return out_spike_counter / self.T
        
    def encode(self, data):
        x, edge_index = data.x, data.edge_index
        if self.poisson:
            input_encoder = encoding.PoissonEncoder()
            x = input_encoder(x)
        for idx, conv in enumerate(self.convs):
            x = conv(x, edge_index)
            # if idx != len(self.convs) -1:
            x = self.dropout(x)
            # if self.bn:
                # x = self.bns[idx](x)        
        x = self.origconv(x, edge_index)
        x = global_mean_pool(x, data.batch)
        # print(x.size())
        x = self.dropout(x)
        x = self.lin(x)
        # x = self.decode(x)
        # print(x)
        return x

def tag2index(dataset):
    tag_set = []
    for g in dataset:
        all_nodes = torch.cat([g.edge_index[0], g.edge_index[1]])
        node_tags = torch.bincount(all_nodes, minlength=g.num_nodes)/2
        node_tags = list(set(list(np.array(node_tags))))
        tag_set += node_tags
    tagset = list(set(tag_set))
    tag2index_dict = {int(tagset[i]):i for i in range(len(tagset))}
    return tag2index_dict

def apply_deg_features(dataset, dataset_name, deg_features=0):
    if deg_features == 1:  # Set node features to the node's degree
        tag2index_dict = tag2index(dataset)
        processed_dataset = []

        for i in range(len(dataset)):
            g = dataset[i]
            all_nodes = torch.cat([g.edge_index[0], g.edge_index[1]])
            node_tags = list(np.array(torch.bincount(all_nodes, minlength=g.num_nodes) / 2))
            features = torch.zeros(g.num_nodes, len(tag2index_dict))
            features[[range(g.num_nodes)], [tag2index_dict[tag] for tag in node_tags]] = 1
            g.x = features
            processed_dataset.append(g)

        dataset = processed_dataset

    elif dataset_name in ['IMDB-BINARY', 'IMDB-MULTI', 'COLLAB', 'REDDIT-BINARY', 'REDDIT-MULTI-5K']:  # Set node features to 1 if none
        processed_dataset = []
        for i in range(len(dataset)):
            g = dataset[i]
            features = torch.ones((g.num_nodes, 1))
            g.x = features
            processed_dataset.append(g)

        dataset = processed_dataset

    return dataset

def dataset_selection(root, dataset_name):
    if dataset_name.lower() == "reddit":
        dataset = Reddit(osp.join(root, 'Reddit'))
        data = dataset[0]
    elif dataset_name.lower() == "flickr":
        dataset = Flickr(osp.join(root, 'Flickr'))
        data = dataset[0]
    elif dataset_name.lower() == "yelp":
        dataset = Yelp(osp.join(root, 'Yelp'))
        data = dataset[0]
    elif dataset_name.lower() == "cora":
        dataset = Planetoid(osp.join(root, 'Cora'), name = "Cora")
        data = dataset[0]
    elif dataset_name.lower() == 'citeseer':
        dataset = Planetoid(osp.join(root, 'Citeseer'), name = "CiteSeer")
        data = dataset[0]
    elif dataset_name.lower() == 'pubmed':
        dataset = Planetoid(osp.join(root, 'PubMed'), name = 'PubMed')
        data = dataset[0]
    elif dataset_name.lower() == 'enzymes':
        dataset = TUDataset(osp.join(root, 'ENZYMES'), name = 'ENZYMES', use_node_attr=True)
        dataset = apply_deg_features(dataset, dataset_name, 0)
        data = dataset[0]
    elif dataset_name.lower() == 'mutag':
        dataset = TUDataset(osp.join(root, 'MUTAG'), name = 'MUTAG', use_node_attr=True)
        data = dataset[0]
    elif dataset_name.lower() == 'proteins':
        dataset = TUDataset(osp.join(root, 'PROTEINS'), name = 'PROTEINS', use_node_attr=True)
        data = dataset[0]
    elif dataset_name == 'COLLAB':
        dataset = TUDataset(osp.join(root, 'COLLAB'), name = 'COLLAB', use_node_attr=True)
        dataset = apply_deg_features(dataset, 'COLLAB', 0)
        data = dataset[0]
    elif dataset_name == 'IMDB-BINARY':
        dataset = TUDataset(osp.join(root, 'IMDB-BINARY'), name = 'IMDB-BINARY', use_node_attr=True)
        dataset = apply_deg_features(dataset, 'IMDB-BINARY', 0)
        data = dataset[0]
    elif dataset_name == 'IMDB-MULTI':
        dataset = TUDataset(osp.join(root, 'IMDB-MULTI'), name = 'IMDB-MULTI', use_node_attr=True)
        dataset = apply_deg_features(dataset, 'IMDB-MULTI', 0)
        data = dataset[0]
    elif dataset_name  == 'REDDIT-BINARY':
        dataset = TUDataset(osp.join(root, 'REDDIT-BINARY'), name = 'REDDIT-BINARY', use_node_attr=True)
        dataset = apply_deg_features(dataset, 'REDDIT-BINARY', 0)
        data = dataset[0]
    elif dataset_name == 'REDDIT-MULTI-5K':
        dataset = TUDataset(osp.join(root, 'REDDIT-MULTI-5K'), name = 'REDDIT-MULTI-5K', use_node_attr=True)
        dataset = apply_deg_features(dataset, 'REDDIT-MULTI-5K', 0)
        data = dataset[0]
    elif dataset_name  == 'PTC_FM':
        dataset = TUDataset(osp.join(root, 'PTC_FM'), name = 'PTC_FM', use_node_attr=True)
        dataset = apply_deg_features(dataset, 'PTC_FM', 0)
        data = dataset[0]
    elif dataset_name == 'NCI1':
        dataset = TUDataset(osp.join(root, 'NCI1'), name = 'NCI1', use_node_attr=True)
        dataset = apply_deg_features(dataset, 'NCI1', 0)
        data = dataset[0]    
    elif dataset_name.lower() == 'cifar10':
        dataset = GNNBenchmarkDataset(osp.join(root, 'CIFAR10'), name = 'CIFAR10')
        data = dataset[0]
    elif dataset_name.lower() == 'mnist':
        dataset = GNNBenchmarkDataset(osp.join(root, 'MNIST'), name = 'MNIST')
        data = dataset[0]
    else:
        data = None
        dataset = None

    assert type(dataset) is not None , f"Please select dataset correctly"
    return data, dataset


def valid_one_batch(data_x, label, edge_index, device):
    with torch.no_grad():
        model.eval()
        logits = []
        labels = []
        
        nodes = data_x
        y = label
        n_imgs = nodes.shape[0]
        out_spikes_counter = torch.zeros((n_imgs, dataset.num_classes)).to(device)
        for t in range(args.T_val):
            out_spikes_counter += model(nodes, edge_index)
        logits.append(out_spikes_counter[data.val_mask].max(1)[1])
        labels.append(y[data.val_mask])
        logits = torch.cat(logits, dim=0).cpu()
        labels = torch.cat(labels, dim=0).cpu()
        # logits = logits.argmax(0)
        metric_macro = metrics.f1_score(labels, logits, average='macro')
        metric_micro = metrics.f1_score(labels, logits, average='micro')
        return metric_macro, metric_micro
    
    
def valid_all_batch(data_x, label, edge_index, loader, device):
    with torch.no_grad():
        model.eval()
        logits = []
        labels = []
        cnt = 0
        for batch in tqdm(loader):
            batch.to(device)
            batch_nodes = batch.x
            batch_label = batch.y[:batch.batch_size]
            n_imgs = batch_nodes.shape[0]
            out_spikes_counter = torch.zeros((batch.batch_size, dataset.num_classes)).to(device)
            for t in range(args.T_val):
                out_spikes_counter += model(batch_nodes, batch.edge_index)[:batch.batch_size]
            logits.append(out_spikes_counter.max(1)[1])
            # logits.append(out_spikes_counter > 0.8)
            labels.append(batch_label)
        
        logits = torch.cat(logits, dim=0).cpu()
        # print(logits.size())
        labels = torch.cat(labels, dim=0).cpu()
        # print(len(logits))
        # logits = logits.argmax(0)
        metric_macro = metrics.f1_score(labels, logits, average='macro')
        metric_micro = metrics.f1_score(labels, logits, average='micro')
        return metric_macro, metric_micro

def valid_gc(loader,device):
    with torch.no_grad():
        model.eval()
        logits = []
        labels = []
        correct = 0
        for batch in tqdm(loader):
            valid_data = batch.to(device)
            out_spikes_counter_frequency = model(valid_data)
            logits.append(out_spikes_counter_frequency.max(1)[1])
            # logits.append(out_spikes_counter > 0.8)
            labels.append(batch.y)
            
            pred = out_spikes_counter_frequency.argmax(dim = 1)
            correct += int((pred == valid_data.y).sum())
        
        logits = torch.cat(logits, dim=0).cpu()
        print(logits.size())
        labels = torch.cat(labels, dim=0).cpu()
        print(len(logits))
        
        logits_unique_values, logits_counts = torch.unique(logits, return_counts=True)
        labels_unique_values, labels_counts = torch.unique(labels, return_counts=True)
        
        # logits = logits.argmax(0)
        metric_macro = metrics.f1_score(labels, logits, average='macro')
        print(len(loader.dataset))
        # metric_micro = metrics.f1_score(labels, logits, average='micro')
        metric_micro = correct / len(loader.dataset)
        
        print(f"Logits: {logits_counts}, labels: {labels_counts}" )
        print(f"correct: {correct}")
        return metric_macro, metric_micro
        
    

def test_one_batch(data_x, label, edge_index, device):
    with torch.no_grad():
        model.eval()
        logits = []
        labels = []
        
        nodes = data_x
        y = label
        n_imgs = nodes.shape[0]
        out_spikes_counter = torch.zeros((n_imgs, dataset.num_classes)).to(device)
        for t in range(args.T_val):
            out_spikes_counter += model(nodes, edge_index)
        logits.append(out_spikes_counter[data.test_mask].max(1)[1])
        labels.append(y[data.test_mask])
        logits = torch.cat(logits, dim=0).cpu()
        labels = torch.cat(labels, dim=0).cpu()
        # logits = logits.argmax(0)
        metric_macro = metrics.f1_score(labels, logits, average='macro')
        metric_micro = metrics.f1_score(labels, logits, average='micro')
        return metric_macro, metric_micro


if __name__ == '__main__':
    # wandb.init(project = 'Spiking Neural Network Experiment')
    CURRENT_TASK_TYPE = "GC"
    torch.autograd.set_detect_anomaly(True)
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset", nargs="?", default="cora",
                    help="Datasets (Reddit and Flickr only). (default: cora)")
    parser.add_argument("--opt", type  = str, default = 'adamw',
                        help = 'Choose types of optimizer for trianing process')
    parser.add_argument("--id", type = int, default = 5, help = "Experiment ID setting")
    parser.add_argument('--sizes', type=int, nargs='+', default=[-1,-1,-1],
                        help='For Neighborhood sampling, each . (default: full batch)')
    parser.add_argument('--hids', type=int, nargs='+',
                        default=[128,128], help='Hidden units for each layer. (default: [128, 50])')
    parser.add_argument("--aggr", nargs="?", default="add",
                        help="Aggregate function ('mean', 'sum'). (default: 'mean')")
    parser.add_argument("--sampler", nargs="?", default="sage",
                        help="Neighborhood Sampler, including uniform sampler from GraphSAGE ('sage') and random walk sampler ('rw'). (default: 'sage')")
    parser.add_argument("--surrogate", nargs="?", default="sigmoid",
                        help="Surrogate function ('sigmoid', 'triangle', 'arctan', 'mg', 'super'). (default: 'sigmoid')")
    parser.add_argument("--neuron", nargs="?", default="LIF",
                        help="Spiking neuron used for training. (IF, LIF, PLIF, AdaptiveLIF, AdaptiveIF, AdaptivePLIF). (default: LIF")
    parser.add_argument('--lr', type=float, default=5e-3,
                        help='Learning rate for training. (default: 5e-3)')
    parser.add_argument('--alpha', type=float, default=1.0,
                        help='Smooth factor for surrogate learning. (default: 1.0)')
    parser.add_argument('--T', type=int, default=5,
                        help='Number of time steps. (default: 5)')
    parser.add_argument('--T_val', type=int, default=5,
                        help='Number of time steps for validation and test. (default: 5)')
    parser.add_argument('--dropout', type=float, default=0.5,
                        help='Dropout probability. (default: 0.5)')
    parser.add_argument('--epochs', type=int, default=500,
                        help='Number of training epochs. (default: 500)')
    parser.add_argument('--concat', action='store_true',
                        help='Whether to concat node representation and neighborhood representations. (default: False)')
    parser.add_argument('--seed', type=int, default=7777,
                        help='Random seed for model. (default: 777)')
    parser.add_argument('--quantize', action = 'store_true', default = False,
                        help = 'Quantize for calibration')
    parser.add_argument('--bs', type=int, default=1100,
                        help='Number of batch size for mini batching. (default: fullsize)')
    parser.add_argument('--thtr', action = 'store_true', default = False,)
    parser.add_argument('--db_name', type=str, default='Main')
    parser.add_argument('--no_db', action="store_true")
    parser.add_argument('--thr', type =float, default = 0.25)
    parser.add_argument('--loss', type =str, default = 'ce')
    parser.add_argument('--root', type = str, default = "./data")
    parser.add_argument('--model', type = str, default = 'SNNGCNN_GC')
    parser.add_argument('--no_poisson', action = "store_true", default = False)
    parser.add_argument('--deg_bins', type = int, default = 10)
    parser.add_argument('--num_layers', type = int, default=2)
    
    # parser.add_argument('--exp', required=True, type='int')
    
    
    args = parser.parse_args()
    args.split_seed = 42
    tab_printer(args)
    # assert len(args.hids) == len(args.sizes) - 1, "must be equal!"
    
    if not args.no_db:
        try:
            mongo = MongoManager.DBHandler(args.db_name)
            DB_ERROR= False
            print('DB prepared')
        except:
            mongo = None
            DB_ERROR = True
            print('DB ERROR OCCURED')
    else:
        mongo = None
        DB_ERROR = True
            
    print(args.db_name)
        
    cur_time = int(time.time())
    dt_object = datetime.fromtimestamp(cur_time)
    readable_time = dt_object.strftime('%Y-%m-%d %H:%M:%S')
    
    args.time = cur_time
    args.args_to_string = vars(args).copy()
    args.fold = 0
    args.finish = False
    print(args)
    
    # Set logger with standard output
    current_file_name = stdout = f'experiment/{args.dataset}_neuron{args.neuron}_Tval{args.T_val}_ep{args.epochs}_thr{args.thr}_aggr{args.aggr}_{cur_time}_{args.model}.txt'
    log_file = open(current_file_name, 'w')
    # Save the original stdout so we can still print to the console
    original_stdout = sys.stdout

    # Redirect stdout to your custom class
    sys.stdout = StreamRedirector(log_file, original_stdout)
    set_seed(args.seed)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    root = args.root
    data, dataset = dataset_selection(root, args.dataset)
    
    # Do for K-fold Cross Validation
    splits = [(0,0)]
    # Try to make fold only 1    
    num_folds = 1
    if args.dataset == 'MNIST' or args.dataset == 'CIFAR10':
        train_data = GNNBenchmarkDataset(osp.join(root, args.dataset), name=args.dataset, split='train')
        val_data = GNNBenchmarkDataset(osp.join(root, args.dataset), name=args.dataset, split='val')
        test_data = GNNBenchmarkDataset(osp.join(root, args.dataset), name=args.dataset, split='test')
        # splits.append([])
        # Additional preprocessing

        # train_data.data.x = torch.cat((train_data.data.x, train_data.data.pos), dim = 1)
        # val_data.data.x = torch.cat((val_data.data.x, val_data.data.pos), dim = 1)
        # test_data.data.x = torch.cat((test_data.data.x, test_data.data.pos), dim = 1)
        
        dataset = train_data
        num_features = train_data.num_features
        num_classes = train_data.num_classes
        val_size = args.bs
        test_size = args.bs
        
    elif args.dataset.lower() == 'pattern' or args.dataset.lower() == 'cluster':
        train_data = GNNBenchmarkDataset(osp.join(root, args.dataset), name=args.dataset, split='train')
        val_data = GNNBenchmarkDataset(osp.join(root, args.dataset), name=args.dataset, split='val')
        test_data = GNNBenchmarkDataset(osp.join(root, args.dataset), name=args.dataset, split='test')
        
    else:
        # Need to do 10-fold Cross Validation for the experiment
        num_folds = 10
        labels = [data.y.item() for data in dataset]
        stratified_kfold = StratifiedKFold(n_splits=10, shuffle=True, random_state=args.seed)
        splits = stratified_kfold.split(torch.zeros(len(labels)), labels)
        
        if args.dataset.lower() == 'enzymes':
            num_features = dataset.num_features
            num_classes = dataset.num_classes
        elif args.dataset.lower() == 'mutag':
            num_features = dataset.num_features
            num_classes = dataset.num_classes
        elif args.dataset.lower() == 'proteins':
            num_features = dataset.num_features
            num_classes = dataset.num_classes
        elif args.dataset.lower() == 'nci1':
            num_features = dataset.num_features
            num_classes = dataset.num_classes
        elif args.dataset.lower() == 'imdb_binary':
            num_features = dataset.num_features
            num_classes = dataset.num_classes
        elif args.dataset.lower() == 'reddit-binary':
            num_features = 1
            num_classes = 2
        elif args.dataset.lower() == 'imdb-binary':
            num_features = 1
            num_classes = 2
        elif args.dataset.lower() == 'collab':
            num_features = 1
            num_classes = 3
        elif args.dataset.lower() == 'ptc_fm':
            num_features = dataset.num_features
            num_classes = dataset.num_classes
    
    # Make it cluster mode
    degree_to_label = {}
    
    if "Cluster" in args.model or args.deg_bins == -1:
        all_degrees = []            
        
        for data in dataset:
            if "Cluster" in args.model or "GCN" in args.model:
                edge_index, _ = add_self_loops(data.edge_index, num_nodes=data.num_nodes)
            else:
                edge_index = data.edge_index
            node_degrees = degree(edge_index[0], dtype=torch.long)
            all_degrees.append(node_degrees.numpy())  # Convert to numpy for KMeans
        
        # Concatenate all degree arrays into a single numpy array
        all_degrees = np.concatenate(all_degrees).reshape(-1, 1)  # Reshape for KMeans
        if args.deg_bins == -1:
            unique_degrees = np.unique(all_degrees)
            args.deg_bins = len(unique_degrees)  # Set number of clusters to unique degrees count
        kmeans = KMeans(n_clusters=args.deg_bins, random_state=args.seed).fit(all_degrees)
        cluster_labels = kmeans.labels_

        # Sort centroids and create a new mapping of labels
        centroids = kmeans.cluster_centers_.squeeze()
        sorted_indices = np.argsort(centroids)
        label_map = {old_label: new_label for new_label, old_label in enumerate(sorted_indices)}
        new_labels = np.array([label_map[label] for label in cluster_labels])

        # Group degrees by new labels
        degree_label_mapping = {}

        for ddegree, label in zip(all_degrees.squeeze(), new_labels):
            if label not in degree_label_mapping:
                degree_label_mapping[label] = []
            degree_label_mapping[label].append(ddegree)

        # Print the mapping of degrees to new labels
        for label, degrees in degree_label_mapping.items():
            print(f"Label {label}: Degrees {set(degrees)}")

        # Optionally, print more structured output
        print("\nStructured Mapping:")
        for label in sorted(degree_label_mapping):
            print(f"Label {label}: Degrees {set(degree_label_mapping[label])}")
        for label, degrees in degree_label_mapping.items():
            for degree in degrees:
                degree_to_label[degree] = label

    #### 10-CV classificaiton
    #### TODO: Available for Node Classification and Singel case of dataset
    
    total_score_dict = {}
    fold_folder_name = ''
    for fold, (train_idx, test_idx) in enumerate(splits):
        
        # Set to save K-fold Cross Validation Single Values
        args.fold = fold
        item_dict = vars(args).copy()
        item_dict['args_to_string'] = str(vars(args))
        print("ARGS TO STRING CHECK")
        print(item_dict)
        
        
        #### Dataset Preparation
        
        if args.dataset.lower() in ['enzymes', 'mutag', 'proteins', "nci1", "ptc_fm"]:
            train_val_dataset = dataset[train_idx.tolist()]
            test_data = dataset[test_idx.tolist()]
            
            train_data = train_val_dataset
            val_data = test_data
        elif args.dataset.lower() in ['reddit-binary','imdb-binary', 'collab']:
            train_data = Subset(dataset, train_idx)
            test_data = Subset(dataset, test_idx)
            val_data = Subset(dataset, test_idx)
            
            
        train_loader=DataLoader(dataset=train_data,batch_size=args.bs,shuffle=True)
        val_loader=DataLoader(dataset=val_data,batch_size=args.bs,shuffle=False)
        test_loader=DataLoader(dataset=test_data,batch_size = args.bs,shuffle=False)
        
        #### Model Preparation
        
        if args.model == 'SNNGCNN_GC':
            model = SNNGCNN_GC(in_features=num_features, out_features = num_classes,
                            T = args.T, hidden=args.hids, neuron_type = args.neuron, 
                            quantize = args.quantize,  thtr = args.thtr,
                            aggr= args.aggr, thr = args.thr, bn = True, poisson = not(args.no_poisson)).to(device)
            
        elif args.model == 'SNNGCNN_GC_multi':
            model = SNNGCNN_GC_Multi(in_features=num_features, out_features = num_classes,
                            T = args.T, hidden=args.hids, neuron_type = args.neuron, 
                            quantize = args.quantize,  thtr = args.thtr,
                            aggr= args.aggr, thr = args.thr, bn = True, poisson = not(args.no_poisson)).to(device)
        
        elif args.model == 'SNNGCNN_GC_AM':
            model = SNNGCNN_GC_AM(in_features=num_features, out_features = num_classes,
                            T = args.T, hidden=args.hids, neuron_type = args.neuron, 
                            quantize = args.quantize,  thtr = args.thtr,
                            aggr= args.aggr, thr = args.thr, bn = True).to(device)
        
        elif args.model == "SNNGCNN_GC_one":
            model = SNNGCNN_GC_onethreshold(in_features=num_features, out_features = num_classes,
                            T = args.T, hidden=args.hids, neuron_type = args.neuron, 
                            quantize = args.quantize,  thtr = args.thtr,
                            aggr= args.aggr, thr = args.thr, bn = True, poisson = not(args.no_poisson)).to(device)
        elif args.model == 'SNNGCNN_GC_outspike':
            model = SNNGCNN_GC_outspike(in_features=num_features, out_features = num_classes,
                            T = args.T, hidden=args.hids, neuron_type = args.neuron, 
                            quantize = args.quantize,  thtr = args.thtr,
                            aggr= args.aggr, thr = args.thr, bn = True, poisson = not(args.no_poisson)).to(device)
        elif args.model == 'SNNGCNN_GC_Spikeonly':
            model = SNNGCNN_GC_Spikeonly(in_features=num_features, out_features = num_classes,
                            T = args.T, hidden=args.hids, neuron_type = args.neuron, 
                            quantize = args.quantize,  thtr = args.thtr,
                            aggr= args.aggr, thr = args.thr, bn = True, poisson = not(args.no_poisson)).to(device)
        elif args.model == 'SNNGCNN_GC_Norm':
            model = SNNGCNN_GC_Norm(in_features=num_features, out_features = num_classes,
                            T = args.T, hidden=args.hids, neuron_type = args.neuron, 
                            quantize = args.quantize,  thtr = args.thtr,
                            aggr= args.aggr, thr = args.thr, bn = True, poisson = not(args.no_poisson)).to(device)
        elif args.model == 'SNNGCNN_GC_Degree':
            model = SNNGCNN_GC_Degree(in_features=num_features, out_features = num_classes, device = device,
                            T = args.T, hidden=args.hids, neuron_type = args.neuron, 
                            quantize = args.quantize,  thtr = args.thtr,
                            aggr= args.aggr, thr = args.thr, bn = True, poisson = not(args.no_poisson),
                            bins=  args.deg_bins).to(device)
        
        elif args.model == 'SNNGCNN_GC_Degree_Feat':
            if degree_to_label:
                model = SNNGCNN_GC_Degree_Feat(in_features=num_features, out_features = num_classes, device = device,
                                T = args.T, hidden=args.hids, neuron_type = args.neuron, 
                                quantize = args.quantize,  thtr = args.thtr,
                                aggr= args.aggr, thr = args.thr, bn = True, poisson = not(args.no_poisson),
                                bins=  args.deg_bins, degree_to_label = degree_to_label).to(device)
            else:
                model = SNNGCNN_GC_Degree_Feat(in_features=num_features, out_features = num_classes, device = device,
                                T = args.T, hidden=args.hids, neuron_type = args.neuron, 
                                quantize = args.quantize,  thtr = args.thtr,
                                aggr= args.aggr, thr = args.thr, bn = True, poisson = not(args.no_poisson),
                                bins=  args.deg_bins, degree_to_label = None).to(device)
                
        elif args.model == "SNNGCNN_GC_Degree_Feat_Cluster":
            model = SNNGCNN_GC_Degree_Feat_Cluster(in_features=num_features, out_features = num_classes, device = device,
                            T = args.T, hidden=args.hids, neuron_type = args.neuron, 
                            quantize = args.quantize,  thtr = args.thtr,
                            aggr= args.aggr, thr = args.thr, bn = True, poisson = not(args.no_poisson),
                            bins=  args.deg_bins, degree_to_label=degree_to_label).to(device)
        elif args.model == "SNNGCNN_GC_Degree_Feat_Spikeonly":
            model = SNNGCNN_GC_Degree_Feat_Spikeonly(in_features=num_features, out_features = num_classes, device = device,
                            T = args.T, hidden=args.hids, neuron_type = args.neuron, 
                            quantize = args.quantize,  thtr = args.thtr,
                            aggr= args.aggr, thr = args.thr, bn = True, poisson = not(args.no_poisson),
                            bins=  args.deg_bins).to(device)
            
        elif args.model == "SNNGIN_GC_Degree_Feat_fixed":
            if degree_to_label:
                model = SNNGIN_GC_Degree_Feat(in_features=num_features, out_features = num_classes, device = device,
                                T = args.T, hidden=args.hids, neuron_type = args.neuron, 
                                quantize = args.quantize,  thtr = args.thtr,
                                aggr= args.aggr, thr = args.thr, bn = True, poisson = not(args.no_poisson),
                                bins=  args.deg_bins, degree_to_label = degree_to_label).to(device)
            else:
                model = SNNGIN_GC_Degree_Feat(in_features=num_features, out_features = num_classes, device = device,
                                T = args.T, hidden=args.hids, neuron_type = args.neuron, 
                                quantize = args.quantize,  thtr = args.thtr,
                                aggr= args.aggr, thr = args.thr, bn = True, poisson = not(args.no_poisson),
                                bins=  args.deg_bins, degree_to_label = None).to(device)
        elif args.model == "SGAT_Degree_Feat":
            if degree_to_label:
                model = SGAT_Degree_Feat(in_features=num_features, out_features = num_classes, device = device,
                                T = args.T, hidden=[64,64,64], neuron_type = args.neuron, 
                                quantize = args.quantize,  thtr = args.thtr,
                                aggr= args.aggr, thr = args.thr, bn = True, poisson = not(args.no_poisson),
                                bins=  args.deg_bins, degree_to_label = degree_to_label).to(device)
            else:
                model = SGAT_Degree_Feat(in_features=num_features, out_features = num_classes, device = device,
                                T = args.T, hidden=[64,64,64], neuron_type = args.neuron, 
                                quantize = args.quantize,  thtr = args.thtr,
                                aggr= args.aggr, thr = args.thr, bn = True, poisson = not(args.no_poisson),
                                bins=  args.deg_bins, degree_to_label = None).to(device)
        
        
        print(model)
        
        #### Select Optimizer
        optimizer = None
        if args.opt.lower() == 'adamw':
            optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)
        elif args.opt.lower() == 'adam':
            optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
        else:
            assert type(optimizer) is not None, "Optimizer not declared"
        
        
        #### Select Loss function
        if (args.loss).lower() == 'mse':
            loss_fn = nn.MSELoss()
        elif (args.loss).lower() == 'ce':
            loss_fn = nn.CrossEntropyLoss()
        ### Training Process
        best_ep = -1
        best_val_metric = test_metric = 0
        best_train_metric = 0
        only_test_metric = (0,0)
        start = time.time()
        all_test_metric_lst = []
        all_val_metric_lst = []
        all_train_metric_lst = []
        
        cur_exp_folder = f'./pics_and_spike/{args.neuron}_{args.dataset}_{args.epochs}_{args.thr}_{args.T}_bins{args.deg_bins}_{args.model}'
        if fold == 0:
            cur_exp_folder = rename_folder_with_suffix(cur_exp_folder) # Prevent to overwirte
            fold_folder_name = cur_exp_folder
        else:
            cur_exp_folder = fold_folder_name
        
        os.makedirs(cur_exp_folder, exist_ok=True)
        
        for epoch in range(1, args.epochs + 1):
            train_logits = []
            train_labels = []
            model.train()
            total_loss = 0
            train_acc = 0
            out_spikes_counter_frequency_lst = []
            
            if epoch == 1 and not args.no_db:
                    try:
                        mongo.insert_item_one(item_dict)
                        print("INSERTION SUCCESS ")
                        DB_ERROR = False
                    except Exception as e:
                        DB_ERROR = True
                        print("DB INSERT ERROR OCCURED!!!!!!!!!")
                        print(e)
                        
            # Mini-Batch Training for GNN
            # If batch size is bigger than original size it would be reproduce same result
            # If batch size is smaller than original set, it will partitioning into small groups
            for data in tqdm(train_loader):
                data = data.to(device)
                optimizer.zero_grad()
                out_spikes_counter_frequency = model(data)
                
                # for i in range(len(model.convs)):
                    # cur_path=  cur_exp_folder + f'/SpikeConv{i}_firingrate_epoch{epoch}'
                    # model.convs[i].neuron.plot_spike_rate_vs_degree(cur_path)
                # print(torch.unique(out_spikes_counter_frequency))
            
                loss = loss_fn(out_spikes_counter_frequency, data.y)
                if args.neuron == 'BLIF':
                    loss.backward(retain_graph = True)
                else:
                    loss.backward()
                optimizer.step()
                train_logits.append(out_spikes_counter_frequency.max(1)[1])
                train_labels.append(data.y)
                
                out_spikes_counter_frequency_lst += out_spikes_counter_frequency.cpu()
                total_loss += loss.item()
            
            train_logits = torch.cat(train_logits, dim = 0).cpu()
            train_labels = torch.cat(train_labels, dim = 0).cpu()
            labels_unique_values, labels_counts = torch.unique(train_logits, return_counts=True)
            # print(f'Train logits: {labels_counts}')

            train_metric = metrics.f1_score(train_labels, train_logits, average='macro'), metrics.f1_score(train_labels, train_logits, average='micro')
            ## Evaluation
            val_metric = valid_gc(val_loader, device)
            test_metric = valid_gc(test_loader, device)
            
            # val_metric = valid_all_batch(batch_nodes, batch_label, batch_edge_index, val_loader, device)
            # test_metric = valid_all_batch(batch_nodes, batch_label, batch_edge_index, test_loader, device)
            all_train_metric_lst += ([train_metric[1]] * args.T * len(train_loader)) 
            all_val_metric_lst += ([val_metric[1]] * args.T * len(train_loader))
            all_test_metric_lst += ([test_metric[1]] * args.T * len(train_loader))
            
            if train_metric[1] > best_train_metric:
                best_train_metric = train_metric[1]
                
            
            if test_metric[1] > only_test_metric[1]:
                only_test_metric = test_metric
                
            
            if val_metric[1] > best_val_metric:
                best_val_metric = val_metric[1]
                best_test_metric = test_metric
                best_ep = epoch
                
                for i in range(len(model.convs)):
                    # print(fold)
                    model.convs[i].neuron.save_neuron_spikes(f'{cur_exp_folder}/K{fold}_Spikeconv{i}')
                    # print(model.convs[i].neuron.spike_counts)
                    model.convs[i].neuron.reset_stat()
            end = time.time()
            
            for i in range(len(model.convs)):
                model.convs[i].neuron.reset_stat()
        
               
            print(
                f'Fold {fold} Epoch: {epoch:03d}, Train: {train_metric[1]:.4f} Val: {val_metric[1]:.4f}, Test: {test_metric[1]:.4f}, Loss: {total_loss}\
                \nBest: Macro-{best_test_metric[0]:.4f}, Micro-{best_test_metric[1]:.4f}, Time elapsed {end-start:.2f}s,\
                Best Test : Macro-{only_test_metric[0]:.4f}, Micro-{only_test_metric[1]:.4f}')

        ###### Plottting Threshold ######## 
        plots_lst = []
        timestamp = args.T
        
        for i in range(len(model.convs)):
            plots_lst.append(model.convs[i].neuron.v_threshold_values)
        
        fig, ax1 = plt.subplots(figsize=(18,12))
        plt.title(f'{args.dataset} {args.neuron} all Threshold')    
        
        for i in range(len(plots_lst)):
            sns.lineplot(plots_lst[i], label = f'SpikeConv{i}', alpha = 0.25)
        
        ax2 = ax1.twinx()
        sns.lineplot(all_test_metric_lst, label='Test Acc', color = 'red')
        sns.lineplot(all_val_metric_lst, label='Val Acc', color = 'green')
        sns.lineplot(all_train_metric_lst, label='Train Acc', color = 'blue')
        ax2.set_ylim(0,1)
        
        plt.legend()
        plt.xlabel('Epoch (ep) TimeStep (t)')
        plt.ylabel('Threshold Values (v)')
        plt.tight_layout()

        plt.savefig(f'{cur_exp_folder}/All_Thrshold Values_{epoch}_{args.dataset}_{args.neuron}_{args.thr}_{args.aggr}_K{fold}.png')
        plt.clf()
        
        #### Finish State
        # Check DB Updated or Failed
        #dSend information to the SLACK token or
        if not args.no_db and not DB_ERROR:
            mongo.update_item_one({"args_to_string":item_dict['args_to_string']}, {"$set":{
                    "finish":True,
                    "pic_path": f'{cur_exp_folder}/All_Thrshold Values_{epoch}_{args.dataset}_{args.neuron}_{args.thr}.png'}})
            slack_message = f"EXP: All_Thrshold Values_{epoch}_{args.dataset}_{args.neuron}_{args.thr}_{args.aggr}_fold{fold}\n \
            Train: {train_metric[1]:.4f} Val: {val_metric[1]:.4f}, Test: {test_metric[1]:.4f}, Loss: {total_loss}\n\
            Best: Macro-{best_test_metric[0]:.4f}, Micro-{best_test_metric[1]:.4f} \n \
            Best Test : Macro-{only_test_metric[0]:.4f}, Micro-{only_test_metric[1]:.4f} "
            requests.post(slack_my_token,json={'text': slack_message})
                
        else:
            print("The experimental result did not sent to to DB")
            final_info = {"args_to_str": item_dict['args_to_string'], 
                        "training_f1": train_metric[0],
                        "training_acc": train_metric[1], 
                        "best_acc" : best_test_metric,
                        "val_acc": best_val_metric,
                        "only_test_acc": only_test_metric[1],
                        "best_ep" : best_ep,
                        "fold": fold}
            
            headers = {'Content-Type': 'application/json'}
            slack_message = json.dumps(final_info)
            requests.post(slack_my_token,json=slack_message, headers=headers)
            
        total_score_dict[fold] = (train_metric, best_val_metric, best_test_metric)
    
    #### ENSEMBLE whole of the data 
    
    train_sum = 0
    val_sum = 0
    test_sum = 0
    
    for key, (train_acc, val_acc, test_acc) in total_score_dict.items():
        print(f"Fold {key} : Train acc {train_acc[1]} Test acc {test_acc}")
        train_sum += train_acc[1]
        val_sum += val_acc
        test_sum += test_acc[1]
    
    train_sum /= num_folds
    val_sum /= num_folds
    test_sum /= num_folds
    print(f"CV {num_folds} : Train acc {train_sum} Test acc {test_sum}")
    
    print("Total scores for 10-CV Cross Validation finished")
    
    args.fold = 'all'
    item_dict = vars(args).copy()
    item_dict['args_to_string'] = str(vars(args))
