import argparse
import datetime
import random
import warnings

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.neighbors import kneighbors_graph
from torch_geometric.utils import (
    add_self_loops,
    remove_self_loops,
    to_undirected,
)

from data_utils import adj_mul, load_fixed_splits
from dataset import load_dataset
from eval import eval_acc, eval_f1, eval_rocauc, evaluate, run_tsne
from logger import Logger
from ood import (
    get_ood_split_with_indices,
    get_roc,
    perturb_edges_with_split_indices,
    perturb_features_with_split_indices,
)
from parse import parse_method, parser_add_main_args
from sparse_modules import calculate_nm_sparsity, get_threshold

warnings.filterwarnings("ignore")


def fix_seed(seed):
    # random.seed(seed)
    # np.random.seed(seed)
    # torch.manual_seed(seed)
    # torch.cuda.manual_seed(seed)
    # torch.backends.cudnn.deterministic = True
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


fix_seed(0)
# Parse args
parser = argparse.ArgumentParser(description="General Training Pipeline")
parser_add_main_args(parser)
args = parser.parse_args()
print(args)

if args.cpu:
    device = torch.device("cpu")
else:
    device = (
        torch.device("cuda:" + str(args.device))
        if torch.cuda.is_available()
        else torch.device("cpu")
    )

# Load and preprocess data
if args.edge_perturb and args.dataset not in [
    "film",
    "deezer-europe",
    "20news",
    "mini",
]:
    dataset, perturb_dataset = load_dataset(
        args.data_dir, args.dataset, args.sub_dataset, args
    )
    perturb_dataset.graph["edge_index"] = perturb_dataset.graph[
        "edge_index"
    ].to(device)
else:
    dataset, _ = load_dataset(
        args.data_dir, args.dataset, args.sub_dataset, args
    )
if len(dataset.label.shape) == 1:
    dataset.label = dataset.label.unsqueeze(1)
dataset.label = dataset.label.to(device)
# get the splits for all runs
if args.rand_split:
    split_idx_lst = [
        dataset.get_idx_split(
            train_prop=args.train_prop, valid_prop=args.valid_prop
        )
        for _ in range(args.runs)
    ]
elif args.rand_split_class:
    split_idx_lst = [
        dataset.get_idx_split(
            split_type="class", label_num_per_class=args.label_num_per_class
        )
        for _ in range(args.runs)
    ]
elif args.dataset in [
    "ogbn-proteins",
    "ogbn-arxiv",
    "ogbn-products",
    "amazon2m",
]:
    split_idx_lst = [dataset.load_fixed_splits() for _ in range(args.runs)]
else:
    split_idx_lst = load_fixed_splits(
        args.data_dir, dataset, name=args.dataset, protocol=args.protocol
    )
if args.dataset in ("mini", "20news"):
    adj_knn = kneighbors_graph(
        dataset.graph["node_feat"], n_neighbors=args.knn_num, include_self=True
    )
    edge_index = torch.tensor(adj_knn.nonzero(), dtype=torch.long)
    dataset.graph["edge_index"] = edge_index

# Basic information of datasets
n = dataset.graph["num_nodes"]
e = dataset.graph["edge_index"].shape[1]
# infer the number of classes for non one-hot and one-hot labels
c = max(dataset.label.max().item() + 1, dataset.label.shape[1])
d = dataset.graph["node_feat"].shape[1]

print(
    f"dataset {args.dataset} | num nodes {n} | num edge {e} | num node feats {d} | num classes {c}"
)

# whether or not to symmetrize
if not args.directed and args.dataset != "ogbn-proteins":
    dataset.graph["edge_index"] = to_undirected(dataset.graph["edge_index"])

dataset.graph["edge_index"], dataset.graph["node_feat"] = dataset.graph[
    "edge_index"
].to(device), dataset.graph["node_feat"].to(device)


# Load method
model = parse_method(args, dataset, n, c, d, device)

# Loss function (Single-class, Multi-class)
if args.dataset in (
    "yelp-chi",
    "deezer-europe",
    "twitch-e",
    "fb100",
    "ogbn-proteins",
):
    criterion = nn.BCEWithLogitsLoss()
else:
    criterion = nn.NLLLoss()

# Performance metric (Acc, AUC, F1)
if args.metric == "rocauc":
    eval_func = eval_rocauc
elif args.metric == "f1":
    eval_func = eval_f1
else:
    eval_func = eval_acc

logger = Logger(args.runs, args)

model.train()
total_params = sum(p.numel() for p in model.parameters())

print(f"MODEL: {model}")
print(f"Total Parameters: {total_params}")

# Adj storage for relational bias
adjs = []
adj, _ = remove_self_loops(dataset.graph["edge_index"])
adj, _ = add_self_loops(adj, num_nodes=n)
adjs.append(adj)
for i in range(args.rb_order - 1):  # edge_index of high order adjacency
    adj = adj_mul(adj, adj, n)
    adjs.append(adj)
dataset.graph["adjs"] = adjs

if args.edge_perturb and args.dataset not in [
    "film",
    "deezer-europe",
    "20news",
    "mini",
]:
    # Adj storage for relational bias
    adjs = []
    adj, _ = remove_self_loops(perturb_dataset.graph["edge_index"])
    adj, _ = add_self_loops(adj, num_nodes=n)
    adjs.append(adj)
    for i in range(args.rb_order - 1):  # edge_index of high order adjacency
        adj = adj_mul(adj, adj, n)
        adjs.append(adj)
    perturb_dataset.graph["adjs"] = adjs


if args.dataset in ["film", "deezer-europe", "20news", "mini"]:
    if args.node_perturb:
        dataset, _ = perturb_features_with_split_indices(
            dataset,
            split_idx_lst,
            ood_budget_per_graph=args.ood_budget_per_graph,
            ood_noise_scale=args.ood_noise_scale,
            ood_perturbation_type=args.ood_perturbation_type,
        )
    elif args.edge_perturb:
        perturb_dataset = perturb_edges_with_split_indices(
            dataset,
            split_idx_lst,
            ood_budget_per_graph=args.ood_budget_per_graph,
            ood_noise_scale=args.ood_noise_scale,
            ood_perturbation_type=args.ood_perturbation_type,
        )
    elif args.ood:
        dataset, class_num = get_ood_split_with_indices(
            dataset, split_idx_lst, ood_num_left_out_classes=args.num_ood_class
        )
        args.num_classes = class_num

if args.load_model:
    fix_seed(args.seed + 0)
    if (
        args.dataset in ["cora", "citeseer", "pubmed"]
        and args.protocol == "semi"
    ):
        split_idx = split_idx_lst[0]
    else:
        split_idx = split_idx_lst[1]
    train_idx = split_idx["train"].to(device)
    model.reset_parameters()
    model.load_state_dict(torch.load(args.model_dir))
    model.eval()
    print(f"Loaded model from {args.model_dir}")
    optimizer = torch.optim.Adam(
        model.parameters(), weight_decay=args.weight_decay, lr=args.lr
    )
    if args.sm:
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, args.epochs
        )
    best_val = float("-inf")
    model.eval()
    optimizer.zero_grad()
    with torch.no_grad():
        model.eval()
        if args.sm or args.mm:
            threshold = get_threshold(model, args.epochs, args)
            sparsity = None
        else:
            threshold = None
            sparsity = None
        if args.ood:
            result = get_roc(
                model,
                dataset,
                threshold=threshold,
                args=args,
                sparsity=sparsity,
            )
        else:
            if args.edge_perturb and args.dataset not in [
                "film",
                "deezer-europe",
                "20news",
                "mini",
            ]:
                dataset.graph["adjs"] = perturb_dataset.graph["adjs"]
            elif args.edge_perturb and args.dataset in [
                "film",
                "deezer-europe",
                "20news",
                "mini",
            ]:
                dataset.graph["edge_index"] = perturb_dataset.graph[
                    "edge_index"
                ]
            result = evaluate(
                model,
                dataset,
                split_idx,
                eval_func,
                criterion,
                threshold=threshold,
                args=args,
                sparsity=sparsity,
            )
            logger.add_result(0, result[:-1])
            print(
                f"Run: {1}, "
                f"Train: {100 * result[0]:.2f}%, "
                f"Valid: {100 * result[1]:.2f}%, "
                f"Test: {100 * result[2]:.2f}%"
            )
    logger.print_statistics()
    exit()

# Training loop
for run in range(args.runs):
    fix_seed(args.seed + run)

    if (
        args.dataset in ["cora", "citeseer", "pubmed"]
        and args.protocol == "semi"
    ):
        split_idx = split_idx_lst[0]
    else:
        split_idx = split_idx_lst[run]
    train_idx = split_idx["train"].to(device)
    model.reset_parameters()
    optimizer = torch.optim.Adam(
        model.parameters(), weight_decay=args.weight_decay, lr=args.lr
    )
    if args.sm:
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, args.epochs
        )

    best_val = float("-inf")

    for epoch in range(args.epochs):
        model.train()
        optimizer.zero_grad()

        if args.sm or args.mm:
            threshold = get_threshold(model, epoch, args)
            sparsity = None
        else:
            threshold = None
            sparsity = None

        if args.method == "nodeformer":
            out, link_loss_ = model(
                dataset.graph["node_feat"],
                dataset.graph["adjs"],
                args.tau,
                threshold=threshold,
                args=args,
                sparsity=sparsity,
            )
        else:
            out = model(dataset)

        if args.dataset in (
            "yelp-chi",
            "deezer-europe",
            "twitch-e",
            "fb100",
            "ogbn-proteins",
        ):
            if dataset.label.shape[1] == 1:
                true_label = F.one_hot(
                    dataset.label, dataset.label.max() + 1
                ).squeeze(1)
            else:
                true_label = dataset.label
            loss = criterion(
                out[train_idx],
                true_label.squeeze(1)[train_idx].to(torch.float),
            )
        else:
            if (
                args.dataset in ["film", "deezer-europe", "20news", "mini"]
                and args.ood
            ):
                if dataset.label.dim() > 1 and dataset.label.size(1) == 1:
                    label = dataset.label.squeeze(1)
                else:
                    label = dataset.label
                out = F.log_softmax(out, dim=1)
                loss = criterion(out[train_idx], label[train_idx])
            else:
                out = F.log_softmax(out, dim=1)
                loss = criterion(
                    out[train_idx], dataset.label.squeeze(1)[train_idx]
                )

        if args.method == "nodeformer":
            loss -= args.lamda * sum(link_loss_) / len(link_loss_)
        loss.backward()
        optimizer.step()
        if args.sm:
            scheduler.step()

        with torch.no_grad():
            if epoch % args.eval_step == 0:
                model.eval()

                if args.sm or args.mm:
                    threshold = get_threshold(model, args.epochs, args)
                    sparsity = None
                else:
                    threshold = None
                    sparsity = None

                if args.ood:
                    result = get_roc(
                        model,
                        dataset,
                        threshold=threshold,
                        args=args,
                        sparsity=sparsity,
                    )
                else:
                    if args.edge_perturb and args.dataset not in [
                        "film",
                        "deezer-europe",
                        "20news",
                        "mini",
                    ]:
                        dataset.graph["adjs"] = perturb_dataset.graph["adjs"]
                    elif args.edge_perturb and args.dataset in [
                        "film",
                        "deezer-europe",
                        "20news",
                        "mini",
                    ]:
                        dataset.graph["edge_index"] = perturb_dataset.graph[
                            "edge_index"
                        ]

                    result = evaluate(
                        model,
                        dataset,
                        split_idx,
                        eval_func,
                        criterion,
                        threshold=threshold,
                        args=args,
                        sparsity=sparsity,
                    )
                    logger.add_result(run, result[:-1])

                    if result[1] > best_val:
                        best_val = result[1]
                        if args.save_model:
                            current_time = datetime.datetime.now().strftime(
                                "%Y%m%d-%H%M"
                            )
                            torch.save(
                                model.state_dict(),
                                args.model_dir
                                + f"{args.dataset}-{args.method}-{args.linear_sparsity}-{100 * result[2]:.2f}.pkl",
                            )
                        if args.tsne:
                            run_tsne(
                                model,
                                dataset,
                                split_idx,
                                args=args,
                                threshold=threshold,
                                sparsity=sparsity,
                            )

                    print(
                        f"Epoch: {epoch:02d}, "
                        f"Loss: {loss:.4f}, "
                        f"Train: {100 * result[0]:.2f}%, "
                        f"Valid: {100 * result[1]:.2f}%, "
                        f"Test: {100 * result[2]:.2f}%"
                    )

    logger.print_statistics(run)

results = logger.print_statistics()
