import os
import shutil
import warnings

import hydra
import omegaconf
import torch
import torch_geometric.transforms as T
from torch_geometric.seed import seed_everything

from .postprocessing_utils import hash_tensor


def build_transform(transform_cfg: omegaconf.DictConfig):
    """
    Build a transformation pipeline based on the provided configuration.

    Args:
        transform_cfg (omegaconf.DictConfig): Configuration for the transformation pipeline.

    Returns:
        torchvision.transforms.Compose: Composed transformation pipeline.
    """
    transform = []
    for t in transform_cfg:
        transform.append(hydra.utils.instantiate(t, _recursive_=False))
    return T.Compose(transform)


def get_dataset(cfg: omegaconf.DictConfig):
    """
    Get the dataset based on the provided configuration.

    Args:
        cfg (omegaconf.DictConfig): The configuration for building the dataset.
        seed (int, optional): The seed for reproducibility. Defaults to None.

    Returns:
        dataset: The constructed dataset.
    """
    seed = cfg.data.get("data_seed", None)
    if seed is not None:
        seed_everything(seed)

    transform = (
        build_transform(cfg.dataset.transform) if "transform" in cfg.dataset else None
    )
    pre_transform = (
        build_transform(cfg.dataset.pre_transform)
        if "pre_transform" in cfg.dataset
        else None
    )

    try:
        with warnings.catch_warnings():
            warnings.simplefilter("error", UserWarning)
            dataset = hydra.utils.instantiate(
                cfg.dataset, pre_transform=pre_transform, transform=transform
            )
    except hydra.errors.InstantiationException:
        # If the dataset directory exists, remove it
        dataset_path = str(cfg.dataset.root / cfg.dataset.name)
        if os.path.exists(dataset_path):
            shutil.rmtree(dataset_path)
        # Now, try to instantiate the dataset again
        dataset = hydra.utils.instantiate(
            cfg.dataset, pre_transform=pre_transform, transform=transform
        )

    evaluator_cfg = cfg.data.get("evaluator", None)
    if evaluator_cfg is None:
        return dataset

    evaluator = hydra.utils.instantiate(evaluator_cfg)
    return dataset, evaluator


def get_subset(seen_subsets, d, alpha):
    """
    Generate a subset based on the given parameters.

    Args:
        seen_subsets (set or None): A set of previously seen subset hashes, or None if not applicable.
        d (int): The total number of elements to choose from.
        subset_size (int): The size of the subset to generate.

    Returns:
        torch.Tensor: A boolean tensor representing the generated subset.
    """
    p = torch.full((d,), alpha)
    while True:
        subset = torch.bernoulli(p)
        subset_hash = hash_tensor(subset)
        if subset_hash not in seen_subsets:
            seen_subsets.add(subset_hash)
            return subset.bool()


def get_appropriate_edge_index(data):
    """
    Retrieves the appropriate edge index from the input data.

    Args:
        data (torch_geometric.data.Data): The input data object.

    Returns:
        torch.Tensor: The edge index tensor.

    Raises:
        ValueError: If the input data does not contain either 'edge_index' or 'adj_t'.
    """
    if "edge_index" in data:
        edge_index = data.edge_index
    elif "adj_t" in data:
        edge_index = data.adj_t
    else:
        raise ValueError("The input data must contain either 'edge_index' or 'adj_t'.")

    return edge_index
