#!/usr/bin/env python3
"""
Improved proposed method script for Entropy-Weighted Spatial Correlation Aggregation ELCM.
This script implements spatial correlation aggregation that leverages spatial correlations between
neighboring patches to refine entropy weights, promoting spatially consistent predictions while
suppressing noisy individual patches.
"""

import argparse
import torch
from typing import Dict, Tuple
import numpy as np
from scipy.spatial.distance import pdist, squareform
from scipy.stats import pearsonr
from datasets.test_loader import set_test_loader
import clip_w_local
from clip_w_local import clip
from tqdm import tqdm
from torch.nn import functional as F
from utils.eval_util import get_and_print_results, add_results, add_overall_results, save_results_to_json
import os


def get_test_labels(in_dataset: str):
    if in_dataset == 'ImageNet':
        loc = "label_names/imagenet_class_clean.npy"
        with open(loc, 'rb') as f:
            imagemodel_cls = np.load(f)
    else:
        raise ValueError(f"Invalid dataset: {in_dataset}")
    return imagemodel_cls


def compute_spatial_correlation_weights(probs, patch_grid_size=None, correlation_threshold=0.3, 
                                       spatial_decay=2.0, eps=1e-8):
    """
    Optimized spatial correlation weights computation using vectorized operations.
    
    Args:
        probs: Probability distribution over classes [batch, num_patches, num_classes]
        patch_grid_size: Size of the square patch grid (auto-calculated if None)
        correlation_threshold: Minimum correlation to be considered as neighbors
        spatial_decay: Decay factor for spatial distance weighting
        eps: Small epsilon for numerical stability
        
    Returns:
        spatial_weights: Spatial correlation weights [batch, num_patches]
    """
    batch_size, num_patches, num_classes = probs.shape
    
    # Auto-calculate patch_grid_size if not provided
    if patch_grid_size is None:
        patch_grid_size = int(np.sqrt(num_patches))
    
    # Ensure we have a square grid of patches
    assert num_patches == patch_grid_size * patch_grid_size, \
        f"Expected {patch_grid_size}x{patch_grid_size}={patch_grid_size**2} patches, got {num_patches}"
    
    # Create coordinate map for patches
    coords = np.array([(i, j) for i in range(patch_grid_size) 
                       for j in range(patch_grid_size)])
    
    # Precompute spatial distance matrix once
    spatial_distances = squareform(pdist(coords, metric='euclidean'))
    neighbor_mask = spatial_distances <= np.sqrt(2) + eps  # 8-neighborhood adjacency
    
    # Initialize spatial weights
    spatial_weights = np.ones((batch_size, num_patches), dtype=np.float32)
    
    # Process each batch
    for batch_idx in range(batch_size):
        batch_probs = probs[batch_idx]  # [num_patches, num_classes]
        
        # Vectorized correlation computation using numpy
        # Normalize probabilities for stable correlation computation
        prob_mean = np.mean(batch_probs, axis=1, keepdims=True)
        prob_centered = batch_probs - prob_mean
        prob_std = np.std(batch_probs, axis=1, keepdims=True) + eps
        prob_normalized = prob_centered / prob_std
        
        # Compute correlation matrix efficiently
        correlations = np.dot(prob_normalized, prob_normalized.T) / num_classes
        correlations = np.clip(correlations, -1.0, 1.0)  # Ensure valid correlation range
        
        # Process each patch's neighbors
        for i in range(num_patches):
            # Get neighbors for this patch
            neighbors = neighbor_mask[i]
            if np.sum(neighbors) <= 1:  # Only self or no neighbors
                continue
                
            # Get correlations and distances with neighbors (excluding self)
            neighbor_indices = np.where(neighbors)[0]
            self_mask = neighbor_indices != i
            if not np.any(self_mask):
                continue
                
            valid_neighbors = neighbor_indices[self_mask]
            neighbor_correlations = correlations[i, valid_neighbors]
            neighbor_distances = spatial_distances[i, valid_neighbors]
            
            # Weight correlations by spatial proximity
            spatial_decay_weights = 1.0 / (1.0 + spatial_decay * neighbor_distances)
            weighted_correlations = neighbor_correlations * spatial_decay_weights
            
            # Compute spatial correlation score
            positive_corr_mask = weighted_correlations > correlation_threshold
            
            if np.any(positive_corr_mask):
                # Amplify patches with high spatial correlation
                spatial_score = np.mean(weighted_correlations[positive_corr_mask])
                spatial_weights[batch_idx, i] *= (1.0 + 0.5 * spatial_score)  # Reduced amplification
            else:
                # Gentle dampening for poor spatial correlation
                mean_correlation = np.mean(weighted_correlations)
                if mean_correlation < 0:
                    spatial_weights[batch_idx, i] *= (1.0 + mean_correlation * 0.2)  # Reduced dampening
    
    return spatial_weights


def compute_class_conditional_factors(probs, beta=1.0, top_k=3, eps=1e-8):
    """
    Compute class-conditional scaling factors based on top-k class discrimination.
    
    Args:
        probs: Probability distribution over classes for each patch [batch, num_patches, num_classes]
        beta: Strength of class-conditional scaling
        top_k: Number of top classes to consider for discrimination
        eps: Small epsilon for numerical stability
        
    Returns:
        factors: Class-conditional scaling factors [batch, num_patches]
    """
    # Sort probabilities in descending order
    sorted_probs = np.sort(probs, axis=-1)[:, :, ::-1]  # [batch, patches, classes]
    
    # Get top-k probabilities
    k = min(top_k, probs.shape[-1])
    top_k_probs = sorted_probs[:, :, :k]  # [batch, patches, k]
    
    # Compute discrimination ratio
    max_prob = top_k_probs[:, :, 0]  # [batch, patches] - highest probability
    mean_top_k = np.mean(top_k_probs, axis=-1)  # [batch, patches] - mean of top-k
    
    # Class discrimination factor: how much top class dominates over mean of top-k
    discrimination_ratio = max_prob / (mean_top_k + eps)
    class_conditional_factors = np.power(discrimination_ratio, beta)
    
    return class_conditional_factors


def compute_enhanced_entropy_weights(probs, alpha=1.0, beta=1.0, gamma=0.5, top_k=3, 
                                   use_class_conditional=True, use_spatial_correlation=True, 
                                   patch_grid_size=None, correlation_threshold=0.3, 
                                   spatial_decay=2.0, eps=1e-8):
    """
    Compute enhanced entropy weights with class-conditional scaling and spatial correlation.
    
    Args:
        probs: Probability distribution over classes for each patch [batch, num_patches, num_classes]
        alpha: Weight for entropy scaling 
        beta: Strength of class-conditional scaling
        gamma: Weight for spatial correlation component
        top_k: Number of top classes to consider for discrimination
        use_class_conditional: Whether to apply class-conditional scaling
        use_spatial_correlation: Whether to apply spatial correlation weighting
        patch_grid_size: Size of the square patch grid for spatial correlation
        correlation_threshold: Minimum correlation threshold for neighbors
        spatial_decay: Decay factor for spatial distance weighting
        eps: Small epsilon for numerical stability
        
    Returns:
        weights: Enhanced entropy-based weights [batch, num_patches]
    """
    # Compute Shannon entropy: H = -Σ p_c * log(p_c)
    log_probs = np.log(probs + eps)  # Add epsilon for numerical stability
    entropy = -np.sum(probs * log_probs, axis=-1)  # [batch, num_patches]
    
    # Traditional entropy weights: w = exp(-α * H)
    entropy_weights = np.exp(-alpha * entropy)
    
    enhanced_weights = entropy_weights.copy()
    
    if use_class_conditional:
        # Compute class-conditional scaling factors
        class_factors = compute_class_conditional_factors(probs, beta=beta, top_k=top_k, eps=eps)
        enhanced_weights *= class_factors
    
    if use_spatial_correlation:
        # Compute spatial correlation weights
        spatial_weights = compute_spatial_correlation_weights(
            probs,
            correlation_threshold=correlation_threshold,
            spatial_decay=spatial_decay, eps=eps
        )
        # Apply spatial correlation with weight gamma
        enhanced_weights = enhanced_weights * (1.0 + gamma * (spatial_weights - 1.0))
    
    return enhanced_weights


def compute_topk_class_conditional_weights(probs, alpha=1.0, beta=1.0, gamma=0.5, temperature=1.0, k=16, top_k_classes=3, 
                                          entropy_threshold=None, use_class_conditional=True, 
                                          use_percentile_weights=True, percentile_beta=25.0, min_weight_gamma=0.1,
                                          use_spatial_correlation=False, patch_grid_size=None, 
                                          correlation_threshold=0.3, spatial_decay=2.0, eps=1e-8):
    """
    Compute top-K patch selection with entropy filtering, percentile-based weight stabilization,
    class-conditional weighting, and optional spatial correlation.
    
    This method combines:
    1. Class-conditional weighting based on top-k class discrimination
    2. Entropy filtering to remove highly confused patches
    3. Top-K selection of most confident patches
    4. Percentile-based entropy weight stabilization
    5. Optional spatial correlation weighting
    
    Args:
        probs: Probability distribution over classes for each patch [batch, num_patches, num_classes]
        alpha: Weight for entropy scaling (used only if use_percentile_weights=False)
        beta: Strength of class-conditional scaling
        gamma: Weight for spatial correlation component
        temperature: Temperature parameter for entropy computation control
        k: Number of top patches to select (default: 16 for 7x7 ViT patches)
        top_k_classes: Number of top classes for class-conditional discrimination (default: 3)
        entropy_threshold: Maximum entropy threshold for filtering (auto-computed if None)
        use_class_conditional: Whether to apply class-conditional scaling (default: True)
        use_percentile_weights: Whether to use percentile-based weight stabilization (default: True)
        percentile_beta: Percentile cutoff for weight stabilization (default: 25.0)
        min_weight_gamma: Minimum weight for high-entropy patches (default: 0.1)
        use_spatial_correlation: Whether to apply spatial correlation weighting
        patch_grid_size: Size of the square patch grid for spatial correlation
        correlation_threshold: Minimum correlation threshold for neighbors
        spatial_decay: Decay factor for spatial distance weighting
        eps: Small epsilon for numerical stability
        
    Returns:
        selected_weights: Weights for selected top-K patches [batch, num_patches] (zero for non-selected)
    """
    batch_size, num_patches, num_classes = probs.shape
    
    # Apply temperature scaling to probabilities for entropy computation
    temp_probs = probs / temperature
    temp_probs = temp_probs / np.sum(temp_probs, axis=-1, keepdims=True)  # Renormalize
    
    # Compute Shannon entropy: H = -Σ p_c * log(p_c)
    log_probs = np.log(temp_probs + eps)  # Add epsilon for numerical stability
    entropy = -np.sum(temp_probs * log_probs, axis=-1)  # [batch, num_patches]
    
    # Get maximum class probability for each patch (confidence score)
    patch_max_probs = np.max(probs, axis=-1)  # [batch, num_patches]
    
    # Compute class-conditional factors if enabled
    if use_class_conditional:
        class_factors = compute_class_conditional_factors(probs, beta=beta, top_k=top_k_classes, eps=eps)
    else:
        class_factors = np.ones_like(entropy)
    
    # Apply entropy filtering if threshold is provided
    if entropy_threshold is None:
        # Auto-compute threshold as 75th percentile of entropy to filter out most confused patches
        entropy_threshold = np.percentile(entropy, 75)
    
    # Create mask for patches with acceptable entropy (low confusion)
    entropy_filter_mask = entropy <= entropy_threshold  # [batch, num_patches]
    
    # Apply class-conditional scaling to confidence for better filtering
    class_scaled_confidence = patch_max_probs * class_factors
    
    # Set confidence of high-entropy patches to very low value for filtering
    filtered_confidence = class_scaled_confidence.copy()
    filtered_confidence[~entropy_filter_mask] = -1.0  # Very low confidence for high-entropy patches
    
    # Select top-K patches based on filtered confidence
    k = min(k, num_patches)  # Ensure k doesn't exceed available patches
    
    # Get indices of top-K patches for each batch item
    topk_indices = np.argsort(filtered_confidence, axis=1)[:, -k:]  # [batch, k]
    
    # Create weights for selected patches
    selected_weights = np.zeros_like(entropy)  # [batch, num_patches]
    
    for batch_idx in range(batch_size):
        selected_patch_indices = topk_indices[batch_idx]
        # Only consider patches that passed entropy filtering
        valid_selected = entropy_filter_mask[batch_idx, selected_patch_indices]
        
        if np.any(valid_selected):
            # Compute entropy weights only for valid selected patches
            valid_indices = selected_patch_indices[valid_selected]
            patch_entropies = entropy[batch_idx, valid_indices]
            
            if use_percentile_weights:
                # Percentile-based entropy weight stabilization
                if len(patch_entropies) > 1:
                    # Compute percentile cutoffs
                    low_entropy_threshold = np.percentile(patch_entropies, percentile_beta)
                    high_entropy_threshold = np.percentile(patch_entropies, 100 - percentile_beta)
                    
                    # Create stabilized weights based on percentile ranking
                    patch_weights = np.ones_like(patch_entropies)  # Start with max weight
                    
                    # Patches with entropy above high threshold get minimum weight
                    high_entropy_mask = patch_entropies >= high_entropy_threshold
                    patch_weights[high_entropy_mask] = min_weight_gamma
                    
                    # Patches between thresholds get linearly interpolated weights
                    mid_entropy_mask = (patch_entropies > low_entropy_threshold) & (patch_entropies < high_entropy_threshold)
                    if np.any(mid_entropy_mask):
                        mid_entropies = patch_entropies[mid_entropy_mask]
                        # Linear interpolation from 1.0 to min_weight_gamma
                        normalized_entropies = (mid_entropies - low_entropy_threshold) / (high_entropy_threshold - low_entropy_threshold)
                        patch_weights[mid_entropy_mask] = 1.0 - normalized_entropies * (1.0 - min_weight_gamma)
                else:
                    # Single patch case
                    patch_weights = np.ones_like(patch_entropies)
            else:
                # Original exponential entropy weighting
                patch_weights = np.exp(-alpha * patch_entropies)
            
            # Apply class-conditional scaling to selected patches
            if use_class_conditional:
                class_scaling = class_factors[batch_idx, valid_indices]
                patch_weights = patch_weights * class_scaling
            
            # Apply spatial correlation if enabled (lightweight version)
            if use_spatial_correlation and len(valid_indices) > 1:
                # Simplified spatial correlation for efficiency
                # Just check immediate neighbors without full correlation computation
                patch_grid_size = int(np.sqrt(num_patches))
                if num_patches == patch_grid_size * patch_grid_size:
                    spatial_factors = np.ones(len(valid_indices))
                    
                    for idx, patch_i in enumerate(valid_indices):
                        row = patch_i // patch_grid_size
                        col = patch_i % patch_grid_size
                        
                        # Check 4-neighborhood (simpler than 8-neighborhood)
                        neighbors = []
                        for dr, dc in [(-1,0), (1,0), (0,-1), (0,1)]:
                            nr, nc = row + dr, col + dc
                            if 0 <= nr < patch_grid_size and 0 <= nc < patch_grid_size:
                                neighbor_patch = nr * patch_grid_size + nc
                                if neighbor_patch in valid_indices:
                                    neighbors.append(neighbor_patch)
                        
                        # Simple neighbor consistency boost
                        if len(neighbors) > 0:
                            neighbor_count = len(neighbors)
                            spatial_factors[idx] = 1.0 + gamma * 0.1 * neighbor_count
                    
                    patch_weights = patch_weights * spatial_factors
            
            # Normalize weights to sum to 1 among selected patches
            patch_weights = patch_weights / (np.sum(patch_weights) + eps)
            
            selected_weights[batch_idx, valid_indices] = patch_weights
        else:
            # Fallback: if no patches pass filtering, use uniform weights on top-K by confidence
            uniform_weight = 1.0 / k
            selected_weights[batch_idx, selected_patch_indices] = uniform_weight
    
    return selected_weights


def get_ood_scores(model, method, loader, test_labels, lambda_local: float = 0.5, T: float = 1.0, alpha: float = 1.0, beta: float = 1.0, gamma: float = 0.5, entropy_temp: float = 1.0, args=None):
    to_np = lambda x: x.data.cpu().numpy()
    concat = lambda x: np.concatenate(x, axis=0)
    _score = []
    tokenizer = clip.tokenize
    tqdm_object = tqdm(loader, total=len(loader))
    
    with torch.no_grad():
        for images, labels in tqdm_object:
            labels = labels.long().cuda()
            images = images.cuda()
            global_features, local_features = model.encode_image(images)  # .float()

            global_features = global_features.float()
            local_features = local_features.float()

            global_features /= global_features.norm(dim=-1, keepdim=True)
            local_features /= local_features.norm(dim=-1, keepdim=True)

            text_inputs = tokenizer([f"a photo of a {c}" for c in test_labels])
            text_features = model.encode_text(text_inputs.cuda()).float()
            text_features /= text_features.norm(dim=-1, keepdim=True)   
            output_global = global_features @ text_features.T
            output_local = local_features @ text_features.T

            smax_global = to_np(F.softmax(output_global/ T, dim=1))
            smax_local = to_np(F.softmax(output_local/ T, dim=-1))  # batch, h*w, class

            if method == 'mcm':
                _score.append(-np.max(smax_global, axis=1)) 
            elif method == 'gl-mcm':
                global_score = -np.max(smax_global, axis=1)
                local_score = -np.max(smax_local, axis=(1, 2))
                _score.append(global_score+lambda_local*local_score)
            elif method == 'improved-elcm':
                # Improved ELCM: Top-K Patch Selection with Class-Conditional Weighting
                global_score = -np.max(smax_global, axis=1)
                
                # smax_local shape: [batch, num_patches, num_classes]
                # Compute top-K patch selection with class-conditional weighting
                selected_weights = compute_topk_class_conditional_weights(
                    smax_local, alpha=alpha, beta=beta, gamma=gamma, temperature=entropy_temp, 
                    k=getattr(args, 'topk', 16), 
                    top_k_classes=getattr(args, 'top_k_classes', 3),
                    entropy_threshold=getattr(args, 'entropy_threshold', None),
                    use_class_conditional=getattr(args, 'use_class_conditional', True),
                    use_percentile_weights=getattr(args, 'use_percentile_weights', True),
                    percentile_beta=getattr(args, 'percentile_beta', 25.0),
                    min_weight_gamma=getattr(args, 'min_weight_gamma', 0.1),
                    use_spatial_correlation=getattr(args, 'use_spatial_correlation', True),
                    correlation_threshold=getattr(args, 'correlation_threshold', 0.3),
                    spatial_decay=getattr(args, 'spatial_decay', 2.0)
                )  # [batch, num_patches]
                
                # Get maximum class probability for each patch
                patch_max_probs = np.max(smax_local, axis=-1)  # [batch, num_patches]
                
                # Compute class-conditional weighted local score
                local_score = -np.sum(selected_weights * patch_max_probs, axis=1)  # [batch]
                
                _score.append(global_score + lambda_local * local_score)
            else:
                raise NotImplementedError(f"Method {method} not implemented")
                
    return concat(_score)[:len(loader.dataset)].copy()   


def main(args: argparse.Namespace) -> None:
    """Run evaluation"""
    print("Starting Improved ELCM evaluation...")
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # Load model
    model, preprocess = clip_w_local.load(args.model_name)
    model = model.to(device)
    model.eval()
    
    id_data_loader = set_test_loader(args.root, 'imagenet', preprocess, args.sample_size, args.batch_size, args.seed)
    test_labels = get_test_labels(args.in_dataset)
    
    # Calculate in-distribution scores
    in_score = get_ood_scores(
        model=model, 
        method=args.method, 
        loader=id_data_loader, 
        test_labels=test_labels, 
        lambda_local=args.lambda_value, 
        T=args.T, 
        alpha=args.alpha,
        beta=args.beta,
        gamma=args.gamma,
        entropy_temp=args.entropy_temp,
        args=args
    )
    
    # Lists for evaluation
    auroc_list, fpr_list = [], []
    results_data = []
    
    # Evaluate out-of-distribution datasets
    out_datasets = ['iNaturalist', 'SUN', 'places365', 'Texture']
    
    scores_dict: Dict[str, np.ndarray] = {}
    scores_dict["ImageNet"] = in_score
    
    for out_dataset in out_datasets:
        print(f"Evaluating OOD dataset: {out_dataset}")
        ood_data_loader = set_test_loader(args.root, out_dataset, preprocess, args.sample_size, args.batch_size, args.seed)
        out_score = get_ood_scores(
            model=model, 
            method=args.method, 
            loader=ood_data_loader, 
            test_labels=test_labels, 
            lambda_local=args.lambda_value, 
            T=args.T, 
            alpha=args.alpha,
            beta=args.beta,
            gamma=args.gamma,
            entropy_temp=args.entropy_temp,
            args=args
        )

        results = get_and_print_results(
            args, in_score, out_score,
            auroc_list, fpr_list
        )

        scores_dict[out_dataset] = out_score
        # Save results
        results_data = add_results(results_data, args.method, results, out_dataset)

    # Add overall results to results_data
    results_data = add_overall_results(results_data, args.method, auroc_list, fpr_list)

    # Save scores to .npz
    np.savez(f"{args.output_dir}/scores.npz", **scores_dict)

    # Save results to JSON
    save_results_to_json(results_data, args.output_dir, "results.json")
    print("Improved ELCM evaluation completed")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Improved ELCM experiment script")
    parser.add_argument(
        "--output-dir", 
        type=str, 
        required=True,
        help="Path to output directory where results will be saved"
    )
    parser.add_argument(
        "--alpha", 
        type=float, 
        default=1.0, 
        help="Entropy weighting parameter (default: 1.0)"
    )
    parser.add_argument(
        "--beta", 
        type=float, 
        default=1.0, 
        help="Class-conditional scaling strength (default: 1.0)"
    )
    parser.add_argument(
        "--entropy-temp", 
        type=float, 
        default=1.0, 
        help="Temperature parameter for entropy computation (default: 1.0)"
    )
    parser.add_argument(
        "--topk", 
        type=int, 
        default=16, 
        help="Number of top patches to select after entropy filtering (default: 16)"
    )
    parser.add_argument(
        "--top-k-classes", 
        type=int, 
        default=3, 
        help="Number of top classes for class-conditional discrimination (default: 3)"
    )
    parser.add_argument(
        "--entropy-threshold", 
        type=float, 
        default=None, 
        help="Maximum entropy threshold for filtering patches (auto-computed if None)"
    )
    parser.add_argument(
        "--use-class-conditional", 
        action="store_true", 
        default=True,
        help="Enable class-conditional weighting (default: True)"
    )
    parser.add_argument(
        "--use-percentile-weights", 
        action="store_true", 
        default=True,
        help="Use percentile-based entropy weight stabilization (default: True)"
    )
    parser.add_argument(
        "--percentile-beta", 
        type=float, 
        default=25.0, 
        help="Percentile cutoff for weight stabilization (default: 25.0)"
    )
    parser.add_argument(
        "--min-weight-gamma", 
        type=float, 
        default=0.1, 
        help="Minimum weight for high-entropy patches (default: 0.1)"
    )
    parser.add_argument(
        "--gamma", 
        type=float, 
        default=0.5, 
        help="Weight for spatial correlation component (default: 0.5)"
    )
    parser.add_argument(
        "--use-spatial-correlation", 
        action="store_true", 
        default=False,
        help="Enable spatial correlation weighting (default: False)"
    )
    parser.add_argument(
        "--patch-grid-size", 
        type=int, 
        default=7, 
        help="Size of the square patch grid for spatial correlation (default: 7)"
    )
    parser.add_argument(
        "--correlation-threshold", 
        type=float, 
        default=0.3, 
        help="Minimum correlation threshold for neighbors (default: 0.3)"
    )
    parser.add_argument(
        "--spatial-decay", 
        type=float, 
        default=2.0, 
        help="Decay factor for spatial distance weighting (default: 2.0)"
    )
    
    args = parser.parse_args()
    
    # Set up arguments that match the experiment.py requirements
    full_args = argparse.Namespace()
    
    # Core arguments for Improved ELCM experiment
    full_args.root = "/datasets/LoCoOp"  # Path to dataset
    full_args.output_dir = args.output_dir
    full_args.seed = 1
    full_args.batch_size = 100
    full_args.model_name = 'ViT-B/16'
    full_args.in_dataset = 'ImageNet'
    full_args.method = 'improved-elcm'  # Use Improved ELCM method
    full_args.T = 1.0  # Temperature for softmax
    full_args.lambda_value = 0.5  # Weight for regularization loss (same as baseline)
    full_args.alpha = args.alpha  # Entropy weighting parameter
    full_args.beta = args.beta  # Class-conditional scaling strength
    full_args.gamma = args.gamma  # Weight for spatial correlation component
    full_args.entropy_temp = args.entropy_temp  # Temperature for entropy computation
    full_args.topk = args.topk  # Number of top patches to select
    full_args.top_k_classes = args.top_k_classes  # Number of top classes for discrimination
    full_args.entropy_threshold = args.entropy_threshold  # Entropy filtering threshold
    full_args.use_class_conditional = args.use_class_conditional  # Enable class-conditional weighting
    full_args.use_percentile_weights = args.use_percentile_weights  # Use percentile-based weight stabilization
    full_args.percentile_beta = args.percentile_beta  # Percentile cutoff for weight stabilization
    full_args.min_weight_gamma = args.min_weight_gamma  # Minimum weight for high-entropy patches
    full_args.use_spatial_correlation = args.use_spatial_correlation  # Enable spatial correlation weighting
    full_args.patch_grid_size = args.patch_grid_size  # Size of the square patch grid
    full_args.correlation_threshold = args.correlation_threshold  # Minimum correlation threshold
    full_args.spatial_decay = args.spatial_decay  # Spatial decay factor
    full_args.sample_size = 100  # Reduced sample size for faster testing
    
    print("Running Improved ELCM experiment with parameters:")
    print(f"  Root: {full_args.root}")
    print(f"  Output directory: {full_args.output_dir}")
    print(f"  Seed: {full_args.seed}")
    print(f"  Batch size: {full_args.batch_size}")
    print(f"  Model: {full_args.model_name}")
    print(f"  In-distribution dataset: {full_args.in_dataset}")
    print(f"  Method: {full_args.method}")
    print(f"  Temperature (T): {full_args.T}")
    print(f"  Lambda value: {full_args.lambda_value}")
    print(f"  Alpha (entropy weighting): {full_args.alpha}")
    print(f"  Beta (class-conditional scaling): {full_args.beta}")
    print(f"  Gamma (spatial correlation): {full_args.gamma}")
    print(f"  Entropy temperature: {full_args.entropy_temp}")
    print(f"  Top-K patches: {full_args.topk}")
    print(f"  Top-K classes: {full_args.top_k_classes}")
    print(f"  Entropy threshold: {full_args.entropy_threshold}")
    print(f"  Use class-conditional: {full_args.use_class_conditional}")
    print(f"  Use percentile weights: {full_args.use_percentile_weights}")
    print(f"  Percentile beta: {full_args.percentile_beta}")
    print(f"  Min weight gamma: {full_args.min_weight_gamma}")
    print(f"  Use spatial correlation: {full_args.use_spatial_correlation}")
    print(f"  Patch grid size: {full_args.patch_grid_size}")
    print(f"  Correlation threshold: {full_args.correlation_threshold}")
    print(f"  Spatial decay: {full_args.spatial_decay}")
    print(f"  Sample size: {full_args.sample_size}")
    print()
    
    # Create output directory if it doesn't exist
    os.makedirs(full_args.output_dir, exist_ok=True)
    
    # Run the main experiment
    main(full_args)