import logging
import GPUtil
import torch
import torch.nn as nn
from typing import Optional

from torch_geometric.utils import homophily
from models.utils import get_knn_edge_index


def unused_gpu():
    for gpu in GPUtil.getGPUs():
        if gpu.load == 0:
            return gpu.id
    return None


def kl_bernoulli(q, p, eps=1e-8) -> torch.Tensor:
    q = torch.clamp(q, eps, 1 - eps)
    kl = p * torch.log(p / q) + (1 - p) * torch.log((1 - p) / (1 - q))
    kl = torch.nan_to_num(kl, nan=0.0, posinf=0.0, neginf=0.0)
    return kl


def edge_homophily(adj: torch.Tensor, y: torch.Tensor) -> float:
    if adj.is_sparse:
        mask = adj.values() > 0
        edge_index = adj.indices()[:, mask]
        edge_index = torch.cat([edge_index, edge_index.flip(0)], dim=1)
    else:
        edge_index = adj.nonzero().t()
    return homophily(edge_index=edge_index.to('cpu'), y=y.to('cpu'))


def _is_norm(module):
    if isinstance(module, nn.BatchNorm1d) or \
        isinstance(module, nn.BatchNorm2d) or \
            isinstance(module, nn.BatchNorm3d) or \
                isinstance(module, nn.LayerNorm):
        return True
    return False


def setup_logger(log_file=None, console=True):
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)

    class RepeatFilter(logging.Filter):
        def filter(self, record):
            if not hasattr(record, 'repeat') or record.repeat is None:
                record.repeat = '-'
            return True

    # Clear existing handlers to avoid duplicate logs when rerunning
    if logger.hasHandlers():
        logger.handlers.clear()

    formatter = logging.Formatter('%(asctime)s | %(levelname)s | REPEAT=%(repeat)s | %(message)s')
    logger.addFilter(RepeatFilter())

    if console:
        # Console handler (optional)
        console_handler = logging.StreamHandler()
        console_handler.setFormatter(formatter)
        logger.addHandler(console_handler)

    if log_file is not None:
        # File handler
        file_handler = logging.FileHandler(log_file, mode='w')
        file_handler.setFormatter(formatter)
        logger.addHandler(file_handler)
    return logger


def get_graph_prior(obs_edge_index: torch.Tensor,
                obs_edge_prob: float,
                non_edge_prob: float,
                n_nodes=None,
                knn_prior_edge_k=0,
                x: Optional[torch.Tensor]=None,
                knn_prior_edge_prob=0.2,
                knn_prior_edge_dist_metric='euclidean'):
    if n_nodes is None:
        n_nodes = int(max(obs_edge_index.max(), obs_edge_index.max())) + 1
    A = torch.full((n_nodes, n_nodes), non_edge_prob, dtype=torch.float32)
    
    # Set kNN prior edge probabilities
    if knn_prior_edge_k > 0:
        knn_edge_index = get_knn_edge_index(
            x, k=knn_prior_edge_k, metric=knn_prior_edge_dist_metric)
        A[knn_edge_index[0], knn_edge_index[1]] = knn_prior_edge_prob
        A[knn_edge_index[1], knn_edge_index[0]] = knn_prior_edge_prob 
    
    # Set observed edge probabilities (overwrite kNN if there is overlap)
    A[obs_edge_index[0], obs_edge_index[1]] = obs_edge_prob
    A[obs_edge_index[1], obs_edge_index[0]] = obs_edge_prob  # symmetrize

    edge_index = torch.triu_indices(A.size(0), A.size(1), offset=1)
    edge_values = A[edge_index[0], edge_index[1]]  # probabilities
    return edge_index, edge_values