"""
This code provides PTQ results for citation datasets - no gradient-based updates 
PTQ method: TOPG
datasets: Cora, CiteSeer, PubMed
"""

import os
import sys
import time
import argparse
from collections import OrderedDict
from tqdm import tqdm

parser = argparse.ArgumentParser()
# experiment
parser.add_argument("--dataset", type=str, default="pubmed") # ['cora', 'citeseer', 'pubmed']
parser.add_argument("--model", type=str, default="GCN", help="model")
parser.add_argument("--num_layers", type=int, default=2)
parser.add_argument("--seed", type=int, default=123)
# other configurations
parser.add_argument(
    "--path", type=str, default="/datasets/citation/", help="where all datasets live")
parser.add_argument("--cpu", action="store_true")
parser.add_argument("--db_name", type=str, default=None)
parser.add_argument("--check_time", action="store_true")
# PTQ configurations
parser.add_argument("--bit", type=int, default=4)
parser.add_argument("--model_ckpt_path", type=str, default="./ckpts/", help="model_ckpt_path")
parser.add_argument("--model_ckpt_acc", type=str, default="0.7700", help="model_ckpt_path")
# TOPG WIENER configurations
parser.add_argument("--use_wc", action="store_true")
parser.add_argument("--k", type=int, default=2)
parser.add_argument("--percentile", type=float, default=-1)


args = parser.parse_args()

import torch
import torch.nn.functional as F

print(args)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)

def count_parameters(model):
    for p in model.parameters():
        print(p.shape)
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def accuracy(output, labels):
    preds = output.max(1)[1].type_as(labels)
    correct = preds.eq(labels).double()
    correct = correct.sum()
    return correct / len(labels)


from models.models_absorption import GCN, GAT, GIN, GraphSAGE

from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T
from torch_geometric.loader import NeighborSampler
from utils.wiener_allocation import WienerBook


ckpt_path = args.model_ckpt_path

if args.cpu:
    device = torch.device("cpu")
    print("torch device: {}".format(device))
else:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

dataset = Planetoid(args.path+args.dataset, args.dataset, transform=T.NormalizeFeatures())
print("Dataset: {}".format(args.dataset))

# MODEL SELECTION
sparse_check = False
if args.model == "GCN": arch, hidden = GCN, 16
elif args.model == "GAT": arch, hidden = GAT, 8
elif args.model == "GIN": arch, hidden = GIN, 16
elif args.model == "GS": 
    arch, hidden = GraphSAGE, 16
    sparse_check = True 
    

model = arch(
    dataset,
    num_layers=args.num_layers,
    hidden=hidden,
    graph_level=False,
    # quantization param
    device=device,
    qtype=args.bit,
    momentum=False,
    use_wc=args.use_wc,
    # SAGE
    sparse_check=sparse_check,
    percentile=args.percentile
)
print(f"model has {count_parameters(model)} parameters")
# get FP32 ckpts
layer_info = "" if args.num_layers == 2 else "_{}LAYER".format(args.num_layers) 
model_ckpt = args.model + layer_info + "_FP32_" + args.dataset + "_" + args.model_ckpt_acc + ".pt"
state = torch.load(os.path.join(ckpt_path+args.dataset, model_ckpt), map_location=device)
new_state = OrderedDict()
for key, value in state.items():
    if "table" in key:
        continue
    elif "min" in key:
        continue
    elif "max" in key:
        continue
    else:
        new_key = key
    new_state[new_key] = value
model.load_state_dict(new_state, strict=False)
model.reset_quantizers(args.bit, False, False, device)
model.to(device)
model.eval()
print("checkpoint has been succcessfully maintained, {}".format(model_ckpt))


# get wiener index
alloc = WienerBook(dataset_name=args.dataset, k=args.k)
t_start = time.perf_counter()
dataset = alloc.get_wiener_processed_dataset_node_level(dataset, dataset[0].train_mask, idx="train")
t_end = time.perf_counter()
q_group = [torch.zeros(dataset[0].num_nodes).to(torch.bool), dataset[0].group_num, dataset[0].indegree]
q_group = [q.to(device) for q in q_group]
data = dataset[0].to(device)
train_data = dataset[0].to(device)
train_q_group = q_group
if args.model == "GS":
    train_loader = NeighborSampler(dataset[0].edge_index, sizes=[25, 10], batch_size=dataset[0].num_nodes,
                            shuffle=True, num_nodes=dataset[0].num_nodes)
    test_loader = NeighborSampler(dataset[0].edge_index, sizes=[-1, -1], batch_size=dataset[0].num_nodes,
                            shuffle=False, num_nodes=dataset[0].num_nodes)
    print("Neighborsampler generated for GS")
data = dataset[0].to(device)
train_data = train_data.to(device)


t = tqdm(total=1, initial=1)  

# TRAIN START
best_acc = 0
durations = 0
model.unfreeze_quantization_parameters()
for epoch in range(1):
    total_loss = 0
    total_acc = 0
    # TRAIN SCALES - CALIBRATE
    with torch.no_grad():
        if args.model == "GS":
            for batch_size, n_id, adjs in train_loader:
                # `adjs` holds a list of `(edge_index, e_id, size)` tuples.
                y = train_data.y[n_id[:batch_size]]
                train_mask = train_data.train_mask[n_id[:batch_size]]
                adjs = [adj.to(device) for adj in adjs]
                t_start = time.perf_counter()
                out = model((train_data.x[n_id], adjs), train_q_group)
                t_end = time.perf_counter()
                out = out[:batch_size]
                loss_train = F.nll_loss(out[train_mask], y[train_mask])
                acc_train = accuracy(out[train_mask], y[train_mask])
                total_loss += loss_train.detach()
                total_acc += acc_train.detach()
            loss_train = total_loss / len(train_loader)
            acc_train = total_acc / len(train_loader)
        else: 
            t_start = time.perf_counter()
            output = model(train_data, train_q_group)
            loss_train = F.nll_loss(output[train_data.train_mask], train_data.y[train_data.train_mask])
            acc_train = accuracy(output[train_data.train_mask], train_data.y[train_data.train_mask])
            t_end = time.perf_counter()
    duration = t_end - t_start
    durations += duration
    
    model.freeze_quantization_parameters()
    # EVAL
    model.eval()
    with torch.no_grad():
        if args.model == "GS":
            for batch_size, n_id, adjs in test_loader:
                y = data.y[n_id[:batch_size]]
                val_mask = data.val_mask[n_id[:batch_size]]
                test_mask = data.test_mask[n_id[:batch_size]]
                adjs = [adj.to(device) for adj in adjs]
                out = model((data.x[n_id], adjs), q_group)
                loss_val = F.nll_loss(out[val_mask], y[val_mask])
                acc_val = accuracy(out[val_mask], y[val_mask])
                acc_test = accuracy(out[test_mask], y[test_mask])        
        else:
            output = model(data, q_group)
            loss_val = F.nll_loss(output[data.val_mask], data.y[data.val_mask])
            acc_val = accuracy(output[data.val_mask], data.y[data.val_mask])
            acc_test = accuracy(output[data.test_mask], data.y[data.test_mask])
        
    
    if acc_test > best_acc:
        best_acc = acc_test.item()
        print("checkpoint/data successfully updated, acc: {}".format(acc_test))

    t.set_postfix(
    {
        "Train_Loss": "{:05.3f}".format(loss_train.item()),
        "Val_Loss": "{:05.3f}".format(loss_val.item()),
        "Train_Acc": "{:05.3f}".format(acc_train.item()),
        "Val_Acc": "{:05.3f}".format(acc_val.item()),

    })
    t.update(1)

if torch.cuda.is_available():
    torch.cuda.synchronize()
