#!/usr/bin/env python
# -*- coding: utf-8 -*-

import argparse
import math
import time
import torch
import numpy as np
import torch as th
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from matplotlib import pyplot as plt
from matplotlib.ticker import AutoMinorLocator, MultipleLocator
from models_calib import GCN
from ogb.nodeproppred import DglNodePropPredDataset, Evaluator
from mygraphconv_calib import MyGraphConv


def gen_model(args):
    model = GCN(
        in_feats,
        args.n_hidden,
        n_classes,
        args.n_layers,
        F.relu,
        args.dropout,
        args.use_linear,
    )
    return model


def compute_acc(pred, labels, evaluator):
    return evaluator.eval(
        {"y_pred": pred.argmax(dim=-1, keepdim=True), "y_true": labels}
    )["acc"]


@th.no_grad()
def evaluate(
    model, graph, labels, test_idx, evaluator
):

    model.eval()
    feat = graph.ndata["feat"]
    pred = model(graph, feat)

    return compute_acc(pred[test_idx], labels[test_idx], evaluator)


def run(
    args, graph, labels, test_idx, evaluator
):
    model_a = gen_model(args)
    model_a.load_state_dict(torch.load('pretrained_model_a.pt', map_location='cuda'), strict=False)
    model_a = model_a.to(device)

    model_b = gen_model(args)
    model_b.load_state_dict(torch.load('pretrained_model_b.pt', map_location='cuda'), strict=False)
    model_b = model_b.to(device)

    model_avg = gen_model(args)
    model_avg.load_state_dict(torch.load('child_model_pmc.pt', map_location='cuda'), strict=False)
    model_avg = model_avg.to(device)

    feat = graph.ndata["feat"]

    model_a.train()
    model_b.train()

    any(model_a.reset_message_stats(module) for module in model_a.modules() if isinstance(module, nn.BatchNorm1d))
    any(model_b.reset_message_stats(module) for module in model_b.modules() if isinstance(module, nn.BatchNorm1d))

    with torch.no_grad():
        _ = model_a(graph, feat)
        _ = model_b(graph, feat)

    for conv_a, conv_avg, conv_b in zip(model_a.modules(), model_avg.modules(), model_b.modules()):
        if isinstance(conv_a, MyGraphConv):
            conv_avg.calibrate_message_stats(
                parent_mean=((conv_a.lfnorm.running_mean + conv_b.lfnorm.running_mean) / 2),
                parent_var=(((conv_a.lfnorm.running_var.sqrt() + conv_b.lfnorm.running_var.sqrt()) / 2) ** 2)
            ).enable_calibrating()  

    model_avg.train()

    any(model_avg.reset_message_stats(module) for module in model_avg.modules() if isinstance(module, nn.BatchNorm1d))

    with torch.no_grad():
        _ = model_avg(graph, feat)

    test_acc = evaluate(model_avg, graph, labels, test_idx, evaluator)

    print("Test Accs (CMC): ", test_acc) 


def count_parameters(args):
    model = gen_model(args)
    return sum(
        [np.prod(p.size()) for p in model.parameters() if p.requires_grad]
    )


def main():
    global device, in_feats, n_classes

    argparser = argparse.ArgumentParser(
        "GCN on OGBN-Arxiv",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    argparser.add_argument(
        "--cpu",
        action="store_true",
        help="CPU mode. This option overrides --gpu.",
    )
    argparser.add_argument("--gpu", type=int, default=0, help="GPU device ID.")
    argparser.add_argument(
        "--n-runs", type=int, default=1, help="running times"
    )
    argparser.add_argument(
        "--n-epochs", type=int, default=1000, help="number of epochs"
    )
    argparser.add_argument(
        "--use-linear", action="store_true", help="Use linear layer."
    )
    argparser.add_argument(
        "--lr", type=float, default=0.005, help="learning rate"
    )
    argparser.add_argument(
        "--n-layers", type=int, default=5, help="number of layers"
    )
    argparser.add_argument(
        "--n-hidden", type=int, default=256, help="number of hidden units"
    )
    argparser.add_argument(
        "--dropout", type=float, default=0, help="dropout rate"
    )
    argparser.add_argument("--wd", type=float, default=0, help="weight decay")
    argparser.add_argument(
        "--log-every", type=int, default=20, help="log every LOG_EVERY epochs"
    )
    argparser.add_argument(
        "--plot-curves", action="store_true", help="plot learning curves"
    )
    args = argparser.parse_args()

    if args.cpu:
        device = th.device("cpu")
    else:
        device = th.device("cuda:%d" % args.gpu)

    # load data
    data = DglNodePropPredDataset(name="ogbn-arxiv")
    evaluator = Evaluator(name="ogbn-arxiv")

    graph, labels = data[0]

    # add reverse edges
    srcs, dsts = graph.all_edges()
    graph.add_edges(dsts, srcs)

    # add self-loop
    print(f"Total edges before adding self-loop {graph.num_edges()}")
    graph = graph.remove_self_loop().add_self_loop()
    print(f"Total edges after adding self-loop {graph.num_edges()}")

    in_feats = graph.ndata["feat"].shape[1]
    n_classes = (labels.max() + 1).item()
    graph.create_formats_()

    labels = labels.to(device)
    graph = graph.to(device)

    test_idx = torch.load('test_idx_1.pt')
    test_idx = test_idx.to(device)

    run(
        args, graph, labels, test_idx, evaluator
        )

    test_idx = torch.load('test_idx_2.pt')
    test_idx = test_idx.to(device)

    run(
        args, graph, labels, test_idx, evaluator
        )

if __name__ == "__main__":
    main()
