
import os
import json
import yaml
import time
import random
import logging
from dataclasses import dataclass
from typing import Optional, Dict, Any, List
import numpy as np
import torch

def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def get_logger(name: str = "gatv2_ns3_ids", level: int = logging.INFO) -> logging.Logger:
    logger = logging.getLogger(name)
    if logger.handlers:
        return logger
    logger.setLevel(level)
    ch = logging.StreamHandler()
    ch.setLevel(level)
    fmt = logging.Formatter("[%(asctime)s] %(levelname)s - %(name)s - %(message)s", "%H:%M:%S")
    ch.setFormatter(fmt)
    logger.addHandler(ch)
    return logger

def load_yaml(path: str) -> Dict[str, Any]:
    with open(path, "r") as f:
        return yaml.safe_load(f)

def ensure_dir(path: str):
    if path:
        os.makedirs(path, exist_ok=True)

class Timer:
    def __init__(self):
        self.t0 = time.time()
    def reset(self):
        self.t0 = time.time()
    def elapsed(self) -> float:
        return time.time() - self.t0

@dataclass
class GraphData:
    x: torch.Tensor               # [N, F]
    edge_index: torch.Tensor      # [2, E], long
    edge_attr: Optional[torch.Tensor]  # [E, Fe] or None
    y_node: torch.Tensor          # [N] int64 labels
    graph_id: str
    window_idx: int

def to_device(g: 'GraphData', device: torch.device) -> 'GraphData':
    return GraphData(
        x=g.x.to(device),
        edge_index=g.edge_index.to(device),
        edge_attr=None if g.edge_attr is None else g.edge_attr.to(device),
        y_node=g.y_node.to(device),
        graph_id=g.graph_id,
        window_idx=g.window_idx
    )

def collate_graphs(graphs: List['GraphData']) -> List['GraphData']:
    return graphs

def predictive_entropy(logits: torch.Tensor) -> torch.Tensor:
    probs = torch.softmax(logits, dim=-1)
    ent = -(probs * torch.log(probs.clamp_min(1e-9))).sum(dim=-1)
    return ent

def attention_entropy(attn: torch.Tensor, index: torch.Tensor) -> torch.Tensor:
    E = attn.shape[0]
    if E == 0:
        return torch.tensor(0.0)

    # Ensure index values are within bounds
    valid_mask = index < E
    if not valid_mask.any():
        return torch.tensor(0.0)

    valid_index = index[valid_mask]
    valid_attn = attn[valid_index]

    if valid_attn.numel() == 0:
        return torch.tensor(0.0)

    # Sort by index
    order = torch.argsort(valid_index)
    attn_sorted = valid_attn[order]
    idx_sorted = valid_index[order]

    # Group by consecutive indices
    uniq, counts = torch.unique_consecutive(idx_sorted, return_counts=True)
    start = 0
    ents = []
    for c in counts.tolist():
        seg = attn_sorted[start:start+c]
        seg = seg / (seg.sum() + 1e-9)
        seg_ent = -(seg * (seg + 1e-9).log()).sum()
        ents.append(seg_ent)
        start += c

    if len(ents) == 0:
        return torch.tensor(0.0)

    return torch.stack(ents).mean()

def save_json(path: str, obj: Dict[str, Any]):
    d = os.path.dirname(path)
    if d:
        os.makedirs(d, exist_ok=True)
    with open(path, "w") as f:
        json.dump(obj, f, indent=2)

def load_json(path: str) -> Dict[str, Any]:
    with open(path, "r") as f:
        return json.load(f)
