from .config import *
from .storage import *
from .utils import *
from .object import *
from .yaml import *
from .prediction import Prediction
from .graphs import *
from .visualization_graph import visualize_graph_embeddings
from dataclasses import dataclass
from typing import Optional
import torch


@dataclass
class Prediction:
    """Class for storing model predictions along with uncertainty estimates"""
    
    # Basic predictions
    soft: Optional[torch.Tensor] = None  # Soft prediction (probabilities)
    hard: Optional[torch.Tensor] = None  # Hard prediction (class indices)
    
    # Alpha parameters for Dirichlet distribution
    alpha: Optional[torch.Tensor] = None  # Alpha parameters
    
    # Confidence scores for different uncertainty types
    sample_confidence_aleatoric: Optional[torch.Tensor] = None
    sample_confidence_epistemic: Optional[torch.Tensor] = None
    sample_confidence_features: Optional[torch.Tensor] = None
    sample_confidence_neighborhood: Optional[torch.Tensor] = None
    sample_confidence_structure: Optional[torch.Tensor] = None
    
    # Prediction confidence scores
    prediction_confidence_aleatoric: Optional[torch.Tensor] = None
    prediction_confidence_epistemic: Optional[torch.Tensor] = None
    
    # Additional parameters
    logits: Optional[torch.Tensor] = None
    evidence: Optional[torch.Tensor] = None


def to_one_hot(labels, num_classes, dtype=None):
    """Convert class labels to one-hot encoding.
    
    Args:
        labels: Class labels as tensor of shape (N,)
        num_classes: Number of classes
        dtype: Data type for the resulting tensor
        
    Returns:
        One-hot encoded labels as tensor of shape (N, num_classes)
    """
    if dtype is None:
        dtype = torch.float
    
    y_onehot = torch.zeros(labels.size(0), num_classes, dtype=dtype, device=labels.device)
    y_onehot.scatter_(1, labels.unsqueeze(1), 1)
    
    return y_onehot
