import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
from sklearn.manifold import TSNE
from sklearn.metrics import f1_score, roc_auc_score


def eval_f1(y_true, y_pred):
    acc_list = []
    y_true = y_true.detach().cpu().numpy()
    y_pred = y_pred.argmax(dim=-1, keepdim=True).detach().cpu().numpy()
    for i in range(y_true.shape[1]):
        f1 = f1_score(y_true, y_pred, average="micro")
        acc_list.append(f1)
    return sum(acc_list) / len(acc_list)


def eval_acc(y_true, y_pred):
    acc_list = []
    y_true = y_true.detach().cpu().numpy()
    y_pred = y_pred.argmax(dim=-1, keepdim=True).detach().cpu().numpy()
    for i in range(y_true.shape[1]):
        is_labeled = y_true[:, i] == y_true[:, i]
        correct = y_true[is_labeled, i] == y_pred[is_labeled, i]
        acc_list.append(float(np.sum(correct)) / len(correct))
    return sum(acc_list) / len(acc_list)


def eval_rocauc(y_true, y_pred):
    rocauc_list = []
    y_true = y_true.detach().cpu().numpy()
    if y_true.shape[1] == 1:
        y_pred = F.softmax(y_pred, dim=-1)[:, 1].unsqueeze(1).cpu().numpy()
    else:
        y_pred = y_pred.detach().cpu().numpy()

    for i in range(y_true.shape[1]):
        if np.sum(y_true[:, i] == 1) > 0 and np.sum(y_true[:, i] == 0) > 0:
            is_labeled = y_true[:, i] == y_true[:, i]
            score = roc_auc_score(y_true[is_labeled, i], y_pred[is_labeled, i])
            rocauc_list.append(score)

    if len(rocauc_list) == 0:
        raise RuntimeError(
            "No positively labeled data available. Cannot compute ROC-AUC."
        )

    return sum(rocauc_list) / len(rocauc_list)


@torch.no_grad()
def evaluate(
    model, dataset, split_idx, eval_func, criterion, threshold, args, sparsity
):
    model.eval()
    if args.method == "nodeformer":
        out, _ = model(
            dataset.graph["node_feat"],
            dataset.graph["adjs"],
            args.tau,
            threshold,
            args,
            sparsity,
        )
    else:
        out = model(dataset)

    train_acc = eval_func(
        dataset.label[split_idx["train"]], out[split_idx["train"]]
    )
    valid_acc = eval_func(
        dataset.label[split_idx["valid"]], out[split_idx["valid"]]
    )
    test_acc = eval_func(
        dataset.label[split_idx["test"]], out[split_idx["test"]]
    )

    if args.dataset in (
        "yelp-chi",
        "deezer-europe",
        "twitch-e",
        "fb100",
        "ogbn-proteins",
    ):
        if dataset.label.shape[1] == 1:
            true_label = F.one_hot(
                dataset.label, dataset.label.max() + 1
            ).squeeze(1)
        else:
            true_label = dataset.label
        valid_loss = criterion(
            out[split_idx["valid"]],
            true_label.squeeze(1)[split_idx["valid"]].to(torch.float),
        )
    else:
        out = F.log_softmax(out, dim=1)
        valid_loss = criterion(
            out[split_idx["valid"]],
            dataset.label.squeeze(1)[split_idx["valid"]],
        )

    return train_acc, valid_acc, test_acc, valid_loss, out


@torch.no_grad()
def run_tsne(model, dataset, split_idx, args, threshold=None, sparsity=None):
    model.eval()
    if args.method == "nodeformer":
        out, _ = model(
            dataset.graph["node_feat"],
            dataset.graph["adjs"],
            args.tau,
            threshold,
            args,
            sparsity,
        )
    else:
        out = model(dataset)
    test_embeddings = out[split_idx["test"]].cpu().numpy()
    test_labels = dataset.label[split_idx["test"]].cpu().numpy()

    tsne = TSNE(n_components=2, random_state=42)
    embeddings_2d = tsne.fit_transform(test_embeddings)

    plt.figure(figsize=(10, 8))
    scatter = plt.scatter(
        embeddings_2d[:, 0],
        embeddings_2d[:, 1],
        c=test_labels,
        cmap="tab10",
        s=10,
    )
    plt.colorbar(scatter, label="Classes")
    plt.title(f"TSNE Visualization of {args.dataset} Test Data")
    plt.xlabel("TSNE Dimension 1")
    plt.ylabel("TSNE Dimension 2")

    save_path = "tsne.png"
    plt.savefig(save_path)

    plt.show()


@torch.no_grad()
def evaluate_cpu(
    model,
    dataset,
    split_idx,
    eval_func,
    criterion,
    threshold,
    args,
    sparsity,
    result=None,
):
    model.eval()

    model.to(torch.device("cpu"))
    dataset.label = dataset.label.to(torch.device("cpu"))
    adjs_, x = dataset.graph["adjs"], dataset.graph["node_feat"]
    adjs = []
    adjs.append(adjs_[0])
    for k in range(args.rb_order - 1):
        adjs.append(adjs_[k + 1])
    out, _ = model(x, adjs, args.tau, threshold, args, sparsity)

    train_acc = eval_func(
        dataset.label[split_idx["train"]], out[split_idx["train"]]
    )
    valid_acc = eval_func(
        dataset.label[split_idx["valid"]], out[split_idx["valid"]]
    )
    test_acc = eval_func(
        dataset.label[split_idx["test"]], out[split_idx["test"]]
    )
    if args.dataset in (
        "yelp-chi",
        "deezer-europe",
        "twitch-e",
        "fb100",
        "ogbn-proteins",
    ):
        if dataset.label.shape[1] == 1:
            true_label = F.one_hot(
                dataset.label, dataset.label.max() + 1
            ).squeeze(1)
        else:
            true_label = dataset.label
        valid_loss = criterion(
            out[split_idx["valid"]],
            true_label.squeeze(1)[split_idx["valid"]].to(torch.float),
        )
    else:
        out = F.log_softmax(out, dim=1)
        valid_loss = criterion(
            out[split_idx["valid"]],
            dataset.label.squeeze(1)[split_idx["valid"]],
        )

    return train_acc, valid_acc, test_acc, valid_loss, out
