import math
from typing import Optional

import numpy as np
import sklearn.metrics as sk
import torch


def get_perturbed_indices(
    data, ood_budget_per_graph, perturb_train_indices, **_
):
    if hasattr(data, "train_mask"):
        remaining_indices = (
            (~data.train_mask) & (~data.val_mask) & (~data.test_mask)
        )
        remaining_indices = remaining_indices.nonzero().squeeze().tolist()
        val_indices = data.val_mask.nonzero().squeeze().tolist()
        test_indices = data.test_mask.nonzero().squeeze().tolist()
        train_indices = data.train_mask.nonzero().squeeze().tolist()
        sample_indices = [val_indices, test_indices, remaining_indices]
        if perturb_train_indices:
            sample_indices.append(train_indices)
        ind_perturbed = []
        for indices in sample_indices:
            n_perturbed = int(len(indices) * ood_budget_per_graph)
            _ind_perturbed = np.random.choice(
                indices, n_perturbed, replace=False
            )
            ind_perturbed.extend(_ind_perturbed)
    else:
        indices = range(0, data.y.size(0))
        n_perturbed = int(len(indices) * ood_budget_per_graph)
        ind_perturbed = np.random.choice(
            indices, n_perturbed, replace=False
        ).tolist()
    return ind_perturbed


def get_perturbed_indices_with_split_indices(
    data, split_idx_lst, ood_budget_per_graph, perturb_train_indices=False, **_
):
    split_indices = split_idx_lst[0]
    assert (
        "train" in split_indices
        and "valid" in split_indices
        and "test" in split_indices
    )

    val_indices = split_indices["valid"].tolist()
    test_indices = split_indices["test"].tolist()
    train_indices = split_indices["train"].tolist()

    all_indices = set(range(data.label.size(0)))
    used_indices = set(val_indices + test_indices + train_indices)
    remaining_indices = list(all_indices - used_indices)

    sample_indices = [val_indices, test_indices, remaining_indices]
    if perturb_train_indices:
        sample_indices.append(train_indices)

    device = data.graph["node_feat"].device

    ind_perturbed = []
    for indices in sample_indices:
        n_perturbed = int(len(indices) * ood_budget_per_graph)
        if n_perturbed > 0:
            _ind_perturbed = np.random.choice(
                indices, n_perturbed, replace=False
            )
            ind_perturbed.extend(_ind_perturbed)

    ind_perturbed = torch.tensor(ind_perturbed, device=device)

    return ind_perturbed


def perturb_features(
    data,
    perturb_train_indices=False,
    ood_budget_per_graph=1.0,
    ood_noise_scale=1.0,
    ood_perturbation_type="bernoulli_0.5",
    ind_perturbed=None,
    **_,
):
    dim_features = data.x.size(1)
    data = data.clone()
    if ind_perturbed is None:
        ind_perturbed = get_perturbed_indices(
            data, ood_budget_per_graph, perturb_train_indices
        )
    n_perturbed = len(ind_perturbed)
    noise = torch.zeros((n_perturbed, dim_features))
    if ood_perturbation_type == "gaussian":
        noise = torch.randn((n_perturbed, dim_features))
        noise = ood_noise_scale * noise
    elif ood_perturbation_type == "flip":
        noise = data.x[ind_perturbed].clone()
        num_features = noise.size(1)
        for i in range(n_perturbed):
            num_flip_dimensions = int(ood_noise_scale * num_features)
            flip_indices = np.random.choice(
                num_features, size=num_flip_dimensions, replace=False
            )
            noise[i, flip_indices] = 1 - noise[i, flip_indices]
    else:
        raise ValueError(
            f"perturbation {ood_perturbation_type} is not supported!"
        )
    if ood_perturbation_type == "gaussian":
        data.x[ind_perturbed] = data.x[ind_perturbed] + noise
    else:
        data.x[ind_perturbed] = ood_noise_scale * noise
    ood_mask = torch.zeros_like(data.y, dtype=bool)
    ood_mask[ind_perturbed] = True
    id_mask = ~ood_mask
    data.ood_mask = ood_mask
    data.id_mask = id_mask
    condition = hasattr(data, "ood_val_mask")
    condition = condition & hasattr(data, "ood_test_mask")
    condition = condition & hasattr(data, "id_val_mask")
    condition = condition & hasattr(data, "id_test_mask")
    if condition:
        data.ood_val_mask = data.ood_val_mask | (ood_mask & data.val_mask)
        data.id_val_mask = data.id_val_mask | (id_mask & data.val_mask)
        data.ood_test_mask = data.ood_test_mask | (ood_mask & data.test_mask)
        data.id_test_mask = data.id_test_mask | (id_mask & data.test_mask)
    elif hasattr(data, "train_mask"):
        data.ood_val_mask = ood_mask & data.val_mask
        data.id_val_mask = id_mask & data.val_mask
        data.ood_test_mask = ood_mask & data.test_mask
        data.id_test_mask = id_mask & data.test_mask
    return data, ind_perturbed


def perturb_features_with_split_indices(
    data,
    split_idx_lst,
    perturb_train_indices=False,
    ood_budget_per_graph=1.0,
    ood_noise_scale=1.0,
    ood_perturbation_type="gaussian",
    ind_perturbed=None,
    **_,
):
    dim_features = data.graph["node_feat"].size(1)
    split_indices = split_idx_lst[0]
    assert (
        "train" in split_indices
        and "valid" in split_indices
        and "test" in split_indices
    )

    if ind_perturbed is None:
        ind_perturbed = get_perturbed_indices_with_split_indices(
            data, split_idx_lst, ood_budget_per_graph, perturb_train_indices
        )

    n_perturbed = len(ind_perturbed)
    noise = torch.zeros((n_perturbed, dim_features))
    if ood_perturbation_type == "gaussian":
        noise = torch.randn((n_perturbed, dim_features))
        noise = ood_noise_scale * noise
    elif ood_perturbation_type == "flip":
        noise = data.graph["node_feat"][ind_perturbed].clone()
        num_features = noise.size(1)
        for i in range(n_perturbed):
            num_flip_dimensions = int(ood_noise_scale * num_features)
            flip_indices = np.random.choice(
                num_features, size=num_flip_dimensions, replace=False
            )
            noise[i, flip_indices] = 1 - noise[i, flip_indices]
    else:
        raise ValueError(
            f"perturbation {ood_perturbation_type} is not supported!"
        )

    ood_mask = torch.zeros_like(data.label, dtype=bool)
    ood_mask[ind_perturbed] = True
    id_mask = ~ood_mask
    data.ood_mask = ood_mask
    data.id_mask = id_mask

    data.ood_val_mask = ood_mask[split_indices["valid"]]
    data.id_val_mask = id_mask[split_indices["valid"]]
    data.ood_test_mask = ood_mask[split_indices["test"]]
    data.id_test_mask = id_mask[split_indices["test"]]

    if ood_perturbation_type == "gaussian":
        noise = noise.cuda()
        data.graph["node_feat"][ind_perturbed] = (
            data.graph["node_feat"][ind_perturbed] + noise
        )
    else:
        data.graph["node_feat"][ind_perturbed] = ood_noise_scale * noise

    return data, ind_perturbed


def perturb_edges(
    data,
    edge_perturb_type="change",
    ood_budget_per_graph=1.0,
    **_,
):
    data = data.clone()
    num_nodes = data.y.size(0)
    num_edges = data.edge_index.size(1)

    num_perturbed_edges = int(ood_budget_per_graph * num_edges)

    if edge_perturb_type == "change":
        ind_perturbed = np.random.choice(
            np.arange(num_edges), num_perturbed_edges, replace=False
        )

        new_targets = np.random.choice(
            np.arange(num_nodes), num_perturbed_edges, replace=True
        )
        data.edge_index[1, ind_perturbed] = torch.from_numpy(new_targets).to(
            data.y.device
        )

    elif edge_perturb_type == "add":
        new_sources = np.random.choice(
            np.arange(num_nodes), num_perturbed_edges, replace=True
        )
        new_targets = np.random.choice(
            np.arange(num_nodes), num_perturbed_edges, replace=True
        )

        new_edges = torch.vstack(
            (torch.from_numpy(new_sources), torch.from_numpy(new_targets))
        ).to(data.y.device)
        data.edge_index = torch.cat([data.edge_index, new_edges], dim=1)

    elif edge_perturb_type == "remove":
        ind_perturbed = np.random.choice(
            np.arange(num_edges), num_perturbed_edges, replace=False
        )
        data.edge_index = torch.index_select(
            data.edge_index,
            1,
            torch.tensor(
                [i for i in range(num_edges) if i not in ind_perturbed]
            ).to(data.edge_index.device),
        )

    return data


def perturb_edges_with_split_indices(
    data,
    perturb_train_indices=False,
    ood_budget_per_graph=1.0,
    ood_noise_scale=1.0,
    ind_perturbed=None,
    **_,
):
    num_nodes = data.label.size(0)
    num_edges = data.graph["edge_index"].size(1)
    num_perturbed_edges = int(ood_budget_per_graph * num_edges)

    ind_perturbed = np.random.choice(
        np.arange(num_edges), num_perturbed_edges, replace=False
    )

    new_targets = np.random.choice(
        np.arange(num_nodes), num_perturbed_edges, replace=True
    )
    data.graph["edge_index"][1, ind_perturbed] = torch.from_numpy(
        new_targets
    ).to(data.label.device)
    return data


def get_ood_split(
    data,
    ood_frac_left_out_classes: float = 0.45,
    ood_num_left_out_classes: Optional[int] = None,
    ood_leave_out_last_classes: Optional[bool] = False,
    ood_left_out_classes: Optional[list] = None,
    **_,
):

    data = data.clone()

    assert (
        hasattr(data, "train_mask")
        and hasattr(data, "val_mask")
        and hasattr(data, "test_mask")
    )

    num_classes = data.y.max().item() + 1
    classes = np.arange(num_classes)

    if ood_left_out_classes is None:
        if ood_num_left_out_classes is None:
            ood_num_left_out_classes = math.floor(
                num_classes * ood_frac_left_out_classes
            )

        if not ood_leave_out_last_classes:
            np.random.shuffle(classes)

        left_out_classes = classes[
            num_classes - ood_num_left_out_classes : num_classes
        ]  # noqa

    else:
        ood_num_left_out_classes = len(ood_left_out_classes)
        left_out_classes = np.array(ood_left_out_classes)
        # reorder c in classes, such that left-out-classes
        # are at the end of classes-array
        tmp = [c for c in classes if c not in left_out_classes]
        tmp = tmp + [c for c in classes if c in left_out_classes]
        classes = np.array(tmp)

    class_mapping = {c: i for i, c in enumerate(classes)}

    left_out = torch.zeros_like(data.y, dtype=bool)
    for c in left_out_classes:
        left_out = left_out | (data.y == c)

    left_out_val = left_out & data.val_mask
    left_out_test = left_out & data.test_mask

    data.ood_mask = left_out
    data.id_mask = ~left_out

    if hasattr(data, "train_mask"):
        data.train_mask[left_out] = False
        data.test_mask[left_out] = False
        data.val_mask[left_out] = False

        data.ood_val_mask = left_out_val
        data.ood_test_mask = left_out_test

        data.id_val_mask = data.val_mask
        data.id_test_mask = data.test_mask

    num_classes = num_classes - ood_num_left_out_classes
    data.y = torch.LongTensor(
        [class_mapping[y.item()] for y in data.y], device=data.y.device
    )

    return data, num_classes


def get_roc(model, dataset, threshold, args, sparsity):
    model.eval()
    labels = dataset[0][1]
    logits = model(
        dataset.graph["node_feat"],
        dataset.graph["adjs"],
        args.tau,
        threshold,
        args,
        sparsity,
    )
    # logits = torch.softmax(logits, dim=-1)
    ind_scores, _ = logits[0][dataset.id_test_mask].max(dim=1)
    ind_scores = ind_scores.cpu().detach().numpy()
    ind_labels = np.zeros(ind_scores.shape[0])
    ind_scores = ind_scores * -1
    ood_scores, _ = logits[0][dataset.ood_test_mask].max(dim=1)
    ood_scores = ood_scores.cpu().detach().numpy()
    ood_labels = np.ones(ood_scores.shape[0])
    ood_scores = ood_scores * -1
    labels = np.concatenate([ind_labels, ood_labels])
    scores = np.concatenate([ind_scores, ood_scores])
    auroc = sk.roc_auc_score(labels, scores)
    print("* AUROC = {}".format(auroc))
    return auroc


def get_ood_split_with_indices(
    data,
    split_idx_lst,
    ood_frac_left_out_classes: float = 0.45,
    ood_num_left_out_classes: Optional[int] = None,
    ood_leave_out_last_classes: Optional[bool] = False,
    ood_left_out_classes: Optional[list] = None,
    **_,
):

    split_indices = split_idx_lst[0]
    assert (
        "train" in split_indices
        and "valid" in split_indices
        and "test" in split_indices
    )

    num_classes = data.label.max().item() + 1
    classes = np.arange(num_classes)

    if ood_left_out_classes is None:
        # which classes are left out
        if ood_num_left_out_classes is None:
            ood_num_left_out_classes = math.floor(
                num_classes * ood_frac_left_out_classes
            )

        if not ood_leave_out_last_classes:
            # create random perturbation of classes to leave out
            np.random.shuffle(classes)

        left_out_classes = classes[
            num_classes - ood_num_left_out_classes : num_classes
        ]

    else:
        ood_num_left_out_classes = len(ood_left_out_classes)
        left_out_classes = np.array(ood_left_out_classes)
        tmp = [c for c in classes if c not in left_out_classes]
        tmp = tmp + [c for c in classes if c in left_out_classes]
        classes = np.array(tmp)

    class_mapping = {c: i for i, c in enumerate(classes)}

    # Create a left out mask using the split_idx_lst['train'], ['valid'], ['test']
    left_out = torch.zeros_like(data.label, dtype=bool)
    for c in left_out_classes:
        left_out = left_out | (data.label == c)

    left_out_train = left_out.clone()
    left_out_valid = left_out.clone()
    left_out_test = left_out.clone()

    # Adjust the train, valid, and test splits based on indices
    left_out_train[split_indices["train"]] = left_out[split_indices["train"]]
    left_out_valid[split_indices["valid"]] = left_out[split_indices["valid"]]
    left_out_test[split_indices["test"]] = left_out[split_indices["test"]]

    data.ood_train_mask = left_out_train
    data.ood_val_mask = left_out_valid
    data.ood_test_mask = left_out_test.view(-1)

    data.id_train_mask = ~left_out_train
    data.id_val_mask = ~left_out_valid
    data.id_test_mask = ~left_out_test.view(-1)

    num_classes = num_classes - ood_num_left_out_classes
    data.label = torch.tensor(
        [class_mapping[y.item()] for y in data.label],
        device=data.label.device,
        dtype=torch.long,
    )

    return data, num_classes
