import os
import torch
import random
from torch_geometric.data import Data

def set_seed(seed: int):
    random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)


def to_device(data: Data, device: torch.device) -> Data:
    data.x = data.x.to(device)
    data.edge_index = data.edge_index.to(device)
    if data.edge_attr is not None:
        data.edge_attr = data.edge_attr.to(device)
        if data.edge_attr.dim() == 1:
            data.edge_attr = data.edge_attr.unsqueeze(-1)
    if hasattr(data, 'batch') and data.batch is not None:
        data.batch = data.batch.to(device)
    return data


def make_dirs(path: str):
    os.makedirs(path, exist_ok=True)


def get_experiment_name(args):
    parts = [
        f"bs{args.batch_size}", f"lr{args.lr}", f"epoch{args.epochs}",
        f"tauN{args.tau_nce}", f"tauG{args.tau_aug}", f"dim{args.proj_out}",
        f"subsamp{args.subsample_size}"
    ]
    return '_'.join(parts)