import argparse
import os.path as osp
import time

import psutil
import torch
import torch.nn.functional as F
from ogb.nodeproppred import PygNodePropPredDataset
from torch import Tensor
from tqdm import tqdm

from torch_geometric import seed_everything
from torch_geometric.loader import NeighborLoader
from torch_geometric.nn.models import GAT, GraphSAGE #, SGFormer
from torch_geometric.nn import ChebConv
from torch_geometric.utils import (
    add_self_loops,
    remove_self_loops,
    to_undirected,
)
import wandb 
from Stable.EulerConv import Euler_ChebConv

parser = argparse.ArgumentParser(
    formatter_class=argparse.ArgumentDefaultsHelpFormatter, )
parser.add_argument(
    '--dataset',
    type=str,
    default='ogbn-arxiv',
    choices=['ogbn-papers100M', 'ogbn-products', 'ogbn-arxiv'],
    help='Dataset name.',
)
parser.add_argument(
    '--dataset_dir',
    type=str,
    default='./data',
    help='Root directory of dataset.',
)
parser.add_argument(
    "--model",
    type=str.lower,
    default='GAT',
    choices=['sage', 'gat', 'sgformer','chebnet','eulerchebnet'],
    help="Model used for training",
)

parser.add_argument('-e', '--epochs', type=int, default=140)
parser.add_argument('--K', type=int, default=4)
parser.add_argument('--seed', type=int, default=1234)
parser.add_argument('--num_layers', type=int, default=5)
parser.add_argument('--num_heads', type=int, default=1,
                    help='number of heads for GAT model.')
parser.add_argument('-b', '--batch_size', type=int, default=1024)
parser.add_argument('--num_workers', type=int, default=12)
parser.add_argument('--fan_out', type=int, default=10,
                    help='number of neighbors in each layer')
parser.add_argument('--hidden_channels', type=int, default=256)
parser.add_argument('--step_size', type=float, default=0.1)
parser.add_argument('--lr', type=float, default=0.001)
parser.add_argument('--wd', type=float, default=0.0)
parser.add_argument('--dropout', type=float, default=0.2)
parser.add_argument(
    '--use_directed_graph',
    action='store_true',
    help='Whether or not to use directed graph',
)
parser.add_argument(
    '--add_self_loop',
    action='store_true',
    help='Whether or not to add self loop',
)
args = parser.parse_args()

wall_clock_start = time.perf_counter()

if (args.dataset == 'ogbn-papers100M'
        and (psutil.virtual_memory().total / (1024**3)) < 390):
    print('Warning: may not have enough RAM to run this example.')
    print('Consider upgrading RAM if an error occurs.')
    print('Estimated RAM Needed: ~390GB.')

print(f'Training {args.dataset} with {args.model} model.')

seed_everything(args.seed)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_epochs = args.epochs
num_layers = args.num_layers
num_workers = args.num_workers
num_hidden_channels = args.hidden_channels
batch_size = args.batch_size
root = osp.join(args.dataset_dir, args.dataset)
print('The root is: ', root)
dataset = PygNodePropPredDataset(name=args.dataset, root=root)
split_idx = dataset.get_idx_split()
data = dataset[0]

if not args.use_directed_graph:
    data.edge_index = to_undirected(data.edge_index, reduce='mean')
if args.add_self_loop:
    data.edge_index, _ = remove_self_loops(data.edge_index)
    data.edge_index, _ = add_self_loops(data.edge_index,
                                        num_nodes=data.num_nodes)

data.to(device, 'x', 'y')

train_loader = NeighborLoader(
    data,
    input_nodes=split_idx['train'],
    num_neighbors=[args.fan_out] * num_layers,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    persistent_workers=True,
    disjoint=args.model == "sgformer",
)
val_loader = NeighborLoader(
    data,
    input_nodes=split_idx['valid'],
    num_neighbors=[args.fan_out] * num_layers,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    persistent_workers=True,
    disjoint=args.model == "sgformer",
)
test_loader = NeighborLoader(
    data,
    input_nodes=split_idx['test'],
    num_neighbors=[args.fan_out] * num_layers,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    persistent_workers=True,
    disjoint=args.model == "sgformer",
)


def train(epoch: int) -> tuple[Tensor, float]:
    model.train()

    pbar = tqdm(total=split_idx['train'].size(0))
    pbar.set_description(f'Epoch {epoch:02d}')

    total_loss = total_correct = 0
    for batch in train_loader:
        optimizer.zero_grad()
        if args.model == "sgformer":
            out = model(batch.x, batch.edge_index.to(device),
                        batch.batch.to(device))[:batch.batch_size]
        else:
            out = model(batch.x,
                        batch.edge_index.to(device))[:batch.batch_size]
            
        y = batch.y[:batch.batch_size].squeeze().to(torch.long)
        loss = F.cross_entropy(out, y)
        loss.backward()
        optimizer.step()

        total_loss += float(loss)
        total_correct += int(out.argmax(dim=-1).eq(y).sum())
        pbar.update(batch.batch_size)

    pbar.close()
    loss = total_loss / len(train_loader)
    approx_acc = total_correct / split_idx['train'].size(0)
    return loss, approx_acc


@torch.no_grad()
def test(loader: NeighborLoader) -> float:
    model.eval()

    total_correct = total_examples = 0
    for batch in loader:
        batch = batch.to(device)
        batch_size = len(batch.n_id) #batch.num_sampled_nodes[0]
        if args.model == "sgformer":
            out = model(batch.x, batch.edge_index,
                        batch.batch)[:batch.batch_size]
        else:
            out = model(batch.x, batch.edge_index)[:batch_size]
        pred = out.argmax(dim=-1)
        y = batch.y[:batch_size].view(-1).to(torch.long)

        total_correct += int((pred == y).sum())
        total_examples += y.size(0)

    return total_correct / total_examples


device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
import torch
import torch.nn.functional as F
from torch.nn import BatchNorm1d
from torch_geometric.nn import ChebConv

class ChebNet(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels,num_layers, out_channels ,K,dropout):
        super(ChebNet, self).__init__()
        self.convs = torch.nn.ModuleList()
        self.bns = torch.nn.ModuleList()
        self.dropout = dropout

        # Input layer: input_dim -> hidden_dim
        self.convs.append(ChebConv(in_channels, hidden_channels, K=K))
        self.bns.append(BatchNorm1d(hidden_channels))

        # Hidden layers: hidden_dim -> hidden_dim
        for _ in range(num_layers - 2):
            self.convs.append(ChebConv(hidden_channels, hidden_channels, K=K))
            self.bns.append(BatchNorm1d(hidden_channels))

        # Output layer: hidden_dim -> output_dim
        self.convs.append(ChebConv(hidden_channels, out_channels, K=K))
        # No batchnorm for the output layer

    def forward(self, x, edge_index):

        for i, conv in enumerate(self.convs):
            x = conv(x, edge_index)
            # Apply ReLU, batchnorm, dropout except for the last layer
            if i < len(self.convs) - 1:
                x = F.leaky_relu(x)
                # x = self.bns[i](x)
                ## Add dropout
                x = F.dropout(x, p=0.25, training=self.training)
        return x


class EulerChebNet(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels,num_layers, out_channels, K, step_size,dissipation_force):
        super(EulerChebNet, self).__init__()
        self.convs = torch.nn.ModuleList()
        self.bns = torch.nn.ModuleList()

        # Input linear layer: input_dim -> hidden_dim
        self.lin=torch.nn.Linear(in_channels, hidden_channels)

        # Hidden layers: hidden_dim -> hidden_dim
        for _ in range(num_layers):
            self.convs.append(Euler_ChebConv(hidden_channels, hidden_channels, K=K, step_size=step_size,dissipation_force=dissipation_force))
            self.bns.append(BatchNorm1d(hidden_channels))

        # Output layer: hidden_dim -> output_dim
        self.classify=torch.nn.Linear(hidden_channels, out_channels)
        # No batchnorm for the output layer

    def forward(self, x, edge_index):

        x=self.lin(x)
        for i, conv in enumerate(self.convs):
            x = conv(x, edge_index)
            # Apply ReLU, batchnorm, dropout except for the last layer
            if i < len(self.convs) - 1:
                x = F.relu(x)
                x = self.bns[i](x)
                x = F.dropout(x, p=0.2, training=self.training)
        x=self.classify(x)
        return x

def get_model(model_name: str) -> torch.nn.Module:
    if model_name == 'gat':
        model = GAT(
            in_channels=dataset.num_features,
            hidden_channels=num_hidden_channels,
            num_layers=num_layers,
            out_channels=dataset.num_classes,
            dropout=args.dropout,
            heads=args.num_heads,
        )
    elif model_name == 'chebnet':
        model = ChebNet(
            in_channels=dataset.num_features,
            hidden_channels=num_hidden_channels,
            num_layers=num_layers,
            out_channels=dataset.num_classes,
            K=args.K,
            dropout=args.dropout,
        )
    elif model_name == 'eulerchebnet':
        model = EulerChebNet(
            in_channels=dataset.num_features,
            hidden_channels=num_hidden_channels,
            num_layers=num_layers,
            out_channels=dataset.num_classes,
            K=args.K,
            step_size=args.step_size,
            dissipation_force=0.1
        )
    elif model_name == 'sage':
        model = GraphSAGE(
            in_channels=dataset.num_features,
            hidden_channels=num_hidden_channels,
            num_layers=num_layers,
            out_channels=dataset.num_classes,
            dropout=args.dropout,
        )
    elif model_name == 'sgformer':
        model = SGFormer(
            in_channels=dataset.num_features,
            hidden_channels=num_hidden_channels,
            out_channels=dataset.num_classes,
            trans_num_heads=args.num_heads,
            trans_dropout=args.dropout,
            gnn_num_layers=num_layers,
            gnn_dropout=args.dropout,
        )
    else:
        raise ValueError(f'Unsupported model type: {model_name}')

    return model


model = get_model(args.model).to(device)
# model.reset_parameters()
optimizer = torch.optim.Adam(
    model.parameters(),
    lr=args.lr,
    weight_decay=args.wd,
)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max',
                                                       patience=5)

print(f'Total time before training begins took '
      f'{time.perf_counter() - wall_clock_start:.4f}s')
print('Training...')

wandb.init(
project="OGB_arxiv2025",
name= args.model+"_"+str(args.hidden_channels)+"hidden_"+str(args.num_layers)+"lays_"+str(args.K)+"K",
)
# config=config,

# Optionally, log the args (if modified via command line)
wandb.config.update(args)
wandb.config.update({"Total params": sum(p.numel() for p in model.parameters())})

times = []
train_times = []
inference_times = []
best_val = 0.
for epoch in range(1, num_epochs + 1):
    train_start = time.perf_counter()
    loss, acc = train(epoch)
    train_times.append(time.perf_counter() - train_start)

    inference_start = time.perf_counter()
    val_acc = test(val_loader)
    inference_times.append(time.perf_counter() - inference_start)

    print(f'Epoch {epoch:02d}, Loss: {loss:.4f}, ',
          f'Train Time: {train_times[-1]:.4f}s')
    print(f'Val: {val_acc * 100.0:.2f}%,')
    wandb.log({"Train Acc": 100*acc})
    wandb.log({"Val Acc": 100*val_acc})
    wandb.log({"Loss": loss})
    wandb.log({"Epoch": epoch})

    if val_acc > best_val:
        best_val = val_acc
    times.append(time.perf_counter() - train_start)
    for param_group in optimizer.param_groups:
        print('lr:')
        print(param_group['lr'])
    scheduler.step(val_acc)

print(f'Average Epoch Time on training: '
      f'{torch.tensor(train_times).mean():.4f}s')
print(f'Average Epoch Time on inference: '
      f'{torch.tensor(inference_times).mean():.4f}s')
print(f'Average Epoch Time: {torch.tensor(times).mean():.4f}s')
print(f'Median Epoch Time: {torch.tensor(times).median():.4f}s')
print(f'Best Validation Accuracy: {100.0 * best_val:.2f}%')

print('Testing...')
test_final_acc = test(test_loader)
print(f'Test Accuracy: {100.0 * test_final_acc:.2f}%')
print(f'Total Program Runtime: '
      f'{time.perf_counter() - wall_clock_start:.4f}s')
wandb.log({"Test Acc": 100*test_final_acc})