import argparse
import warnings

import numpy as np
import torch as th
import torch.nn as nn

warnings.filterwarnings("ignore")

from dataset import process_dataset
from model import MVGRL, LogReg

parser = argparse.ArgumentParser(description="mvgrl")

parser.add_argument(
    "--dataname", type=str, default="cora", help="Name of dataset."
)
parser.add_argument(
    "--gpu", type=int, default=0, help="GPU index. Default: -1, using cpu."
)
parser.add_argument("--epochs", type=int, default=500, help="Training epochs.")
parser.add_argument(
    "--patience",
    type=int,
    default=20,
    help="Patient epochs to wait before early stopping.",
)
parser.add_argument(
    "--lr1", type=float, default=0.001, help="Learning rate of mvgrl."
)
parser.add_argument(
    "--lr2", type=float, default=0.01, help="Learning rate of linear evaluator."
)
parser.add_argument(
    "--wd1", type=float, default=0.0, help="Weight decay of mvgrl."
)
parser.add_argument(
    "--wd2", type=float, default=0.0, help="Weight decay of linear evaluator."
)
parser.add_argument(
    "--epsilon",
    type=float,
    default=0.01,
    help="Edge mask threshold of diffusion graph.",
)
parser.add_argument(
    "--hid_dim", type=int, default=512, help="Hidden layer dim."
)

args = parser.parse_args()

# check cuda
if args.gpu != -1 and th.cuda.is_available():
    args.device = "cuda:{}".format(args.gpu)
else:
    args.device = "cpu"

if __name__ == "__main__":
    print(args)

    # Step 1: Prepare data =================================================================== #
    (
        graph,
        diff_graph,
        feat,
        label,
        train_idx,
        val_idx,
        test_idx,
        edge_weight,
    ) = process_dataset(args.dataname, args.epsilon)
    n_feat = feat.shape[1]
    n_classes = np.unique(label).shape[0]

    graph = graph.to(args.device)
    diff_graph = diff_graph.to(args.device)
    feat = feat.to(args.device)
    edge_weight = th.tensor(edge_weight).float().to(args.device)

    train_idx = train_idx.to(args.device)
    val_idx = val_idx.to(args.device)
    test_idx = test_idx.to(args.device)

    n_node = graph.number_of_nodes()
    lbl1 = th.ones(n_node * 2)
    lbl2 = th.zeros(n_node * 2)
    lbl = th.cat((lbl1, lbl2))

    # Step 2: Create model =================================================================== #
    model = MVGRL(n_feat, args.hid_dim)
    model = model.to(args.device)

    lbl = lbl.to(args.device)

    # Step 3: Create training components ===================================================== #
    optimizer = th.optim.Adam(
        model.parameters(), lr=args.lr1, weight_decay=args.wd1
    )
    loss_fn = nn.BCEWithLogitsLoss()

    # Step 4: Training epochs ================================================================ #
    best = float("inf")
    cnt_wait = 0
    for epoch in range(args.epochs):
        model.train()
        optimizer.zero_grad()

        shuf_idx = np.random.permutation(n_node)
        shuf_feat = feat[shuf_idx, :]
        shuf_feat = shuf_feat.to(args.device)

        out = model(graph, diff_graph, feat, shuf_feat, edge_weight)
        loss = loss_fn(out, lbl)

        loss.backward()
        optimizer.step()

        print("Epoch: {0}, Loss: {1:0.4f}".format(epoch, loss.item()))

        if loss < best:
            best = loss
            cnt_wait = 0
            th.save(model.state_dict(), "model.pkl")
        else:
            cnt_wait += 1

        if cnt_wait == args.patience:
            print("Early stopping")
            break

    model.load_state_dict(th.load("model.pkl"))
    embeds = model.get_embedding(graph, diff_graph, feat, edge_weight)

    train_embs = embeds[train_idx]
    test_embs = embeds[test_idx]

    label = label.to(args.device)
    train_labels = label[train_idx]
    test_labels = label[test_idx]
    accs = []

    # Step 5:  Linear evaluation ========================================================== #
    for _ in range(5):
        model = LogReg(args.hid_dim, n_classes)
        opt = th.optim.Adam(
            model.parameters(), lr=args.lr2, weight_decay=args.wd2
        )

        model = model.to(args.device)
        loss_fn = nn.CrossEntropyLoss()
        for epoch in range(300):
            model.train()
            opt.zero_grad()
            logits = model(train_embs)
            loss = loss_fn(logits, train_labels)
            loss.backward()
            opt.step()

        model.eval()
        logits = model(test_embs)
        preds = th.argmax(logits, dim=1)
        acc = th.sum(preds == test_labels).float() / test_labels.shape[0]
        accs.append(acc * 100)

    accs = th.stack(accs)
    print(accs.mean().item(), accs.std().item())
