import random

import numpy as np
import torch
from torch_geometric.data import Batch, Data, DataLoader
from torch_geometric.graphgym.config import cfg
from torch_geometric.loader import DataLoader
from tqdm import tqdm

torch.backends.cuda.matmul.allow_tf32 = True  # Default False in PyTorch 1.12+
torch.backends.cudnn.allow_tf32 = True  # Default True


def perturb_node_features(data, noise_level=0.1):
    num_nodes = data.x.size(0)
    perturb_indices = random.sample(range(num_nodes), num_nodes)
    noise = torch.zeros((num_nodes, data.x.size(1)))
    noise = torch.randn((num_nodes, data.x.size(1))) * noise_level
    if data.x.dtype == torch.int64:
        noise = noise.round().to(torch.int64)

    data.x[perturb_indices] = noise
    return data


def perturb_node(data_loader, noise_level=0.1):
    perturbed_graphs = []
    for batch in tqdm(data_loader, desc="Perturbing nodes"):
        for data in batch.to_data_list():
            perturbed_data = perturb_node_features(data, noise_level)
            perturbed_graphs.append(perturbed_data)
    return DataLoader(
        perturbed_graphs,
        batch_size=data_loader.batch_size,
        shuffle=False,
    )


def perturb_edge(data_loader, edge_perturbation_rate):
    perturbed_graphs = []
    for batch in tqdm(data_loader, desc="Perturbing edges"):
        for data in batch.to_data_list():
            edge_index = data.edge_index
            num_edges = edge_index.size(1)
            num_perturb_edges = int(num_edges * edge_perturbation_rate)
            edge_indices = np.random.choice(
                num_edges, num_perturb_edges, replace=False
            )
            for edge_idx in edge_indices:
                target_node = edge_index[1, edge_idx]
                candidates = np.setdiff1d(
                    np.arange(data.num_nodes), target_node
                )
                new_target_node = np.random.choice(candidates)
                data.edge_index[1, edge_idx] = torch.tensor(
                    new_target_node, device=data.edge_index.device
                )
            perturbed_graphs.append(data)

    return DataLoader(
        perturbed_graphs, batch_size=data_loader.batch_size, shuffle=False
    )


def delete_random_graphs(loader, delete_percentage):
    dataset = loader.dataset
    num_graphs = len(dataset)
    num_delete = int(num_graphs * delete_percentage)
    delete_indices = np.random.choice(num_graphs, num_delete, replace=False)
    remaining_graphs = []
    for i in tqdm(range(num_graphs), desc="Deleting graphs"):
        if i not in delete_indices:
            remaining_graphs.append(dataset[i])
    new_loader = DataLoader(
        remaining_graphs, batch_size=loader.batch_size, shuffle=False
    )
    return new_loader


def filter_loader(loader, ood_classes=None):
    filtered_data_list = []

    for batch in tqdm(loader, desc="OOD"):
        for data in batch.to_data_list():
            if data.y in ood_classes:
                filtered_data_list.append(data)

    return DataLoader(
        filtered_data_list, batch_size=loader.batch_size, shuffle=False
    )


def remove_ood_classes(loaders, num_ood_classes):
    all_classes = torch.cat([data.y for data in loaders[0]], dim=0).unique()
    ood_classes = np.random.choice(
        all_classes.cpu().numpy(), num_ood_classes, replace=False
    )
    ood_class_set = set(ood_classes)
    not_ood_classes = torch.tensor(
        [cls for cls in all_classes if cls.item() not in ood_class_set]
    )
    ood_classes = torch.tensor(ood_classes)
    loaders[0] = filter_loader(loaders[0], ood_classes=not_ood_classes)
    loaders[1] = filter_loader(loaders[1], ood_classes=ood_classes)
    loaders[2] = filter_loader(loaders[2], ood_classes=ood_classes)
    return loaders, ood_classes
