"""
code to produce ckpts for TOPGQ.
"""

"""
This code provides FP32 ckpts for citation datasets
Cora, CiteSeer, PubMed
"""

import os
import sys
import time
import argparse
from collections import OrderedDict
from tqdm import tqdm

parser = argparse.ArgumentParser()
# experiment
parser.add_argument("--dataset", type=str, default="cora") # ['cora', 'citeseer', 'pubmed']
parser.add_argument("--model", type=str, default="GCN", help="model")
parser.add_argument("--num_layers", type=int, default=2)
parser.add_argument("--seed", type=int, default=123)
parser.add_argument("--epochs", type=int, default=200)
parser.add_argument("--batch_size", type=int, default=128)
parser.add_argument("--lr", type=float, default=0.01)
parser.add_argument("--wd", type=float, default=4e-5)
parser.add_argument("--lr_decay_factor", type=float, default=0.5)
parser.add_argument("--lr_decay_step_size", type=int, default=25)
parser.add_argument("--dropout", type=float, default=0.1)
# other configurations
parser.add_argument(
    "--path", type=str, default="/datasets/citation/", help="where all datasets live")
parser.add_argument("--cpu", action="store_true")
parser.add_argument("--check_time", action="store_true")
parser.add_argument("--save_model", action="store_true", help="save_model")
parser.add_argument(
    "--save_path", type=str, default="./ckpts/", help="where the model ckpts live")

args = parser.parse_args()

import torch
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torch.nn.functional as F

print(args)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)

from models.models_FP32 import GCN, GAT, GIN, GraphSAGE
from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T
from torch_geometric.loader import NeighborSampler


def count_parameters(model):
    for name, p in model.named_parameters():
         print(name, p.shape)
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def accuracy(output, labels):
    preds = output.max(1)[1].type_as(labels)
    correct = preds.eq(labels).double()
    correct = correct.sum()
    return correct / len(labels)

save_path = args.save_path + args.dataset
print("saving ckpt to ... {}".format(save_path))


if args.cpu:
    device = torch.device("cpu")
    print("torch device: {}".format(device))
else:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

dataset = Planetoid(args.path+args.dataset, args.dataset, transform=T.NormalizeFeatures())
print("Dataset: {}".format(args.dataset))
q_group = [None, None, None]

# MODEL SELECTION
sparse_check = False
if args.model == "GCN": arch, hidden = GCN, 16
elif args.model == "GAT": arch, hidden = GAT, 8
elif args.model == "GIN": arch, hidden = GIN, 16
elif args.model == "GS": 
    arch, hidden = GraphSAGE, 16
    sparse_check = True
    train_sizes = [25, 10] if args.num_layers == 2 else [5] * args.num_layers
    train_loader = NeighborSampler(dataset[0].edge_index, sizes=train_sizes, batch_size=dataset[0].num_nodes,
                               shuffle=True, num_nodes=dataset[0].num_nodes)
    test_loader = NeighborSampler(dataset[0].edge_index, sizes=[-1]*args.num_layers, batch_size=dataset[0].num_nodes,
                               shuffle=False, num_nodes=dataset[0].num_nodes)
    print("Neighborsampler generated for GS")
residual = True if args.num_layers > 2 else False

model = arch(
    dataset,
    num_layers=args.num_layers,
    hidden=hidden,
    graph_level=False,
    residual=residual,
    # quantization param
    device=device,
    qtype=32,
    momentum=False,
    dropout=args.dropout,
    # SAGE
    sparse_check=sparse_check
)

print(f"model has {count_parameters(model)} parameters")
model.to(device).reset_parameters()
data = dataset[0].to(device)
optimizer = Adam(model.parameters(),
                    lr=args.lr, weight_decay=args.wd)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=args.lr_decay_factor, 
                              patience=args.lr_decay_step_size, min_lr=0.00001, verbose=True)
t = tqdm(total=args.epochs, initial=1)  

# TRAIN START
best_acc = 0
best_loss = 1000000
durations = 0
for epoch in range(1, args.epochs + 1):
    # TRAIN
    model.train()
    if args.model == "GS":
        total_loss = 0
        total_acc = 0
        for batch_size, n_id, adjs in train_loader:
            y = data.y[n_id[:batch_size]]
            train_mask = data.train_mask[n_id[:batch_size]]
            adjs = [adj.to(device) for adj in adjs]
            t_start = time.perf_counter()
            optimizer.zero_grad()
            out = model((data.x[n_id], adjs), q_group)
            out = out[:batch_size]
            loss_train = F.nll_loss(out[train_mask], y[train_mask])
            acc_train = accuracy(out[train_mask], y[train_mask])
            loss_train.backward()
            optimizer.step()
            t_end = time.perf_counter()
            total_loss += loss_train.detach()
            total_acc += acc_train.detach()

        loss_train = total_loss / len(train_loader)
        acc_train = total_acc / len(train_loader)
    else: 
        t_start = time.perf_counter()
        optimizer.zero_grad()
        output = model(data, q_group)
        loss_train = F.nll_loss(output[data.train_mask], data.y[data.train_mask])
        acc_train = accuracy(output[data.train_mask], data.y[data.train_mask])
        loss_train.backward()
        optimizer.step()
        t_end = time.perf_counter()

    duration = t_end - t_start
    durations += duration
    
    # EVAL
    model.eval()
    with torch.no_grad():
        if args.model == "GS":
            for batch_size, n_id, adjs in test_loader:
                y = data.y[n_id[:batch_size]]
                val_mask = data.val_mask[n_id[:batch_size]]
                test_mask = data.test_mask[n_id[:batch_size]]
                adjs = [adj.to(device) for adj in adjs]
                out = model((data.x[n_id], adjs), q_group)
                loss_val = F.nll_loss(out[val_mask], y[val_mask])
                acc_val = accuracy(out[val_mask], y[val_mask])
                acc_test = accuracy(out[test_mask], y[test_mask])        
        else:
            output = model(data, q_group)
            loss_val = F.nll_loss(output[data.val_mask], data.y[data.val_mask])
            acc_val = accuracy(output[data.val_mask], data.y[data.val_mask])
            acc_test = accuracy(output[data.test_mask], data.y[data.test_mask])
    
    scheduler.step(loss_val)
    if loss_val < best_loss:
        best_loss = loss_val
        if acc_test > best_acc:
            best_acc = acc_test.item()
            if args.save_model: 
                layer_info = "" if args.num_layers == 2 else "{}LAYER_".format(args.num_layers) 
                pt_name = f'{args.model}_{layer_info}FP32_{args.dataset}_{acc_test:.4f}.pt'
                torch.save(model.state_dict(), os.path.join(save_path, pt_name))
            print("checkpoint/data successfully updated, acc: {}".format(acc_test))


    t.set_postfix(
    {
        "Train_Loss": "{:05.3f}".format(loss_train.item()),
        "Val_Loss": "{:05.3f}".format(loss_val.item()),
        "Train_Acc": "{:05.3f}".format(acc_train.item()),
        "Val_Acc": "{:05.3f}".format(acc_val.item()),

    })
    t.update(1)

if torch.cuda.is_available():
    torch.cuda.synchronize()

print("Best_Acc Checkpointed: {}".format(best_acc))




