import argparse
import random
import warnings

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.utils import (
    add_self_loops,
    remove_self_loops,
    subgraph,
    to_undirected,
)

from data_utils import adj_mul, load_fixed_splits
from dataset import load_dataset
from eval import eval_acc, eval_f1, eval_rocauc, evaluate_cpu
from logger import Logger
from ood import get_roc
from parse import parse_method, parser_add_main_args
from sparse_modules import calculate_nm_sparsity, get_threshold

warnings.filterwarnings("ignore")


# NOTE: for consistent data splits, see data_utils.rand_train_test_idx
def fix_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True


# Parse args
parser = argparse.ArgumentParser(description="General Training Pipeline")
parser_add_main_args(parser)
args = parser.parse_args()
print(args)

fix_seed(args.seed)

if args.cpu:
    device = torch.device("cpu")
else:
    device = (
        torch.device("cuda:" + str(args.device))
        if torch.cuda.is_available()
        else torch.device("cpu")
    )

# Load and preprocess data
dataset, _ = load_dataset(args.data_dir, args.dataset, args.sub_dataset, args)

if len(dataset.label.shape) == 1:
    dataset.label = dataset.label.unsqueeze(1)

# get the splits for all runs
if args.rand_split:
    split_idx_lst = [
        dataset.get_idx_split(
            train_prop=args.train_prop, valid_prop=args.valid_prop
        )
        for _ in range(args.runs)
    ]
elif args.rand_split_class:
    split_idx_lst = [
        dataset.get_idx_split(
            split_type="class", label_num_per_class=args.label_num_per_class
        )
        for _ in range(args.runs)
    ]
elif args.dataset in [
    "ogbn-proteins",
    "ogbn-arxiv",
    "ogbn-products",
    "amazon2m",
]:
    split_idx_lst = [dataset.load_fixed_splits() for _ in range(args.runs)]
else:
    split_idx_lst = load_fixed_splits(
        args.data_dir, dataset, dataset=args.dataset, protocol=args.protocol
    )

n = dataset.graph["num_nodes"]
# infer the number of classes for non one-hot and one-hot labels
c = max(dataset.label.max().item() + 1, dataset.label.shape[1])
d = dataset.graph["node_feat"].shape[1]

# whether or not to symmetrize
if not args.directed and args.dataset != "ogbn-proteins":
    dataset.graph["edge_index"] = to_undirected(dataset.graph["edge_index"])

edge_index, x = dataset.graph["edge_index"], dataset.graph["node_feat"]

print(
    f"num nodes {n} | num edges {edge_index.size(1)} | num classes {c} | num node feats {d}"
)

# Load method
model = parse_method(args, dataset, n, c, d, device)

# Loss function (Single-class, Multi-class)
if args.dataset in (
    "yelp-chi",
    "deezer-europe",
    "twitch-e",
    "fb100",
    "ogbn-proteins",
):
    criterion = nn.BCEWithLogitsLoss()
else:
    criterion = nn.NLLLoss()

# Performance metric (Acc, AUC, F1)
if args.metric == "rocauc":
    eval_func = eval_rocauc
elif args.metric == "f1":
    eval_func = eval_f1
else:
    eval_func = eval_acc

logger = Logger(args.runs, args)

model.train()
print("MODEL:", model)

adjs = []
adj, _ = remove_self_loops(edge_index)
adj, _ = add_self_loops(adj, num_nodes=n)
adj = adj.to(device)
adjs.append(adj)
for i in range(args.rb_order - 1):  # edge_index of high order adjacency
    adj = adj_mul(adj, adj, n)
    adj = adj.to(device)
    adjs.append(adj)
dataset.graph["adjs"] = adjs

if args.dataset in (
    "yelp-chi",
    "deezer-europe",
    "twitch-e",
    "fb100",
    "ogbn-proteins",
):
    if dataset.label.shape[1] == 1:
        true_label = F.one_hot(dataset.label, dataset.label.max() + 1).squeeze(
            1
        )
    else:
        true_label = dataset.label

# Training loop
for run in range(args.runs):
    if (
        args.dataset in ["cora", "citeseer", "pubmed"]
        and args.protocol == "semi"
    ):
        split_idx = split_idx_lst[0]
    else:
        split_idx = split_idx_lst[run]
    train_idx = split_idx["train"].to(device)
    model.reset_parameters()
    optimizer = torch.optim.Adam(
        model.parameters(), weight_decay=args.weight_decay, lr=args.lr
    )
    scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, milestones=[100, 200], gamma=0.5
    )
    best_val = float("-inf")
    num_batch = train_idx.size(0) // args.batch_size + 1

    for epoch in range(args.epochs):
        model.to(device)
        model.train()

        if args.dataset in (
            "yelp-chi",
            "deezer-europe",
            "twitch-e",
            "fb100",
            "ogbn-proteins",
        ):
            true_label = true_label.to(device)
        else:
            dataset.label = dataset.label.to(device)

        if args.sm or args.mm:
            threshold = get_threshold(model, epoch, args)
            sparsity = None
        else:
            threshold = None
            sparsity = None

        idx = torch.randperm(train_idx.size(0))
        for i in range(num_batch):
            idx_i = train_idx[
                idx[i * args.batch_size : (i + 1) * args.batch_size]
            ]  # noqa
            x = x.to(device)
            x_i = x[idx_i].to(device)
            adjs_i = []
            edge_index_i, _ = subgraph(
                idx_i, adjs[0], num_nodes=n, relabel_nodes=True
            )
            adjs_i.append(edge_index_i.to(device))
            for k in range(args.rb_order - 1):
                edge_index_i, _ = subgraph(
                    idx_i, adjs[k + 1], num_nodes=n, relabel_nodes=True
                )
                adjs_i.append(edge_index_i.to(device))
            optimizer.zero_grad()
            out_i, link_loss_ = model(
                x_i,
                adjs_i,
                args.tau,
                threshold=threshold,
                args=args,
                sparsity=sparsity,
            )
            if args.dataset in (
                "yelp-chi",
                "deezer-europe",
                "twitch-e",
                "fb100",
                "ogbn-proteins",
            ):
                loss = criterion(
                    out_i, true_label.squeeze(1)[idx_i].to(torch.float)
                )
            else:
                out_i = F.log_softmax(out_i, dim=1)
                loss = criterion(out_i, dataset.label.squeeze(1)[idx_i])
            loss -= args.lamda * sum(link_loss_) / len(link_loss_)
            loss.backward()
            optimizer.step()
            if args.dataset == "ogbn-proteins":
                scheduler.step()

        if epoch % 9 == 0:

            if args.sm or args.mm:
                threshold = get_threshold(model, args.epochs, args)
                sparsity = None
            else:
                threshold = None
                sparsity = None

            if args.ood:
                result = get_roc(
                    model,
                    dataset,
                    threshold=threshold,
                    args=args,
                    sparsity=sparsity,
                )
            else:

                result = evaluate_cpu(
                    model,
                    dataset,
                    split_idx,
                    eval_func,
                    criterion,
                    threshold=threshold,
                    args=args,
                    sparsity=sparsity,
                )
                logger.add_result(run, result[:-1])

                if result[1] > best_val:
                    best_val = result[1]
                    if args.save_model:
                        torch.save(
                            model.state_dict(),
                            args.model_dir
                            + f"{args.dataset}-{args.method}.pkl",
                        )

                print(
                    f"Epoch: {epoch:02d}, "
                    f"Loss: {loss:.4f}, "
                    f"Train: {100 * result[0]:.2f}%, "
                    f"Valid: {100 * result[1]:.2f}%, "
                    f"Test: {100 * result[2]:.2f}%"
                )
        logger.print_statistics(run)

results = logger.print_statistics()
