import torch                  
from torch import Tensor  
import html, os, csv     
import pytorch_lightning as pl
import json, os, html
from typing import List, Union
from pathlib import Path
from tqdm.auto import tqdm 
import torch.nn.functional as F
import numpy as np
import os
import json
import html
from typing import List, Union
import numpy as np
import torch
from torch import Tensor




def seed_everything(seed):
    """Initialise all RNGs for full reproducibility."""
    pl.seed_everything(seed, workers=True)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False



# Highlight colors for thresholds
# HIGHLIGHT_COLOR_C = (0.0, 1.0, 0.0)  # Green for causal weights
HIGHLIGHT_COLOR_C = (1.0, 0.0, 0.0)  # Red for structural weights

THRESHOLD_C = 0.3
THRESHOLD_S = 0.5

def split_weights(weights):
    if isinstance(weights, torch.Tensor):
        return weights.detach().cpu().view(-1)
    if isinstance(weights, (list, tuple)):
        return [w.detach().cpu().view(-1) for w in weights]
    raise TypeError(f"Unexpected type: {type(weights)}")










