import os
import torch
import numpy as np
import random
from omegaconf import DictConfig, OmegaConf, open_dict
from time import time
from pytorch_lightning.loggers import WandbLogger, CSVLogger
from torch import nn
from torchvision.models import resnet18, resnet34, resnet50, resnet101, resnet152
import transformers
from env import CACHE
import sympy as sp
import warnings
warnings.filterwarnings("ignore")
import re
from scipy.spatial.distance import cdist

def set_matmul_precision():
    """
    Set float32 matmul precision for better performance on CUDA devices with Tensor Cores.
    Automatically detects the GPU and sets precision only if Tensor Cores are available.
    """
    if torch.cuda.is_available():
        device_name = torch.cuda.get_device_name(0)
        # List of GPUs with Tensor Cores that benefit from reduced precision
        tensor_core_gpus = ['V100', 'A100', 'A6000', 'A5000', 'A4000', 'RTX', 'T4', 'H100']
        if any(gpu in device_name for gpu in tensor_core_gpus):
            torch.set_float32_matmul_precision('medium')
            print(f"Setting matmul precision to 'medium' for {device_name}")
        else:
            print(f"Using default matmul precision for {device_name}")

def set_seed(seed: int):
    print(f"Seed set to {seed}")
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def is_valid_experiment(cfg: DictConfig):
    """ 
    Check if the experiment is valid based on the dataset and model combination.
    """
    if cfg.dataset.metadata.task == 'regression' and cfg.model.metadata.name in ['dcr', 'cmr']:
        raise ValueError(f"The experiment is not valid. Please check the configuration.\n\
                         {cfg.model.metadata.name} cannot be applied to regression tasks.")

def set_loggers(cfg):
    """ Set the loggers for the experiment """
    # Update the note in the config: if it is None, set it to an empty string
    with open_dict(cfg):
        cfg.update(
            note = "_" if cfg.note is None else str(cfg.note)
        )
    name = f"{cfg.dataset.metadata.name}.{cfg.model.metadata.name}.{str(cfg.memory_size)}.{cfg.seed}.{int(time())}"
    group_format = (
        "{dataset}_"
        "{model}_"
        "{memory_size}_"
        "{note}"
    )
    # Define the tags for wandb
    tags = [cfg.dataset.metadata.name, cfg.model.metadata.name, cfg.note, str(cfg.memory_size)]
    # Filter out None values from tags
    tags = [tag for tag in tags if tag is not None]
    # Define the group for wandb
    group = group_format.format(**parse_hyperparams(cfg))
    if cfg.wandb.project is None or cfg.wandb.entity is None:
        wandb_logger = None
    else:
        wandb_logger = WandbLogger(project=cfg.wandb.project, 
                               entity=cfg.wandb.entity, 
                               name=name,
                               group=group,
                               tags=tags)
        wandb_logger.log_hyperparams(parse_hyperparams(cfg))
    csv_logger = CSVLogger("logs/",
                           name="experiment_metrics",
                           version="")  # Use empty string to avoid version folders
    return wandb_logger, csv_logger

def parse_hyperparams(cfg: DictConfig):
    hyperparams = {
        "dataset": cfg.dataset.metadata.name,
        "model": cfg.model.metadata.name,
        "memory_size": cfg.memory_size,
        "seed": cfg.seed,
        "hydra_cfg": OmegaConf.to_container(cfg),
        "note": cfg.note,
    }
    return hyperparams

def get_backbone_latent_size(backbone):
    if backbone == 'resnet18':
        model = resnet18(pretrained=True)
    elif backbone == 'resnet34':
        model = resnet34(pretrained=True)
    elif backbone == 'resnet50':
        model = resnet50(pretrained=True)
    elif backbone == 'resnet101':
        model = resnet101(pretrained=True)
    elif backbone == 'resnet152':
        model = resnet152(pretrained=True)
    elif 'vit' in backbone or 'dino' in backbone:
        model = transformers.ViTModel.from_pretrained(backbone)
    elif backbone == 'bert-base-uncased':
        model = transformers.AutoModel.from_pretrained(backbone)
        return 768  # BERT base model has 768 hidden size
    elif backbone == 'sentence-transformers/all-mpnet-base-v2':
        return 768  # all-mpnet-base-v2 model has 768 hidden size
    elif backbone == 'sentence-transformers/all-MiniLM-L6-v2':
        return 384  # all-MiniLM-L6-v2 model has 384 hidden size
    else:
        raise ValueError(f"Image backbone {backbone} not recognized.")
    
    if 'resnet' in backbone:
        model = nn.Sequential(*list(model.children())[:-1])
        test = model(torch.randn((1,3,224,224)))
        latent_dim = test.flatten(start_dim=1).shape[1]
    elif 'vit' in backbone or 'dino' in backbone:
        test = model(torch.randn((1,3,224,224)))
        latent_dim = test.last_hidden_state[:, 0, :].shape[1]
    else:
        pass 
    
    # delete the model to free memory
    del model
    del test
    torch.cuda.empty_cache()
    return latent_dim

def get_type_from_name(dataset_name):
    # Symbolic regression datasets
    symbolic_prefixes = ['feynman_']
    if any(dataset_name.startswith(prefix) for prefix in symbolic_prefixes):
        return 'symbolic_regression'
    
    # Video datasets (pre-embedded)
    if dataset_name in ['synthetic_motion']:
        return 'video'
    
    # Image datasets
    if dataset_name in ['mnist_addition', 'cub', 'cub_incomplete', \
                        'awa2', 'awa2_incomplete', 'xor', 'celeba', \
                        'cifar10', 'cifar100', 'mnist_arithmetic', 'mnist_arithmetic_hard', \
                        'pendulum', 'dsprites', 'dsprites_simple', 'dsprites_complex', \
                        'mnist_exponential']:
        return 'image'
    
    # Text datasets (default)
    return 'text'

def get_batch_from_loader(train_loader, device):
    with torch.cuda.device(device if device != 'cpu' else 'cpu'):
        # Temporarily set default tensor type to CPU to avoid automatic GPU allocation
        original_default_tensor_type = torch.get_default_dtype()
        if device == 'cpu':
            torch.set_default_tensor_type('torch.FloatTensor')
        
        batch = next(iter(train_loader))
        
        # Restore original tensor type
        torch.set_default_dtype(original_default_tensor_type)
    return batch

def setup_encoder(cfg: DictConfig, input_size: int, backbone_latent_size: int) -> DictConfig:

    type = None

    # if we want to extract the embeddings it means that we are NOT 
    # fine-tuning a pre-trained backbone during training.
    # This means that we just need a linear encoder.
    if cfg.extract_embeddings:
        input_size = input_size if cfg.dataset.metadata.name != 'xor' else 2 
        cfg.model.params.encoder = {
            '_target_': 'src.models.encoders.mlp.MLPEncoder',
            'output_size': backbone_latent_size, 
            'activation': cfg.activation,
            'input_transform': {
                '_target_': 'src.models.encoders.transform.FlattenTransform',
            },
            'dropout': 0.1
        }

    else:
        backbone = cfg.img_backbone_name if get_type_from_name(cfg.dataset.metadata.name) == 'image' \
                                            else cfg.text_backbone_name
        
        if cfg.dataset.metadata.data_type == 'toy':
            target = 'src.models.encoders.linear.LinearEncoder'
            transform = None
        else:
            if 'vit' in backbone:
                target = 'src.models.encoders.vit.VitEncoder'
                transform = {
                    '_target_': "src.models.encoders.transform.VitTransform",
                    'flatten': False
                }
            elif 'resnet' in backbone:
                target = 'src.models.encoders.resnet.ResNetEncoder'
                transform = {
                    '_target_': "src.models.encoders.transform.ImageTransform",
                    'flatten': False
                }
            elif ('sentence-transformers' in backbone) or ('bert' in backbone):
                target = 'src.models.encoders.transformer.TransformerEncoder'
                transform = None
                type = backbone

        # If the dataset is mnist_addition, then perform its own preprocessing
        if cfg.dataset.metadata.name == 'mnist_addition':
            transform = {
                '_target_': "src.models.encoders.transform.MNISTTransform",
                'flatten': False
            }

        # If we are fine-tuning a pre-trained model,
        # we need to set the encoder to the one defined in the dataset config.
        cfg.model.params.encoder = {
            '_target_': target,
            'output_size': backbone_latent_size, # we do not want the linear layer to reduce the size of the embeddings
            'input_transform' : transform,
        } 

    with open_dict(cfg):
        cfg.model.params.encoder.update(
            input_size = input_size,
        )

        if type != None:
            cfg.model.params.encoder.update(
                type = type
            )

    return cfg

def allow_hard_concepts(dataset_name):
    # Datasets that do not allow hard concepts so far: cebab, mnist_arithmetic.
    return dataset_name in ['xor', 'cub', 'mnist_addition', 'awa2', 'cub_incomplete', 'awa2_incomplete', 'cifar10', 'cifar100']


def update_config_from_data(cfg: DictConfig, train_loader, c_names,
                            y_names, c_groups, csv_log_dir) -> DictConfig:
    """
    Update the config with the input size, output size, and concept names.
    """
    # Create a temporary dataloader that loads data directly to the specified device
    # or force CPU loading to avoid automatic GPU allocation
    device = cfg.gpus[0]
    batch = get_batch_from_loader(train_loader, device)

    if get_type_from_name(cfg.dataset.metadata.name) == 'image':
        x = batch['x']
        data_type = 'image'
    else:
        x = batch['x'] if cfg.extract_embeddings else batch['x']['input_ids']
        data_type = 'text'

    input_size = torch.prod(torch.tensor(x.shape[1:])).item()
        
    # Clean up GPU memory
    del batch
    del x
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    n_labels = len(y_names)

    if c_groups is None or not isinstance(c_groups, dict):
        c_groups = c_groups
    else:
        c_groups = dict(c_groups)

    backbone = cfg.img_backbone_name if get_type_from_name(cfg.dataset.metadata.name) == 'image' \
                else "bert-base-uncased" if cfg.dataset.metadata.name == 'mawps' \
                                        else cfg.text_backbone_name
    backbone_latent_size = cfg.dataset.latent_size if cfg.extract_embeddings else get_backbone_latent_size(backbone)

    with open_dict(cfg):
        # Extract true equations if available
        true_equations = cfg.dataset.equations if 'equations' in cfg.dataset and cfg.dataset.equations is not None else None
        
        cfg.engine.update(
            c_names = c_names,
            y_name = y_names,
            csv_log_dir = csv_log_dir,
            data_type = data_type,
            dataset_name = cfg.dataset.metadata.name,
            scale_variables = cfg.scale_variables if 'scale_variables' in cfg else False,
            true_equations = true_equations
        )

        hard_concepts = cfg.hard_concepts
        concept_type = prepare_concept_type(cfg.dataset.metadata.concept_type, cfg.engine.c_names, cfg.dataset.metadata.name)

        # If cfg.dataset.equations exists, then we want to use the known equations, 
        # null otherwise.
        known_equations = cfg.dataset.equations if 'equations' in cfg.dataset and cfg.dataset.equations is not None else None

        cfg.model.params.update(
            output_size = n_labels,
            c_names = c_names,
            y_names = y_names,
            task = cfg.dataset.metadata.task,
            c_groups = c_groups,
            backbone_latent_size = backbone_latent_size,
            concept_type = concept_type,
            hard_concepts = hard_concepts,
            known_equations = known_equations,
            disjoint_training = cfg.disjoint_training,
            device = device
        )

        cfg = setup_encoder(cfg, input_size, backbone_latent_size)

    return cfg

def prepare_concept_type(c_types, c_names, dataset_name):
    n_concepts = len(c_names)
    if isinstance(c_types, list):
        return c_types
    else:
        if dataset_name == 'mnist_exponential':
            return c_types
        return [c_types] * n_concepts

def generate_data_path(cfg):
    data_path = os.path.join(str(CACHE), 
                             'stored_tensors', 
                             'embeddings' if cfg.extract_embeddings else 'raw', # whether it contains embeddings or not
                             cfg.dataset.metadata.name)

    # Add backbone name
    dataset_type = get_type_from_name(cfg.dataset.metadata.name)
    
    if dataset_type == 'image' or dataset_type == 'video':
        # Both image and video datasets use img_backbone_name
        img_backbone_name = cfg.img_backbone_name.replace('/', '_')
        data_path += f"/{img_backbone_name}"
    elif dataset_type == 'text':
        text_backbone_name = cfg.text_backbone_name.replace('/', '_')
        data_path += f"/{text_backbone_name}"
    elif dataset_type == 'symbolic_regression':
        # For symbolic regression, use latent_dim and noise_std in the path
        latent_dim = cfg.dataset.loader.get('latent_dim', 4)
        noise_std = cfg.dataset.loader.get('noise_std', 0.0)
        data_path += f"/latent{latent_dim}_noise{str(noise_std).replace('.', '')}"

    # Add seed
    data_path += f"/seed_{cfg.seed}"

    if cfg.dataset.loader.concept_percentage != None:
        # Add concept percentage if it is not None
        data_path += f"_{str(cfg.dataset.loader.concept_percentage).replace('.', '')}"

    train_path = f"{data_path}/train.pt"
    val_path = f"{data_path}/val.pt"
    test_path = f"{data_path}/test.pt"

    return data_path, train_path, val_path, test_path

def standardize_tensor(tensor, dim=0):
    """
    Standardize a tensor to have zero mean and unit variance.
    
    Args:
        tensor: Input tensor to standardize
    """
    mean = tensor.mean(dim=dim)
    std = tensor.std(dim=dim)

    # Add small epsilon to avoid division by zero
    eps = 1e-8
    standardized = (tensor - mean) / (std + eps)

    return standardized, mean, std


def sanitize_concept_names(concept_names):
    """
    Create SymPy-safe concept names for a list of concept names.

    This returns a tuple (safe_names, mapping) where safe_names is a list of
    names guaranteed to be safe as SymPy variable identifiers and mapping is
    a dict mapping original_name -> safe_name.

    Behavior:
      - Replace any character that is not alphanumeric or underscore with an
        underscore.
      - Collapse repeated underscores and strip leading/trailing underscores.
      - If the resulting name starts with a digit, prefix it with an underscore.
      - Ensure deterministic uniqueness by appending "_1", "_2", ... when
        collisions occur after sanitization.

    This preserves readable concept names (e.g. "eye-color" -> "eye_color",
    "age (yrs)" -> "age_yrs") rather than mapping to generic names like
    c0, c1.
    """
    safe_names = []
    mapping = {}
    seen = {}

    for name in concept_names:
        # Ensure string
        orig = name
        s = str(name)

        # Replace non-alphanumeric (_ allowed) with underscore
        s = re.sub(r'[^0-9A-Za-z_]+', '_', s)

        # Collapse multiple underscores
        s = re.sub(r'__+', '_', s)

        # Strip leading/trailing underscores
        s = s.strip('_')

        # If empty after stripping, use placeholder
        if s == '':
            s = 'x'

        # If starts with digit, prefix with underscore
        if re.match(r'^[0-9]', s):
            s = f"_{s}"

        # Ensure uniqueness deterministically
        base = s
        count = seen.get(base, 0)
        if count > 0:
            s = f"{base}_{count}"
        seen[base] = count + 1

        safe_names.append(s)
        mapping[orig] = s

    return safe_names, mapping

def subsample_for_input_coverage(X, subsample_size, random_state=42):
    """
    Subsample data points to maintain good coverage of the input space.
    Uses k-means clustering to identify representative samples.
    
    Args:
        X: Input data array of shape (n_samples, n_features)
        subsample_size: Number of samples to select
        random_state: Random seed for reproducibility
        
    Returns:
        indices: Array of indices to keep from the original dataset
    """
    from sklearn.cluster import KMeans
    from sklearn.metrics import pairwise_distances_argmin_min
    
    n_samples = X.shape[0]
    
    # If subsample_size >= n_samples, return all indices
    if subsample_size >= n_samples:
        return np.arange(n_samples)
    
    # Convert to numpy if needed
    if torch.is_tensor(X):
        X_np = X.cpu().numpy()
    else:
        X_np = X
    
    # Use k-means to find cluster centers
    kmeans = KMeans(n_clusters=subsample_size, random_state=random_state, n_init=1, verbose=1)
    kmeans.fit(X_np)
    
    # For each cluster center, find the closest actual data point
    # indices, _ = pairwise_distances_argmin_min(kmeans.cluster_centers_, X_np)
    D = cdist(kmeans.cluster_centers_, X_np, metric="euclidean")   # or any other metric
    indices = D.argmin(axis=1)

    return indices

def subsample_based_on_y_hat(y_target, subsample_size):
    """
    Perform density-aware subsampling for continuous target values.
    Uses stratified sampling based on binning to ensure good coverage across the distribution.
    
    Args:
        y_target: Target values of shape (n_samples,) with continuous values
        subsample_size: Number of samples to select
    Returns:
        indices: Array of indices to keep from the original dataset
    """
    # Convert to numpy if needed
    if torch.is_tensor(y_target):
        y_np = y_target.cpu().numpy().flatten()
    else:
        y_np = np.array(y_target).flatten()
    
    n_samples = len(y_np)
    
    # If subsample_size >= n_samples, return all indices
    if subsample_size >= n_samples:
        return np.arange(n_samples)
    
    # Determine number of bins using Freedman-Diaconis rule
    q75, q25 = np.percentile(y_np, [75, 25])
    iqr = q75 - q25
    if iqr > 0:
        bin_width = 2 * iqr / (n_samples ** (1/3))
        n_bins = max(10, min(100, int((y_np.max() - y_np.min()) / bin_width)))
    else:
        n_bins = 20  # Default if IQR is 0
    
    # Bin the target values
    bin_edges = np.linspace(y_np.min(), y_np.max(), n_bins + 1)
    bin_indices = np.digitize(y_np, bin_edges) - 1
    bin_indices = np.clip(bin_indices, 0, n_bins - 1)  # Ensure within bounds
    
    # Count samples in each bin
    bin_counts = np.bincount(bin_indices, minlength=n_bins)
    
    # Calculate proportional samples per bin
    bin_proportions = bin_counts / n_samples
    samples_per_bin = np.floor(bin_proportions * subsample_size).astype(int)
    
    # Ensure at least one sample per non-empty bin if possible
    non_empty_bins = bin_counts > 0
    samples_per_bin[non_empty_bins] = np.maximum(samples_per_bin[non_empty_bins], 1)
    
    selected_indices = []
    
    # Sample from each bin
    for bin_idx in range(n_bins):
        # Find all indices in this bin
        bin_mask = bin_indices == bin_idx
        bin_sample_indices = np.where(bin_mask)[0]
        
        if len(bin_sample_indices) == 0:
            continue
        
        n_to_select = min(samples_per_bin[bin_idx], len(bin_sample_indices))
        
        if n_to_select > 0:
            # Randomly select from this bin
            selected = np.random.choice(bin_sample_indices, size=n_to_select, replace=False)
            selected_indices.extend(selected)
    
    # If we haven't reached subsample_size, randomly sample more from remaining pool
    if len(selected_indices) < subsample_size:
        remaining_needed = subsample_size - len(selected_indices)
        all_indices = set(range(n_samples))
        remaining_indices = list(all_indices - set(selected_indices))
        
        if len(remaining_indices) > 0:
            additional = np.random.choice(remaining_indices, 
                                        size=min(remaining_needed, len(remaining_indices)), 
                                        replace=False)
            selected_indices.extend(additional)
    
    # If we have selected more than subsample_size due to rounding, randomly trim
    if len(selected_indices) > subsample_size:
        selected_indices = np.random.choice(selected_indices, size=subsample_size, replace=False)
    
    return np.array(selected_indices)

def subsampling(X_memory):
    """
    Perform subsampling by selecting exactly one sample per unique pattern.
    
    Args:
        X_memory: Input tensor of shape (n_samples, n_features) with binary features (0 or 1)
    Returns:
        indices: Array of indices to keep from the original dataset
    """
    # Convert to numpy if needed
    if torch.is_tensor(X_memory):
        X_np = X_memory.cpu().numpy()
    else:
        X_np = X_memory
    
    # Find unique rows
    unique_rows = np.unique(X_np, axis=0)
    
    selected_indices = []
    
    # Select exactly one sample for each unique row
    for unique_row in unique_rows:
        # Find all indices in the original data that match this unique row
        matching_indices = np.where((X_np == unique_row).all(axis=1))[0]
        
        # Randomly select one index
        selected = np.random.choice(matching_indices, size=1, replace=False)
        selected_indices.extend(selected)
    
    return np.array(selected_indices)

def symbolic_regression(
        stored_concepts, 
        stored_targets, 
        stored_selector_probs,
        memory_size,
        output_size,
        c_names,
        y_names,
        device,
        pysr_params,
        task,
        disjoint_training
    ):

    """
    Fine-tune the model by replacing each MLP in the BlackBoxPredictor with 
    symbolic equations discovered by PySR.
    
    This method:
    1. Collects stored concepts, targets, and selector probabilities
    2. For each memory slot and each output class, fits a PySR model
    3. Extracts the best equation
    4. Creates a SymbolicPredictor with all discovered equations
    """
    
    # Lazy import PySR only when this function is called
    # This avoids Julia initialization conflicts with PyTorch at import time
    import pysr
    
    # Save current CUDA visibility and temporarily hide GPUs from Julia to avoid conflicts
    original_cuda_visible = os.environ.get('CUDA_VISIBLE_DEVICES', None)
    os.environ['CUDA_VISIBLE_DEVICES'] = ''  # Hide GPUs from Julia
    
    try:
        # Try to initialize Julia if not already done
        pysr.julia_helpers.init_julia(julia_project=None, quiet=False)
    except:
        pass  # Already initialized, that's fine
    
    from pysr import PySRRegressor
    
    # Restore original CUDA visibility
    if original_cuda_visible is not None:
        os.environ['CUDA_VISIBLE_DEVICES'] = original_cuda_visible
    else:
        os.environ.pop('CUDA_VISIBLE_DEVICES', None)
    
    if len(stored_concepts) == 0:
        raise ValueError("No stored data for fine-tuning. Run forward passes with store_for_finetuning=True first.")
    
    # Move all data to CPU and convert to numpy to avoid GPU conflicts with Julia
    stored_concepts = stored_concepts.cpu()
    stored_targets = stored_targets.cpu()
    stored_selector_probs = stored_selector_probs.cpu()

    # Handle different target shapes
    if stored_targets.ndim == 1:
        stored_targets = stored_targets.reshape(-1, 1)
    
    # Dictionary to store equations
    # Structure: {memory_idx: {output_name: sympy_equation}}
    all_equations = {i: {} for i in range(memory_size)}
    
    # With independent outputs, stored_selector_probs has shape (batch, memory_size, n_outputs, n_samples)
    # We need to process each output independently
    for output_idx in range(output_size):
        output_name = y_names[output_idx] if output_idx < len(y_names) else f"y_{output_idx}"
        
        # For each memory slot, collect samples where this memory was selected for this output
        for memory_idx in range(memory_size):
            # Get samples where this memory slot was selected for this specific output
            # stored_selector_probs shape: (batch, memory_size, n_outputs, n_samples)
            # Get argmax along memory dimension (dim=2) for this output
            output_selector_probs = stored_selector_probs[:, :, output_idx, :]  # (batch, memory_size, n_samples)
            memory_mask = (output_selector_probs.argmax(dim=1).flatten() == memory_idx).numpy()
            n_samples_for_memory = memory_mask.sum()
            
            # Filter concepts and targets for this memory slot and output
            X_memory = stored_concepts[memory_mask]  # [n_samples_memory, n_concepts]
            y_target = stored_targets[memory_mask, output_idx]  # [n_samples_memory]
            
            # Skip if no samples for this memory slot
            if n_samples_for_memory == 0:
                print(f"Memory slot {memory_idx}, output '{output_name}' has no samples. Using zero equation.")
                all_equations[memory_idx][output_name] = sp.sympify("0")
                continue
            
            subsample_size = 2000
            if task == 'classification' and disjoint_training:
                indices = subsampling(X_memory)
            elif task == 'classification' and not disjoint_training and n_samples_for_memory > subsample_size:
                indices = subsample_based_on_y_hat(y_target, subsample_size)
            elif task == 'regression' and n_samples_for_memory > subsample_size:
                indices = np.random.choice(n_samples_for_memory, size=subsample_size, replace=False)
            else:
                indices = None

            if indices is not None:
                X_memory = X_memory[indices]
                y_target = y_target[indices]

            # Validate data: check for NaN and Inf values
            if np.isnan(X_memory).sum()>0 or np.isinf(X_memory).sum()>0:
                print(f"WARNING: X_memory contains NaN or Inf values for memory {memory_idx}, output '{output_name}'. Cleaning data.")
                valid_mask = ~(np.isnan(X_memory).any(axis=1) | np.isinf(X_memory).any(axis=1))
                X_memory = X_memory[valid_mask]
                y_target = y_target[valid_mask]
            
            if np.isnan(y_target).sum()>0 or np.isinf(y_target).sum()>0:
                print(f"WARNING: y_target contains NaN or Inf values for memory {memory_idx}, output '{output_name}'. Cleaning data.")
                valid_mask = ~(np.isnan(y_target) | np.isinf(y_target))
                X_memory_clean = X_memory[valid_mask]
                y_target = y_target[valid_mask]
            else:
                X_memory_clean = X_memory
            
            print(f"\nFitting PySR for output '{output_name}' (memory slot {memory_idx})...")
            print(f"  Input shape: {X_memory_clean.shape}, Target shape: {y_target.shape}")

            try:

                model = PySRRegressor(
                    **pysr_params,
                    verbosity=1,  
                    progress=True,
                )

                model.fit(X_memory_clean.cpu().numpy(), y_target.cpu().numpy())

                # Get the best equation (highest score)
                equations_df = model.equations_
                print(f"  Pareto front has {len(equations_df)} equations")

                # Select the best equation based on task type
                if task == 'classification':
                    # For classification: select lowest loss (best accuracy)
                    best_eq_row = equations_df.nsmallest(1, 'loss').iloc[0]
                else:
                    # For regression: select highest score (balance loss and complexity)
                    best_eq_row = equations_df.nlargest(1, 'score').iloc[0]
                sympy_eq = best_eq_row['sympy_format']
                
                # Rename variables from x0, x1, ... to concept names
                for i, c_name in enumerate(c_names):
                    sympy_eq = sympy_eq.subs(sp.Symbol(f'x{i}'), sp.Symbol(c_name))
                
                all_equations[memory_idx][output_name] = sympy_eq
                
                print(f"  ✓ Best equation: {sympy_eq}")
                print(f"    Loss: {best_eq_row['loss']:.6f}")
                print(f"    Complexity: {best_eq_row['complexity']}")
                print(f"    Score: {best_eq_row['score']:.6f}")

                # Clean up the model
                del model
                
            except Exception as e:
                print(f"  ✗ ERROR fitting PySR for memory {memory_idx}, output '{output_name}':")
                print(f"    {type(e).__name__}: {str(e)}")
                print(f"    Using fallback constant equation (mean value)")
                if len(y_target) == 0:
                    print(f"    No samples available. Using zero as fallback.")
                    all_equations[memory_idx][output_name] = sp.sympify("0")
                else:
                    mean_val = float(y_target.mean())
                    all_equations[memory_idx][output_name] = sp.sympify(str(mean_val))
                    print(f"    Fallback equation: {mean_val}")

    # Verify all equations are present
    for memory_idx in range(memory_size):
        for output_idx in range(output_size):
            output_name = y_names[output_idx] if output_idx < len(y_names) else f"y_{output_idx}"
            if output_name not in all_equations[memory_idx]:
                raise ValueError(f"Missing equation for memory {memory_idx}, output {output_name}")
    
    return all_equations


def multiple_symbolic_regression(
        stored_concepts, 
        stored_targets, 
        stored_selector_probs,
        memory_size,
        output_size,
        c_names,
        y_names,
        device,
        pysr_params_list,
        task,
        disjoint_training
    ):
    """
    Run symbolic regression multiple times with different PySR configurations.
    
    This function addresses the instability of PySRRegressor when instantiated 
    multiple times by managing Julia initialization carefully and running all
    symbolic regression tasks in a single session.
    
    Args:
        stored_concepts: Tensor of stored concept predictions
        stored_targets: Tensor of stored target predictions
        stored_selector_probs: Tensor of memory selector probabilities
        memory_size: Number of memory slots
        output_size: Number of outputs
        c_names: List of concept names
        y_names: List of output names
        device: Device to use for computation
        pysr_params_list: List of dictionaries, each containing PySR parameters
        task: 'classification' or 'regression'
        disjoint_training: Whether disjoint training is used
        
    Returns:
        List of equation dictionaries, one per configuration
        Each element is: {memory_idx: {output_name: sympy_equation}}
    """
    
    # Lazy import PySR only when this function is called
    import pysr
    
    # Save current CUDA visibility and temporarily hide GPUs from Julia to avoid conflicts
    original_cuda_visible = os.environ.get('CUDA_VISIBLE_DEVICES', None)
    os.environ['CUDA_VISIBLE_DEVICES'] = ''  # Hide GPUs from Julia
    
    try:
        # Initialize Julia once at the beginning
        pysr.julia_helpers.init_julia(julia_project=None, quiet=False)
        print("✓ Julia initialized successfully")
    except Exception as e:
        print(f"Julia initialization warning: {e}")
        pass  # Already initialized, that's fine
    
    from pysr import PySRRegressor
    
    # Restore original CUDA visibility
    if original_cuda_visible is not None:
        os.environ['CUDA_VISIBLE_DEVICES'] = original_cuda_visible
    else:
        os.environ.pop('CUDA_VISIBLE_DEVICES', None)
    
    if len(stored_concepts) == 0:
        raise ValueError("No stored data for symbolic regression. Run forward passes with store_for_finetuning=True first.")
    
    # Move all data to CPU and convert to numpy to avoid GPU conflicts with Julia
    stored_concepts = stored_concepts.cpu()
    stored_targets = stored_targets.cpu()
    stored_selector_probs = stored_selector_probs.cpu()

    # Handle different target shapes
    if stored_targets.ndim == 1:
        stored_targets = stored_targets.reshape(-1, 1)
    
    # List to store equation sets for each configuration
    all_equation_sets = []
    
    # Process each PySR configuration
    for config_idx, pysr_params in enumerate(pysr_params_list):
        print("\n" + "="*70)
        print(f"SYMBOLIC REGRESSION - Configuration {config_idx + 1}/{len(pysr_params_list)}")
        print("="*70)
        print(f"Parameters: {pysr_params}")
        print("="*70 + "\n")
        
        # Dictionary to store equations for this configuration
        # Structure: {memory_idx: {output_name: sympy_equation}}
        all_equations = {i: {} for i in range(memory_size)}
        
        # With independent outputs, stored_selector_probs has shape (batch, memory_size, n_outputs, n_samples)
        # We need to process each output independently
        for output_idx in range(output_size):
            output_name = y_names[output_idx] if output_idx < len(y_names) else f"y_{output_idx}"
            
            # For each memory slot, collect samples where this memory was selected for this output
            for memory_idx in range(memory_size):
                # Get samples where this memory slot was selected for this specific output
                output_selector_probs = stored_selector_probs[:, :, output_idx, :]  # (batch, memory_size, n_samples)
                memory_mask = (output_selector_probs.argmax(dim=1).flatten() == memory_idx).numpy()
                n_samples_for_memory = memory_mask.sum()
                
                # Filter concepts and targets for this memory slot and output
                X_memory = stored_concepts[memory_mask]  # [n_samples_memory, n_concepts]
                y_target = stored_targets[memory_mask, output_idx]  # [n_samples_memory]
                
                # Skip if no samples for this memory slot
                if n_samples_for_memory == 0:
                    print(f"Memory slot {memory_idx}, output '{output_name}' has no samples. Using zero equation.")
                    all_equations[memory_idx][output_name] = sp.sympify("0")
                    continue
                
                subsample_size = 2000
                if task == 'classification' and disjoint_training:
                    indices = subsampling(X_memory)
                elif task == 'classification' and not disjoint_training and n_samples_for_memory > subsample_size:
                    indices = subsample_based_on_y_hat(y_target, subsample_size)
                elif task == 'regression' and n_samples_for_memory > subsample_size:
                    indices = np.random.choice(n_samples_for_memory, size=subsample_size, replace=False)
                else:
                    indices = None

                if indices is not None:
                    X_memory = X_memory[indices]
                    y_target = y_target[indices]

                # Validate data: check for NaN and Inf values
                if np.isnan(X_memory).sum()>0 or np.isinf(X_memory).sum()>0:
                    print(f"WARNING: X_memory contains NaN or Inf values for memory {memory_idx}, output '{output_name}'. Cleaning data.")
                    valid_mask = ~(np.isnan(X_memory).any(axis=1) | np.isinf(X_memory).any(axis=1))
                    X_memory = X_memory[valid_mask]
                    y_target = y_target[valid_mask]
                
                if np.isnan(y_target).sum()>0 or np.isinf(y_target).sum()>0:
                    print(f"WARNING: y_target contains NaN or Inf values for memory {memory_idx}, output '{output_name}'. Cleaning data.")
                    valid_mask = ~(np.isnan(y_target) | np.isinf(y_target))
                    X_memory_clean = X_memory[valid_mask]
                    y_target = y_target[valid_mask]
                else:
                    X_memory_clean = X_memory
                
                print(f"\nFitting PySR for output '{output_name}' (memory slot {memory_idx})...")
                print(f"  Input shape: {X_memory_clean.shape}, Target shape: {y_target.shape}")

                try:
                    # Create a fresh PySRRegressor for each equation
                    pysr_model = PySRRegressor(
                        **pysr_params,
                        verbosity=1,  
                        progress=True,
                    )

                    pysr_model.fit(X_memory_clean.cpu().numpy(), y_target.cpu().numpy())

                    # Get the best equation
                    equations_df = pysr_model.equations_
                    print(f"  Pareto front has {len(equations_df)} equations")

                    # Select the best equation based on task type
                    if task == 'classification':
                        # For classification: select lowest loss (best accuracy)
                        best_eq_row = equations_df.nsmallest(1, 'loss').iloc[0]
                    else:
                        # For regression: select highest score (balance loss and complexity)
                        best_eq_row = equations_df.nlargest(1, 'score').iloc[0]
                    sympy_eq = best_eq_row['sympy_format']
                    
                    # Rename variables from x0, x1, ... to concept names
                    for i, c_name in enumerate(c_names):
                        sympy_eq = sympy_eq.subs(sp.Symbol(f'x{i}'), sp.Symbol(c_name))
                    
                    all_equations[memory_idx][output_name] = sympy_eq
                    
                    print(f"  ✓ Best equation: {sympy_eq}")
                    print(f"    Loss: {best_eq_row['loss']:.6f}")
                    print(f"    Complexity: {best_eq_row['complexity']}")
                    print(f"    Score: {best_eq_row['score']:.6f}")

                    # Clean up the model
                    del pysr_model
                    
                except Exception as e:
                    print(f"  ✗ ERROR fitting PySR for memory {memory_idx}, output '{output_name}':")
                    print(f"    {type(e).__name__}: {str(e)}")
                    print(f"    Using fallback constant equation (mean value)")
                    if len(y_target) == 0:
                        print(f"    No samples available. Using zero as fallback.")
                        all_equations[memory_idx][output_name] = sp.sympify("0")
                    else:
                        mean_val = float(y_target.mean())
                        all_equations[memory_idx][output_name] = sp.sympify(str(mean_val))
                        print(f"    Fallback equation: {mean_val}")

        # Verify all equations are present
        for memory_idx in range(memory_size):
            for output_idx in range(output_size):
                output_name = y_names[output_idx] if output_idx < len(y_names) else f"y_{output_idx}"
                if output_name not in all_equations[memory_idx]:
                    raise ValueError(f"Missing equation for memory {memory_idx}, output {output_name}")
        
        all_equation_sets.append(all_equations)
        
        print(f"\n✓ Configuration {config_idx + 1} completed successfully!")
    
    print("\n" + "="*70)
    print(f"ALL SYMBOLIC REGRESSIONS COMPLETED")
    print(f"Total configurations processed: {len(all_equation_sets)}")
    print("="*70 + "\n")
    
    return all_equation_sets