#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
CASMIR Ablation Study script for analyzing component contributions.

Usage:
    python scripts/ablation_study.py --artifacts_folder "artifacts_path" --dataset dataset_name
"""

import os
import sys

# Add parent directory to path for module imports
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

import json
import numpy as np
import pandas as pd
import torch
import joblib
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from sklearn.preprocessing import StandardScaler
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from sklearn.decomposition import PCA
import copy
import random
import glob
from pathlib import Path
from collections import OrderedDict
import seaborn as sns

# Internal module imports
from src.data.datasets import load_dataset, preprocess_data, split_data_stratified, create_balanced_dataset, map_shot_types, classify_bins_by_samples, freedman_diaconis_bins, calculate_shot_wise_mae
from src.models.basic_models import MLP, MLPadv, SimpleThreeMLPEnsemble
from src.models.CASMIR_V1 import CASMIR_V1, AdaSmoothMoEV1, precompute_density_and_boundaries, CAS_Module, TabularExpert, GatingNetwork, FeatureNormalizer
# Alias for backward compatibility
# Alias for backward compatibility
# Backward compatibility alias
CAS_Module_Alias = CAS_Module
from src.training.losses import ConRLoss, RankSimLoss, GAILossMD, BMCLossMD, WeightedL1Loss, WeightedMSELoss
from src.training.train_utils import train_pytorch_model, train_CASMIR_V1
# Alias for backward compatibility
train_casmir = train_CASMIR_V1
from src.utils.utils import calculate_balanced_weights, get_gmm
from config import CONFIG

print("All internal modules imported successfully")

# Device setup
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Training device: {DEVICE}")

def set_global_random_seed(seed=42):
    """Set seed for all random sources for complete reproducibility"""
    print(f" Setting all random sources to seed {seed}...")
    
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
        print("    CUDA deterministic mode enabled")
    
    try:
        torch.use_deterministic_algorithms(True)
        print("    PyTorch deterministic algorithms enabled")
    except AttributeError:
        print("     PyTorch version too old for deterministic algorithms")
    except RuntimeError as e:
        print(f"     Warning during deterministic algorithms setup: {e}")
    
    os.environ['PYTHONHASHSEED'] = str(seed)
    print(f" All random sources fixed to seed {seed}\n")

# =============================================================================
# Model variant classes for Ablation Study
# =============================================================================

class CAS_Module_FeatureOnly(nn.Module):
    """
    Feature-Only CAS module - uses only feature similarity without target similarity.
    Comparison version for verifying coupled smoothing effect.
    """
    def __init__(self, input_dim, k, feature_bw, label_bw, density_factor, 
                 strength_base, density_c, epsilon=1e-6, 
                 weight_init_strategy='random', adaptive_k=False):
        super().__init__()
        self.input_dim = input_dim
        self.k = int(k)
        self.feature_bw = feature_bw
        self.label_bw = label_bw  # kept for compatibility
        self.density_factor = density_factor
        self.strength_base = strength_base
        self.density_c = density_c
        self.epsilon = epsilon
        
        # Feature weight initialization
        if weight_init_strategy == 'random':
            self.raw_feature_weights = nn.Parameter(
                torch.randn(input_dim) * 0.1 + 1.0
            )
        elif weight_init_strategy == 'uniform':
            self.raw_feature_weights = nn.Parameter(
                torch.ones(input_dim)
            )
        else:  # 'zeros'
            self.raw_feature_weights = nn.Parameter(torch.zeros(input_dim))
        
        self.adaptive_k = adaptive_k
    
    def forward(self, x_batch, y_batch, densities_batch):
        """Apply Feature-Only CAS per batch"""
        batch_size = x_batch.shape[0]
        device = x_batch.device

        # Adaptive k value calculation
        if self.adaptive_k:
            effective_k = min(self.k, max(3, batch_size // 4))
        else:
            effective_k = self.k
        
        if batch_size <= effective_k:
            print(f"Warning: Feature-Only CAS - batch size({batch_size}) <= effective_k({effective_k}).")
            return x_batch

        # Convert learned weights to positive values
        positive_feature_weights = F.softplus(self.raw_feature_weights)

        # Weighted distance calculation
        diff_sq = (x_batch.unsqueeze(1) - x_batch.unsqueeze(0))**2
        weighted_diff_sq = diff_sq * positive_feature_weights.view(1, 1, -1)
        dist_sq_matrix = torch.sum(weighted_diff_sq, dim=-1)
        dist_sq_matrix_stable = F.relu(dist_sq_matrix)

        # Find k-NN
        try:
            current_k = min(effective_k, batch_size - 1)
            if current_k < 1:
                print(f"Warning: Feature-Only CAS - effective K ({current_k})) < 1.")
                return x_batch

            _, nn_indices = torch.topk(dist_sq_matrix_stable, current_k + 1, dim=1, largest=False, sorted=True)
            nn_indices = nn_indices[:, 1:]  # exclude self
        except RuntimeError as e:
            print(f"Error: Feature-Only CAS torch.topk execution failed: {e}")
            return x_batch

        # Get neighbor information
        neighbor_x = x_batch[nn_indices]
        neighbor_y = y_batch[nn_indices]  # calculated but not used in weights

        # Get density values
        if densities_batch.ndim > 1:
            densities_batch = densities_batch.squeeze()
        
        try:
            clamped_indices = torch.clamp(nn_indices, 0, densities_batch.shape[0] - 1)
            neighbor_densities = densities_batch[clamped_indices]
        except IndexError as e:
            print(f"Error: Feature-Only CAS neighbor density indexing failed: {e}")
            neighbor_densities = torch.full_like(nn_indices, densities_batch.mean(), 
                                               dtype=torch.float32, device=device)

        # Weight calculation - Feature-Only (key difference)
        # Use feature similarity only
        dist_sq_neighbors = torch.gather(dist_sq_matrix_stable, 1, nn_indices)
        sim_x = torch.exp(-dist_sq_neighbors / (2 * self.feature_bw**2))
        
        # Target similarity calculated but not used in weights
        # weights_unnorm = sim_x (core of Feature-Only)
        weights_unnorm = sim_x  # Exclude target similarity
        weights_sum = torch.sum(weights_unnorm, dim=1, keepdim=True) + self.epsilon
        weights = weights_unnorm / weights_sum

        # Calculate smoothed feature mean
        mu_smooth = torch.sum(weights.unsqueeze(-1) * neighbor_x, dim=1)

        # Calculate smoothing strength (inversely proportional to density)
        smoothing_strength = self.strength_base / (densities_batch * self.density_c + self.epsilon)
        smoothing_strength = torch.clamp(smoothing_strength, min=0.05, max=0.8)

        # Feature calibration (blending)
        x_calibrated = (1 - smoothing_strength.unsqueeze(1)) * x_batch + smoothing_strength.unsqueeze(1) * mu_smooth

        return x_calibrated

class CASMIR_MoEOnly(nn.Module):
    """
    Ablation Study 1: MoE structure only (CAS removed)
    - Disable CAS module and use only MoE
    - Pass input features directly to MoE
    """
    def __init__(self, input_dim, num_experts, expert_hidden_dims, gate_hidden_dims, 
                 cas_params=None, use_feature_norm=True, expert_dropout=0.2, gate_temperature=1.0):
        super().__init__()
        self.use_feature_norm = use_feature_norm
        
        # Keep feature normalization (for fair comparison)
        if use_feature_norm:
            self.feature_norm = FeatureNormalizer(input_dim, learnable=True)
        
        # CAS module removed
        self.cas_module = None
        
        # Keep MoE structure
        self.experts = nn.ModuleList([
            TabularExpert(input_dim, expert_hidden_dims, dropout_rate=expert_dropout)
            for _ in range(num_experts)
        ])
        
        self.gating_network = GatingNetwork(input_dim, gate_hidden_dims, num_experts, dropout_rate=expert_dropout*0.5)
        self.gate_temperature = nn.Parameter(torch.tensor(gate_temperature))

    def forward(self, x, y=None, density=None, apply_smoothing=True):
        # Apply feature normalization
        x_norm = self.feature_norm(x) if self.use_feature_norm else x
        
        # Skip CAS and use normalized features directly
        x_processed = x_norm
        
        # Gating calculation
        gate_logits = self.gating_network(x_norm)
        gate_weights = F.softmax(gate_logits / self.gate_temperature.clamp(min=0.5), dim=1)
        
        # Expert output calculation
        expert_outputs = torch.stack([expert(x_processed) for expert in self.experts], dim=1)
        final_prediction = torch.sum(gate_weights.unsqueeze(-1) * expert_outputs, dim=1)
        
        if self.training and apply_smoothing:
            return final_prediction, gate_weights, gate_logits
        else:
            return final_prediction

class CASMIR_CASOnly(nn.Module):
    """
    Ablation Study 2: CAS structure only (MoE removed, single Expert)
    - Remove MoE and use single Expert only
    - CAS + single Expert configuration
    """
    def __init__(self, input_dim, num_experts=1, expert_hidden_dims=None, gate_hidden_dims=None,
                 cas_params=None, use_feature_norm=True, expert_dropout=0.2, gate_temperature=1.0):
        super().__init__()
        self.use_feature_norm = use_feature_norm
        
        # Keep feature normalization
        if use_feature_norm:
            self.feature_norm = FeatureNormalizer(input_dim, learnable=True)
        
        # Keep CAS module
        if cas_params is not None:
            self.cas_module = CAS_Module(input_dim=input_dim, **cas_params)
        else:
            self.cas_module = None
        
        # Use single Expert only (MoE removed)
        self.single_expert = TabularExpert(input_dim, expert_hidden_dims, dropout_rate=expert_dropout)
        
        # Gating network removed (not needed for single Expert)
        self.gating_network = None
        self.gate_temperature = None

    def forward(self, x, y=None, density=None, apply_smoothing=True):
        # Apply feature normalization
        x_norm = self.feature_norm(x) if self.use_feature_norm else x
        
        # Apply CAS (training only)
        if self.training and apply_smoothing and y is not None and density is not None and self.cas_module is not None:
            x_processed = self.cas_module(x_norm, y, density)
        else:
            x_processed = x_norm
        
        # Predict with single Expert
        final_prediction = self.single_expert(x_processed)
        
        # Return None for gate info (no MoE)
        if self.training and apply_smoothing:
            return final_prediction, None, None
        else:
            return final_prediction

class CASMIR_CAS_FeatureOnly(nn.Module):
    """
    Ablation Study 4: Feature-Only CAS structure (Exclude target similarity)
    - Remove MoE and use single Expert only
    - Feature-Only CAS + single Expert configuration
    - Comparison model for verifying coupled smoothing effect
    """
    def __init__(self, input_dim, num_experts=1, expert_hidden_dims=None, gate_hidden_dims=None,
                 cas_params=None, use_feature_norm=True, expert_dropout=0.2, gate_temperature=1.0):
        super().__init__()
        self.use_feature_norm = use_feature_norm
        
        # Keep feature normalization
        if use_feature_norm:
            self.feature_norm = FeatureNormalizer(input_dim, learnable=True)
        
        # Use Feature-Only CAS module (key difference)
        if cas_params is not None:
            self.cas_module = CAS_Module_FeatureOnly(input_dim=input_dim, **cas_params)
        else:
            self.cas_module = None
        
        # Use single Expert only (MoE removed)
        self.single_expert = TabularExpert(input_dim, expert_hidden_dims, dropout_rate=expert_dropout)
        
        # Gating network removed (not needed for single Expert)
        self.gating_network = None
        self.gate_temperature = None

    def forward(self, x, y=None, density=None, apply_smoothing=True):
        # Apply feature normalization
        x_norm = self.feature_norm(x) if self.use_feature_norm else x
        
        # Apply Feature-Only CAS (training only)
        if self.training and apply_smoothing and y is not None and density is not None and self.cas_module is not None:
            x_processed = self.cas_module(x_norm, y, density)
        else:
            x_processed = x_norm
        
        # Predict with single Expert
        final_prediction = self.single_expert(x_processed)
        
        # Return None for gate info (no MoE)
        if self.training and apply_smoothing:
            return final_prediction, None, None
        else:
            return final_prediction

class CASMIR_Full(CASMIR_V1):
    """
    Ablation Study 3: Full CASMIR structure
    - Use both CAS + MoE
    - Same as CASMIR_V1
    """
    def __init__(self, input_dim, num_experts, expert_hidden_dims, gate_hidden_dims, 
                 cas_params, use_feature_norm=True, expert_dropout=0.2, gate_temperature=1.0):
        # Initialize same as parent class (CASMIR_V1)
        super().__init__(input_dim, num_experts, expert_hidden_dims, gate_hidden_dims, 
                        cas_params, use_feature_norm, expert_dropout, gate_temperature)

# =============================================================================
# Additional Model variant classes for Ablation Study
# =============================================================================

class CAS_Module_FixedStrength(nn.Module):
    """
    Fixed Smoothing Strength CAS module - uses fixed smoothing strength value.
    """
    def __init__(self, input_dim, k, feature_bw, label_bw, density_factor, 
                 strength_base, density_c, epsilon=1e-6, 
                 weight_init_strategy='random', adaptive_k=False, fixed_strength=0.0):
        super().__init__()
        self.input_dim = input_dim
        self.k = int(k)
        self.feature_bw = feature_bw
        self.label_bw = label_bw
        self.density_factor = density_factor
        self.strength_base = strength_base
        self.density_c = density_c
        self.epsilon = epsilon
        self.fixed_strength = fixed_strength  #  fixed smoothing strength
        
        # Feature weight initialization
        if weight_init_strategy == 'random':
            self.raw_feature_weights = nn.Parameter(
                torch.randn(input_dim) * 0.1 + 1.0
            )
        elif weight_init_strategy == 'uniform':
            self.raw_feature_weights = nn.Parameter(
                torch.ones(input_dim)
            )
        else:
            self.raw_feature_weights = nn.Parameter(torch.zeros(input_dim))
        
        self.adaptive_k = adaptive_k
    
    def forward(self, x_batch, y_batch, densities_batch):
        """Apply Fixed Strength CAS per batch"""
        batch_size = x_batch.shape[0]
        device = x_batch.device

        # Adaptive k value calculation
        if self.adaptive_k:
            effective_k = min(self.k, max(3, batch_size // 4))
        else:
            effective_k = self.k
        
        if batch_size <= effective_k:
            return x_batch

        # Convert learned weights to positive values
        positive_feature_weights = F.softplus(self.raw_feature_weights)

        # Weighted distance calculation
        diff_sq = (x_batch.unsqueeze(1) - x_batch.unsqueeze(0))**2
        weighted_diff_sq = diff_sq * positive_feature_weights.view(1, 1, -1)
        dist_sq_matrix = torch.sum(weighted_diff_sq, dim=-1)
        dist_sq_matrix_stable = F.relu(dist_sq_matrix)

        # Find k-NN
        try:
            current_k = min(effective_k, batch_size - 1)
            if current_k < 1:
                return x_batch
            _, nn_indices = torch.topk(dist_sq_matrix_stable, current_k + 1, dim=1, largest=False, sorted=True)
            nn_indices = nn_indices[:, 1:]
        except RuntimeError:
            return x_batch

        # Get neighbor information
        neighbor_x = x_batch[nn_indices]
        neighbor_y = y_batch[nn_indices]

        # Get density values
        if densities_batch.ndim > 1:
            densities_batch = densities_batch.squeeze()
        
        try:
            clamped_indices = torch.clamp(nn_indices, 0, densities_batch.shape[0] - 1)
            neighbor_densities = densities_batch[clamped_indices]
        except IndexError:
            neighbor_densities = torch.full_like(nn_indices, densities_batch.mean(), 
                                               dtype=torch.float32, device=device)

        # Weight calculation (same as original)
        dist_sq_neighbors = torch.gather(dist_sq_matrix_stable, 1, nn_indices)
        sim_x = torch.exp(-dist_sq_neighbors / (2 * self.feature_bw**2))
        
        dist_sq_y = (y_batch - neighbor_y.squeeze(-1))**2
        sim_y_base = torch.exp(-dist_sq_y / (2 * self.label_bw**2))
        if self.density_factor > 0:
            density_weight = self.density_factor / (neighbor_densities + self.epsilon)
            sim_y = sim_y_base * density_weight
        else:
            sim_y = sim_y_base
        
        weights_unnorm = sim_x * sim_y
        weights_sum = torch.sum(weights_unnorm, dim=1, keepdim=True) + self.epsilon
        weights = weights_unnorm / weights_sum

        # Calculate smoothed feature mean
        mu_smooth = torch.sum(weights.unsqueeze(-1) * neighbor_x, dim=1)

        # Use fixed smoothing strength (instead of adaptive)
        smoothing_strength = torch.full((batch_size,), self.fixed_strength, 
                                       dtype=torch.float32, device=device)

        # Feature calibration
        x_calibrated = (1 - smoothing_strength.unsqueeze(1)) * x_batch + smoothing_strength.unsqueeze(1) * mu_smooth

        return x_calibrated

class CAS_Module_NoLearnableMetric(nn.Module):
    """
    No Learnable Metric CAS module - w fixed to 1 (learnable metric removed).
    """
    def __init__(self, input_dim, k, feature_bw, label_bw, density_factor, 
                 strength_base, density_c, epsilon=1e-6, 
                 weight_init_strategy='random', adaptive_k=False):
        super().__init__()
        self.input_dim = input_dim
        self.k = int(k)
        self.feature_bw = feature_bw
        self.label_bw = label_bw
        self.density_factor = density_factor
        self.strength_base = strength_base
        self.density_c = density_c
        self.epsilon = epsilon
        
        #  Remove learnable weights - fix all to 1
        # Store as regular tensor, not Parameter
        self.register_buffer('fixed_feature_weights', torch.ones(input_dim))
        
        self.adaptive_k = adaptive_k
    
    def forward(self, x_batch, y_batch, densities_batch):
        """Apply No Learnable Metric CAS per batch"""
        batch_size = x_batch.shape[0]
        device = x_batch.device

        # Adaptive k value calculation
        if self.adaptive_k:
            effective_k = min(self.k, max(3, batch_size // 4))
        else:
            effective_k = self.k
        
        if batch_size <= effective_k:
            return x_batch

        #  Use fixed weights (all 1)
        positive_feature_weights = self.fixed_feature_weights

        # Weighted distance calculation (effectively unweighted)
        diff_sq = (x_batch.unsqueeze(1) - x_batch.unsqueeze(0))**2
        weighted_diff_sq = diff_sq * positive_feature_weights.view(1, 1, -1)
        dist_sq_matrix = torch.sum(weighted_diff_sq, dim=-1)
        dist_sq_matrix_stable = F.relu(dist_sq_matrix)

        # Find k-NN
        try:
            current_k = min(effective_k, batch_size - 1)
            if current_k < 1:
                return x_batch
            _, nn_indices = torch.topk(dist_sq_matrix_stable, current_k + 1, dim=1, largest=False, sorted=True)
            nn_indices = nn_indices[:, 1:]
        except RuntimeError:
            return x_batch

        # Get neighbor information
        neighbor_x = x_batch[nn_indices]
        neighbor_y = y_batch[nn_indices]

        # Get density values
        if densities_batch.ndim > 1:
            densities_batch = densities_batch.squeeze()
        
        try:
            clamped_indices = torch.clamp(nn_indices, 0, densities_batch.shape[0] - 1)
            neighbor_densities = densities_batch[clamped_indices]
        except IndexError:
            neighbor_densities = torch.full_like(nn_indices, densities_batch.mean(), 
                                               dtype=torch.float32, device=device)

        # Calculate neighbor distances
        dist_sq_neighbors = torch.gather(dist_sq_matrix_stable, 1, nn_indices)
        sim_x = torch.exp(-dist_sq_neighbors / (2 * self.feature_bw**2))
        
        dist_sq_y = (y_batch - neighbor_y.squeeze(-1))**2
        sim_y_base = torch.exp(-dist_sq_y / (2 * self.label_bw**2))
        if self.density_factor > 0:
            density_weight = self.density_factor / (neighbor_densities + self.epsilon)
            sim_y = sim_y_base * density_weight
        else:
            sim_y = sim_y_base
        
        weights_unnorm = sim_x * sim_y
        weights_sum = torch.sum(weights_unnorm, dim=1, keepdim=True) + self.epsilon
        weights = weights_unnorm / weights_sum

        # Calculate smoothed feature mean
        mu_smooth = torch.sum(weights.unsqueeze(-1) * neighbor_x, dim=1)

        # Smoothing strength calculation (same as original)
        smoothing_strength = self.strength_base / (densities_batch * self.density_c + self.epsilon)
        smoothing_strength = torch.clamp(smoothing_strength, min=0.05, max=0.8)

        # Feature calibration
        x_calibrated = (1 - smoothing_strength.unsqueeze(1)) * x_batch + smoothing_strength.unsqueeze(1) * mu_smooth

        return x_calibrated

class MixUpModule(nn.Module):
    """
    MixUp module - use MixUp instead of CAS
    """
    def __init__(self, alpha=0.2):
        super().__init__()
        self.alpha = alpha
    
    def forward(self, x_batch, y_batch):
        """Apply MixUp per batch"""
        batch_size = x_batch.shape[0]
        
        if batch_size <= 1:
            return x_batch
        
        # MixUp lambda sampling
        if self.alpha > 0:
            lam = np.random.beta(self.alpha, self.alpha, size=batch_size)
            lam = torch.tensor(lam, dtype=torch.float32, device=x_batch.device)
        else:
            lam = torch.ones(batch_size, dtype=torch.float32, device=x_batch.device)
        
        # Random shuffled indices
        index = torch.randperm(batch_size, device=x_batch.device)
        
        # Apply MixUp
        mixed_x = lam.unsqueeze(1) * x_batch + (1 - lam.unsqueeze(1)) * x_batch[index]
        
        return mixed_x

class CASMIR_FixedStrength(nn.Module):
    """
    Fixed Smoothing Strength model - CASMIR with fixed s_i
    """
    def __init__(self, input_dim, num_experts=1, expert_hidden_dims=None, gate_hidden_dims=None,
                 cas_params=None, use_feature_norm=True, expert_dropout=0.2, 
                 gate_temperature=1.0, fixed_strength=0.0):
        super().__init__()
        self.use_feature_norm = use_feature_norm
        
        # Feature normalization
        if use_feature_norm:
            self.feature_norm = FeatureNormalizer(input_dim, learnable=True)
        
        # Fixed Strength CAS module
        if cas_params is not None:
            self.cas_module = CAS_Module_FixedStrength(
                input_dim=input_dim, 
                fixed_strength=fixed_strength,
                **cas_params
            )
        else:
            self.cas_module = None
        
        # Use single Expert
        self.single_expert = TabularExpert(input_dim, expert_hidden_dims, dropout_rate=expert_dropout)
        self.gating_network = None
        self.gate_temperature = None

    def forward(self, x, y=None, density=None, apply_smoothing=True):
        x_norm = self.feature_norm(x) if self.use_feature_norm else x
        
        if self.training and apply_smoothing and y is not None and density is not None and self.cas_module is not None:
            x_processed = self.cas_module(x_norm, y, density)
        else:
            x_processed = x_norm
        
        final_prediction = self.single_expert(x_processed)
        
        if self.training and apply_smoothing:
            return final_prediction, None, None
        else:
            return final_prediction

class CASMIR_NoLearnableMetric(nn.Module):
    """
    No Learnable Metric model - CASMIR with w fixed to 1
    """
    def __init__(self, input_dim, num_experts=1, expert_hidden_dims=None, gate_hidden_dims=None,
                 cas_params=None, use_feature_norm=True, expert_dropout=0.2, gate_temperature=1.0):
        super().__init__()
        self.use_feature_norm = use_feature_norm
        
        if use_feature_norm:
            self.feature_norm = FeatureNormalizer(input_dim, learnable=True)
        
        # No Learnable Metric CAS module
        if cas_params is not None:
            self.cas_module = CAS_Module_NoLearnableMetric(input_dim=input_dim, **cas_params)
        else:
            self.cas_module = None
        
        self.single_expert = TabularExpert(input_dim, expert_hidden_dims, dropout_rate=expert_dropout)
        self.gating_network = None
        self.gate_temperature = None

    def forward(self, x, y=None, density=None, apply_smoothing=True):
        x_norm = self.feature_norm(x) if self.use_feature_norm else x
        
        if self.training and apply_smoothing and y is not None and density is not None and self.cas_module is not None:
            x_processed = self.cas_module(x_norm, y, density)
        else:
            x_processed = x_norm
        
        final_prediction = self.single_expert(x_processed)
        
        if self.training and apply_smoothing:
            return final_prediction, None, None
        else:
            return final_prediction

class CASMIR_MixUp(nn.Module):
    """
    MixUp + MoE model - MoE with MixUp instead of CAS
    """
    def __init__(self, input_dim, num_experts=3, expert_hidden_dims=None, gate_hidden_dims=None,
                 cas_params=None, use_feature_norm=True, expert_dropout=0.2, 
                 gate_temperature=1.0, mixup_alpha=0.2):
        super().__init__()
        self.use_feature_norm = use_feature_norm
        
        if use_feature_norm:
            self.feature_norm = FeatureNormalizer(input_dim, learnable=True)
        
        # MixUp module (instead of CAS)
        self.mixup_module = MixUpModule(alpha=mixup_alpha)
        
        # MoE structure
        self.experts = nn.ModuleList([
            TabularExpert(input_dim, expert_hidden_dims, dropout_rate=expert_dropout)
            for _ in range(num_experts)
        ])
        
        self.gating_network = GatingNetwork(input_dim, gate_hidden_dims, num_experts, dropout_rate=expert_dropout*0.5)
        self.gate_temperature = nn.Parameter(torch.tensor(gate_temperature))

    def forward(self, x, y=None, density=None, apply_smoothing=True):
        x_norm = self.feature_norm(x) if self.use_feature_norm else x
        
        # Apply MixUp (training only)
        if self.training and apply_smoothing and y is not None:
            x_processed = self.mixup_module(x_norm, y)
        else:
            x_processed = x_norm
        
        # MoE prediction
        gate_logits = self.gating_network(x_norm)
        gate_weights = F.softmax(gate_logits / self.gate_temperature.clamp(min=0.5), dim=1)
        
        expert_outputs = torch.stack([expert(x_processed) for expert in self.experts], dim=1)
        final_prediction = torch.sum(gate_weights.unsqueeze(-1) * expert_outputs, dim=1)
        
        if self.training and apply_smoothing:
            return final_prediction, gate_weights, gate_logits
        else:
            return final_prediction

# =============================================================================
# Artifact scanning and loading functions
# =============================================================================

def scan_adasmooth_artifacts(artifacts_folder, dataset_name):
    """Scan and return CASMIR artifacts only"""
    print(f" {artifacts_folder}in {dataset_name}  CASMIR artifacts...")
    
    if not os.path.exists(artifacts_folder):
        print(f" Artifact folder not found: {artifacts_folder}")
        return []
    
    # CASMIR pattern search - try CASMIR patterns
    pattern_casmir = f"V*_{dataset_name}_CASMIR_*"
    pattern_adasmooth = f"V*_{dataset_name}_AdaSmoothMoE_*"
    
    artifact_dirs = glob.glob(os.path.join(artifacts_folder, pattern_casmir))
    if not artifact_dirs:
        artifact_dirs = glob.glob(os.path.join(artifacts_folder, pattern_adasmooth))
    
    if not artifact_dirs:
        print(f"No CASMIR artifacts found for dataset: {dataset_name}")
        print(f"  Searched patterns: {pattern_casmir}, {pattern_adasmooth}")
        return []
    
    valid_artifacts = []
    for artifact_dir in sorted(artifact_dirs):
        folder_name = os.path.basename(artifact_dir)
        
        # Check required files
        required_files = ['metadata.json', 'hyperparams.json', 'preprocessor.joblib', 
                         'data_split_info.json', 'experiment_config.json']
        
        if all(os.path.exists(os.path.join(artifact_dir, f)) for f in required_files):
            valid_artifacts.append({
                'path': artifact_dir,
                'folder_name': folder_name,
                'algorithm': 'CASMIR',
                'dataset': dataset_name
            })
            print(f" Found: {folder_name}")
        else:
            print(f" Skip: {folder_name} (required files missing)")
    
    print(f" Total {len(valid_artifacts)} CASMIR artifacts found")
    return valid_artifacts

def load_artifact_metadata(artifact_dir):
    """Loading artifact metadata"""
    print(f" Loading artifact metadata: {artifact_dir}")
    
    with open(os.path.join(artifact_dir, 'metadata.json'), 'r', encoding='utf-8') as f:
        metadata = json.load(f)
    
    with open(os.path.join(artifact_dir, 'hyperparams.json'), 'r', encoding='utf-8') as f:
        hyperparams = json.load(f)
    
    preprocessor = joblib.load(os.path.join(artifact_dir, 'preprocessor.joblib'))
    
    with open(os.path.join(artifact_dir, 'data_split_info.json'), 'r', encoding='utf-8') as f:
        data_split_info = json.load(f)
    
    with open(os.path.join(artifact_dir, 'experiment_config.json'), 'r', encoding='utf-8') as f:
        experiment_config = json.load(f)
    
    print(f" Metadata loaded: {metadata['algorithm_name']} on {metadata['dataset_name']}")
    return metadata, hyperparams, preprocessor, data_split_info, experiment_config

# =============================================================================
# Model creation and training functions
# =============================================================================

def create_ablation_models(hyperparams, experiment_config, input_dim):
    """Create model variants for Ablation Study (4 basic + 5 additional = 9 total)"""
    print(" Creating Ablation Study models...")
    
    # Extract hyperparameters
    if 'expert_dim1' in hyperparams and 'expert_dim2' in hyperparams:
        expert_hidden_dims = [hyperparams['expert_dim1'], hyperparams['expert_dim2']]
    else:
        global_config = experiment_config.get('global_config', {})
        expert_hidden_dims = global_config.get('adasmooth_expert_hidden_dims', [32, 16])
        
    if 'gate_dim1' in hyperparams:
        gate_hidden_dims = [hyperparams['gate_dim1']]
    else:
        global_config = experiment_config.get('global_config', {})
        gate_hidden_dims = global_config.get('adasmooth_gate_hidden_dims', [32])
    
    # CAS parameters
    cas_params = {
        'k': hyperparams.get('k_neighbors', 10),
        'feature_bw': hyperparams.get('feature_bw', 1.5),
        'label_bw': hyperparams.get('label_bw', 10.0),
        'density_factor': hyperparams.get('density_factor', 0.1),
        'strength_base': hyperparams.get('strength_base', 0.6),
        'density_c': hyperparams.get('density_c', 20.0),
        'epsilon': 1e-6
    }
    
    # Common parameters
    common_params = {
        'input_dim': input_dim,
        'num_experts': 3,
        'expert_hidden_dims': expert_hidden_dims,
        'gate_hidden_dims': gate_hidden_dims,
        'cas_params': cas_params,
        'use_feature_norm': True,
        'expert_dropout': 0.2,
        'gate_temperature': 1.0
    }
    
    # Create models
    models = OrderedDict()
    
    print("\n Basic Ablation Study models (4):")
    
    # 1. MoE Only (CAS removed)
    models['MoE_Only'] = CASMIR_MoEOnly(**common_params).to(DEVICE)
    print(f" MoE Only Create models: expert_dims={expert_hidden_dims}, gate_dims={gate_hidden_dims}")
    
    # 2. CAS Only (MoE removed) - Coupled CAS
    models['CAS_Only'] = CASMIR_CASOnly(**common_params).to(DEVICE)
    print(f" CAS Only (Coupled) Create models: expert_dims={expert_hidden_dims}")
    
    # 3. CAS Feature-Only (MoE removed)
    models['CAS_Feature_Only'] = CASMIR_CAS_FeatureOnly(**common_params).to(DEVICE)
    print(f" CAS Feature-Only Create models: expert_dims={expert_hidden_dims}")
    
    # 4. CASMIR (full model) - Coupled CAS + MoE
    models['CASMIR'] = CASMIR_Full(**common_params).to(DEVICE)
    print(f" CASMIR model created: expert_dims={expert_hidden_dims}, gate_dims={gate_hidden_dims}")
    
    print("\n Additional Ablation Study models (5):")
    
    # 5. Smooth Strength = 0 (no smoothing)
    models['Smooth_Strength_0'] = CASMIR_FixedStrength(
        **{**common_params, 'num_experts': 1, 'fixed_strength': 0.0}
    ).to(DEVICE)
    print(f" Smooth Strength 0 model created (no smoothing)")
    
    # 6. Smooth Strength = 1 (full smoothing)
    models['Smooth_Strength_1'] = CASMIR_FixedStrength(
        **{**common_params, 'num_experts': 1, 'fixed_strength': 1.0}
    ).to(DEVICE)
    print(f" Smooth Strength 1 model created (full smoothing)")
    
    # 7. CAS No Learnable Metric (w=1 fixed)
    models['CAS_No_Learnable_Metric'] = CASMIR_NoLearnableMetric(
        **{**common_params, 'num_experts': 1}
    ).to(DEVICE)
    print(f" CAS No Learnable Metric model created (w=1 fixed)")
    
    # 8. MixUp + MoE
    models['MixUp_MoE'] = CASMIR_MixUp(
        **{**common_params, 'mixup_alpha': 0.2}
    ).to(DEVICE)
    print(f" MixUp + MoE model created: alpha=0.2")
    
    print(f"\nTotal {len(models)} models created!")
    
    return models

def setup_optimizer_and_loss(model, hyperparams):
    """Set up optimizer and loss function"""
    optimizer_name = hyperparams.get('optimizer', 'Adam')
    lr = hyperparams.get('lr', 0.001)
    weight_decay = hyperparams.get('weight_decay', 0.0)
    
    if optimizer_name == 'AdamW':
        optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    elif optimizer_name == 'Adam':
        optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    else:
        optimizer = optim.Adam(model.parameters(), lr=lr)
    
    criterion = nn.MSELoss()
    return optimizer, criterion

def train_single_model(model, model_name, optimizer, criterion, X_train, y_train, X_val, y_val, 
                      densities_train=None, experiment_config=None, max_epochs=600, patience=600):
    """Train single model"""
    print(f" {model_name} model training started...")
    
    batch_size = experiment_config.get('batch_size', 32) if experiment_config else 32
    
    # Convert data to tensors
    X_train_tensor = torch.tensor(X_train, dtype=torch.float32).to(DEVICE)
    y_train_tensor = torch.tensor(y_train.reshape(-1, 1), dtype=torch.float32).to(DEVICE)
    X_val_tensor = torch.tensor(X_val, dtype=torch.float32).to(DEVICE)
    y_val_tensor = torch.tensor(y_val.reshape(-1, 1), dtype=torch.float32).to(DEVICE)
    
    # If density info exists and model uses CAS
    densities_train_tensor = None
    if densities_train is not None and hasattr(model, 'cas_module') and model.cas_module is not None:
        densities_train_tensor = torch.tensor(densities_train, dtype=torch.float32).to(DEVICE)
    
    best_val_loss = float('inf')
    best_model_state = None
    patience_counter = 0
    
    # Training loop
    for epoch in range(max_epochs):
        model.train()
        
        train_loss = 0.0
        num_batches = 0
        
        for i in range(0, len(X_train_tensor), batch_size):
            batch_X = X_train_tensor[i:i+batch_size]
            batch_y = y_train_tensor[i:i+batch_size]
            
            # Density info for CAS
            batch_density = None
            if densities_train_tensor is not None:
                batch_density = densities_train_tensor[i:i+batch_size]
            
            optimizer.zero_grad()
            
            # Pass different inputs depending on model
            if hasattr(model, 'cas_module') and model.cas_module is not None:
                outputs = model(batch_X, batch_y, batch_density, apply_smoothing=True)
            else:
                outputs = model(batch_X, apply_smoothing=True)
            
            # Process output (use first element if tuple)
            if isinstance(outputs, tuple):
                main_output = outputs[0]
            else:
                main_output = outputs
                
            loss = criterion(main_output, batch_y)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            num_batches += 1
        
        avg_train_loss = train_loss / num_batches
        
        # Validation evaluation
        model.eval()
        with torch.no_grad():
            val_outputs = model(X_val_tensor, apply_smoothing=False)
            
            if isinstance(val_outputs, tuple):
                main_val_output = val_outputs[0]
            else:
                main_val_output = val_outputs
                
            val_loss = nn.MSELoss()(main_val_output, y_val_tensor).item()
        
        # Early stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model_state = copy.deepcopy(model.state_dict())
            patience_counter = 0
        else:
            patience_counter += 1
        
        # Print progress
        if epoch % 50 == 0 or patience_counter == 0:
            print(f"  Epoch {epoch:3d}: Train Loss: {avg_train_loss:.4f}, Val Loss: {val_loss:.4f}, Best: {best_val_loss:.4f}")
        
        # Early stopping
        if patience_counter >= patience:
            print(f" {model_name} Early stopping at epoch {epoch}")
            break
    
    # Restore best model
    if best_model_state is not None:
        model.load_state_dict(best_model_state)
        print(f" {model_name} Best validation model restored (val_loss: {best_val_loss:.4f})")
    
    return model, best_val_loss

# =============================================================================
# Evaluation and visualization functions
# =============================================================================

def evaluate_models(models, X_test_balanced, y_test_balanced, test_balanced_shot_indices,
                    X_test_original, y_test_original, test_original_shot_mapping):
    """Evaluate models (Balanced + Original Imbalanced Test Sets)"""
    print("\n Evaluating model performance...")
    
    results = OrderedDict()
    
    for model_name, model in models.items():
        print(f"\n {model_name} Evaluating...")
        
        model.eval()
        with torch.no_grad():
            # Balanced Test Set evaluation
            X_test_balanced_tensor = torch.tensor(X_test_balanced, dtype=torch.float32).to(DEVICE)
            outputs_balanced = model(X_test_balanced_tensor, apply_smoothing=False)
            if isinstance(outputs_balanced, tuple):
                y_pred_balanced = outputs_balanced[0].cpu().numpy().flatten()
            else:
                y_pred_balanced = outputs_balanced.cpu().numpy().flatten()
            
            # Original Imbalanced Test Set evaluation
            X_test_original_tensor = torch.tensor(X_test_original, dtype=torch.float32).to(DEVICE)
            outputs_original = model(X_test_original_tensor, apply_smoothing=False)
            if isinstance(outputs_original, tuple):
                y_pred_original = outputs_original[0].cpu().numpy().flatten()
            else:
                y_pred_original = outputs_original.cpu().numpy().flatten()
        
        # Calculate Shot-wise MAE
        mae_results_balanced = calculate_shot_wise_mae(y_test_balanced.flatten(), y_pred_balanced, test_balanced_shot_indices)
        mae_results_original = calculate_shot_wise_mae(y_test_original.flatten(), y_pred_original, test_original_shot_mapping)
        
        results[model_name] = {
            'predictions_balanced': y_pred_balanced,
            'predictions_original': y_pred_original,
            'mae_results_balanced': mae_results_balanced,
            'mae_results_original': mae_results_original
        }
        
        print(f"   === Balanced Test Set ===")
        print(f"   Few-shot MAE: {mae_results_balanced['few']:.6f}")
        print(f"   Medium-shot MAE: {mae_results_balanced['medium']:.6f}")
        print(f"   Many-shot MAE: {mae_results_balanced['many']:.6f}")
        print(f"   Overall MAE: {mae_results_balanced['overall']:.6f}")
        
        print(f"   === Original Imbalanced Test Set ===")
        print(f"   Few-shot MAE: {mae_results_original['few']:.6f}")
        print(f"   Medium-shot MAE: {mae_results_original['medium']:.6f}")
        print(f"   Many-shot MAE: {mae_results_original['many']:.6f}")
        print(f"   Overall MAE: {mae_results_original['overall']:.6f}")
    
    return results

# =============================================================================
#  Prediction Stability Analysis (Bias-Variance Trade-off proof)
# =============================================================================

def analyze_prediction_stability(models, X_test, y_test, shot_indices, n_runs=20):
    """
    Prediction Stability analysis using Monte Carlo Dropout
    
    Reviewer question: "bias-variance trade-off" empirical proof
    Method: Measure prediction variance with Dropout enabled
    
    Key points:
    - Few-shot: Adaptive greatly reduces variance (smoothing effect)
    - Many-shot: Adaptive keeps low variance (prevents bias increase)
    """
    print(f"\n Analyzing Prediction Stability (MC Dropout, n={n_runs})...")
    
    results = []
    
    for shot_type, indices in shot_indices.items():
        if len(indices) == 0:
            continue
        
        print(f"\n {shot_type.capitalize()}-shot analysis (n={len(indices)} samples)...")
        
        X_shot = torch.tensor(X_test[indices], dtype=torch.float32).to(DEVICE)
        y_shot = y_test[indices]
        
        for model_name, model in models.items():
            all_predictions = []
            
            # Monte Carlo Dropout: measure variance over multiple predictions
            model.train()  # Dropout enabled
            
            for run in range(n_runs):
                with torch.no_grad():
                    outputs = model(X_shot, apply_smoothing=False)
                    
                    if isinstance(outputs, tuple):
                        pred = outputs[0].cpu().numpy().flatten()
                    else:
                        pred = outputs.cpu().numpy().flatten()
                    
                    all_predictions.append(pred)
            
            all_predictions = np.array(all_predictions)  # (n_runs, n_samples)
            
            # Calculate prediction variance (mean of per-sample variance)
            pred_variance = all_predictions.var(axis=0).mean()
            pred_std = all_predictions.std(axis=0).mean()
            
            # MAE with mean prediction
            mean_pred = all_predictions.mean(axis=0)
            mae = np.abs(mean_pred - y_shot).mean()
            
            results.append({
                'shot_type': shot_type,
                'model': model_name,
                'prediction_variance': pred_variance,
                'prediction_std': pred_std,
                'mae': mae,
                'n_samples': len(indices)
            })
            
            print(f"  {model_name:20s}: variance={pred_variance:.6f}, std={pred_std:.4f}, mae={mae:.4f}")
    
    df_results = pd.DataFrame(results)
    print(f"\n Prediction Stability analysis complete")
    
    return df_results

def plot_prediction_stability(df_results, dataset_name, output_dir="./outputs"):
    """
    Prediction Stability visualization (Bias-Variance Trade-off proof)
    """
    print("\n Creating Prediction Stability graph...")
    
    # Compare key models only (Strength_0, Adaptive, Strength_1)
    key_models = ['Strength_0', 'CAS_Only', 'Strength_1', 'CASMIR']
    df_plot = df_results[df_results['model'].isin(key_models)].copy()
    
    # Model name mapping
    model_labels = {
        'Strength_0': '$s_i$ = 0 (No Smoothing)',
        'CAS_Only': 'Adaptive $s_i$ (CAS Only)',
        'Strength_1': '$s_i$ = 1 (Full Smoothing)',
        'CASMIR': 'CASMIR (Full Model)'
    }
    df_plot['model_label'] = df_plot['model'].map(model_labels)
    
    # Color definitions
    colors = {
        '$s_i$ = 0 (No Smoothing)': '#E74C3C',
        'Adaptive $s_i$ (CAS Only)': '#27AE60',
        '$s_i$ = 1 (Full Smoothing)': '#F39C12',
        'CASMIR (Full Model)': '#3498DB'
    }
    
    fig, ax = plt.subplots(figsize=(12, 7))
    fig.patch.set_facecolor('#FDFEFE')
    
    shot_types = ['few', 'medium', 'many']
    x = np.arange(len(shot_types))
    width = 0.2
    
    # Bar chart for each model
    for i, model_label in enumerate(model_labels.values()):
        model_data = df_plot[df_plot['model_label'] == model_label]
        
        variances = []
        for shot in shot_types:
            shot_data = model_data[model_data['shot_type'] == shot]
            if len(shot_data) > 0:
                variances.append(shot_data['prediction_variance'].values[0])
            else:
                variances.append(0)
        
        bars = ax.bar(x + i*width, variances, width, 
                     label=model_label, color=colors[model_label], 
                     alpha=0.85, edgecolor='white', linewidth=1.5)
        
        # Show values
        for j, (bar, var) in enumerate(zip(bars, variances)):
            if var > 0:
                height = bar.get_height()
                ax.text(bar.get_x() + bar.get_width()/2., height,
                       f'{var:.4f}', ha='center', va='bottom', 
                       fontsize=9, fontweight='bold')
    
    ax.set_xlabel('Shot Type (Data Uncertainty)', fontsize=13, fontweight='bold')
    ax.set_ylabel('Prediction Variance\n(Lower = More Stable)', fontsize=13, fontweight='bold')
    ax.set_title(f'Empirical Evidence of Bias-Variance Trade-off\n{dataset_name} Dataset\n'
                 f'Monte Carlo Dropout Analysis (n=20 runs)',
                 fontsize=15, fontweight='bold', pad=20)
    ax.set_xticks(x + width * 1.5)
    ax.set_xticklabels(['Few-shot\n(High Uncertainty)', 
                       'Medium-shot\n(Moderate Uncertainty)', 
                       'Many-shot\n(Low Uncertainty)'],
                      fontsize=11)
    ax.legend(fontsize=10, loc='upper right')
    ax.grid(True, alpha=0.3, axis='y', linestyle='--')
    ax.set_facecolor('#FDFDFD')
    
    # Highlight Few-shot variance reduction
    few_data = df_plot[df_plot['shot_type'] == 'few']
    
    if 'Adaptive $s_i$ (CAS Only)' in few_data['model_label'].values and \
       '$s_i$ = 0 (No Smoothing)' in few_data['model_label'].values:
        
        adaptive_var = few_data[few_data['model_label'] == 'Adaptive $s_i$ (CAS Only)']['prediction_variance'].values[0]
        s0_var = few_data[few_data['model_label'] == '$s_i$ = 0 (No Smoothing)']['prediction_variance'].values[0]
        
        if s0_var > 0:
            reduction = ((s0_var - adaptive_var) / s0_var) * 100
            
            # Show variance reduction with arrow
            ax.annotate(f'Variance\nReduction:\n{reduction:.1f}%',
                       xy=(0 + width, adaptive_var), 
                       xytext=(0.5, s0_var * 0.7),
                       arrowprops=dict(arrowstyle='->', lw=2.5, color='#27AE60'),
                       fontsize=12, fontweight='bold', color='#27AE60',
                       bbox=dict(boxstyle='round,pad=0.5', facecolor='#D5F4E6', 
                                edgecolor='#27AE60', linewidth=2))
    
    plt.tight_layout()
    
    # Save
    os.makedirs(output_dir, exist_ok=True)
    timestamp = pd.Timestamp.now().strftime("%Y%m%d_%H%M%S")
    output_path = os.path.join(output_dir, f"prediction_stability_{dataset_name}_{timestamp}.png")
    plt.savefig(output_path, dpi=300, bbox_inches='tight', facecolor='#FDFEFE')
    print(f" Prediction Stability graph saved: {output_path}")
    
    plt.close()
    
    return output_path

def print_stability_summary(df_stability, dataset_name):
    """Prediction Stability analysis summary"""
    print(f"\n{'='*80}")
    print(f" Prediction Stability Analysis Summary - {dataset_name}")
    print(f"{'='*80}")
    
    key_models = ['Strength_0', 'CAS_Only', 'Strength_1', 'CASMIR']
    
    for shot_type in ['few', 'medium', 'many']:
        shot_data = df_stability[df_stability['shot_type'] == shot_type]
        
        if len(shot_data) == 0:
            continue
        
        print(f"\n{shot_type.upper()}-SHOT ({shot_data.iloc[0]['n_samples']} samples):")
        print(f"{'Model':<25s} {'Variance':<12s} {'Std':<10s} {'MAE':<10s}")
        print(f"{'-'*60}")
        
        for model in key_models:
            model_data = shot_data[shot_data['model'] == model]
            if len(model_data) > 0:
                data = model_data.iloc[0]
                print(f"{model:<25s} {data['prediction_variance']:<12.6f} "
                      f"{data['prediction_std']:<10.4f} {data['mae']:<10.4f}")
    
    # Key findings
    print(f"\n{'='*80}")
    print(" KEY FINDINGS (Bias-Variance Trade-off Evidence):")
    print(f"{'='*80}")
    
    # Few-shot analysis
    few_data = df_stability[df_stability['shot_type'] == 'few']
    if 'CAS_Only' in few_data['model'].values and 'Strength_0' in few_data['model'].values:
        adaptive_var = few_data[few_data['model'] == 'CAS_Only']['prediction_variance'].values[0]
        s0_var = few_data[few_data['model'] == 'Strength_0']['prediction_variance'].values[0]
        
        if s0_var > 0:
            reduction = ((s0_var - adaptive_var) / s0_var) * 100
            
            print(f"\n Few-shot (High Uncertainty):")
            print(f"   - s_i=0 variance: {s0_var:.6f} (unstable)")
            print(f"   - Adaptive variance: {adaptive_var:.6f} (stable)")
            print(f"   - Variance Reduction: {reduction:.1f}%")
            print(f"   → Adaptive s_i successfully reduces variance in uncertain regions")
    
    # Many-shot analysis
    many_data = df_stability[df_stability['shot_type'] == 'many']
    if 'CAS_Only' in many_data['model'].values and 'Strength_0' in many_data['model'].values:
        adaptive_var = many_data[many_data['model'] == 'CAS_Only']['prediction_variance'].values[0]
        s0_var = many_data[many_data['model'] == 'Strength_0']['prediction_variance'].values[0]
        s1_var = many_data[many_data['model'] == 'Strength_1']['prediction_variance'].values[0]
        
        print(f"\n Many-shot (Low Uncertainty):")
        print(f"   - s_i=0 variance: {s0_var:.6f}")
        print(f"   - Adaptive variance: {adaptive_var:.6f}")
        print(f"   - s_i=1 variance: {s1_var:.6f}")
        print(f"   → Adaptive maintains low variance without over-smoothing (low bias)")
    
    print(f"\n{'='*80}")
    print("CONCLUSION:")
    print("  Adaptive s_i intelligently navigates the bias-variance trade-off:")
    print("  1. High uncertainty (Few-shot) → High s_i → Reduces variance")
    print("  2. Low uncertainty (Many-shot) → Low s_i → Maintains low bias")
    print("  This provides quantitative evidence beyond qualitative description.")
    print(f"{'='*80}\n")

def create_comparison_visualization(results, dataset_name, output_dir="./outputs"):
    """Performance comparison visualization (Balanced + Original Imbalanced Test Sets) - 9 models"""
    print("\n Creating performance comparison visualization...")
    
    # Pastel color definitions (for 9 models - extended)
    pastel_colors = [
        '#AED6F1', '#F8BBD9', '#ABEBC6', '#D2B4DE',  # original 4
        '#F9E79F', '#FAD7A0', '#D5F4E6', '#E8DAEF', '#FADBD8'  # additional 5
    ]
    plt.style.use('default')  # Use default style
    
    # Organize results - include both test sets
    model_names = list(results.keys())
    metrics = ['few', 'medium', 'many', 'overall']
    
    data_balanced = []
    data_original = []
    
    for model_name in model_names:
        mae_results_balanced = results[model_name]['mae_results_balanced']
        mae_results_original = results[model_name]['mae_results_original']
        
        for metric in metrics:
            data_balanced.append({
                'Model': model_name,
                'Shot_Type': metric.replace('_', '-').title() if metric != 'overall' else 'Overall',
                'MAE': mae_results_balanced[metric],
                'Test_Set': 'Balanced'
            })
            data_original.append({
                'Model': model_name,
                'Shot_Type': metric.replace('_', '-').title() if metric != 'overall' else 'Overall',
                'MAE': mae_results_original[metric],
                'Test_Set': 'Original_Imbalanced'
            })
    
    # Combine two dataframes
    df_results = pd.DataFrame(data_balanced + data_original)
    
    # Visualization - pastel colors (extended layout for 9 models)
    fig, ((ax1, ax2, ax3), (ax4, ax5, ax6)) = plt.subplots(2, 3, figsize=(28, 14))
    fig.patch.set_facecolor('#FDFEFE')  # Light gray background
    
    # 1. Balanced Test Set - Shot-wise comparison
    df_balanced = df_results[df_results['Test_Set'] == 'Balanced']
    sns.barplot(data=df_balanced, x='Shot_Type', y='MAE', hue='Model', ax=ax1, palette=pastel_colors)
    ax1.set_title(f'CASMIR Ablation Study: Balanced Test Set\n{dataset_name} Dataset', fontsize=14, fontweight='bold', pad=20)
    ax1.set_ylabel('Mean Absolute Error', fontsize=12)
    ax1.legend(title='Model Variant', title_fontsize=11, fontsize=10)
    ax1.grid(True, alpha=0.2, linestyle='--')
    ax1.set_facecolor('#FDFDFD')
    
    # 2. Original Imbalanced Test Set - Shot-wise comparison
    df_original = df_results[df_results['Test_Set'] == 'Original_Imbalanced']
    sns.barplot(data=df_original, x='Shot_Type', y='MAE', hue='Model', ax=ax2, palette=pastel_colors)
    ax2.set_title(f'CASMIR Ablation Study: Original Imbalanced Test Set\n{dataset_name} Dataset', fontsize=14, fontweight='bold', pad=20)
    ax2.set_ylabel('Mean Absolute Error', fontsize=12)
    ax2.legend(title='Model Variant', title_fontsize=11, fontsize=10)
    ax2.grid(True, alpha=0.2, linestyle='--')
    ax2.set_facecolor('#FDFDFD')
    
    # 3. Overall Performance - Balanced vs Original comparison
    overall_data = df_results[df_results['Shot_Type'] == 'Overall']
    sns.barplot(data=overall_data, x='Model', y='MAE', hue='Test_Set', ax=ax3, palette=['#E8F8F5', '#FDF2E9'])
    ax3.set_title(f'Overall Performance: Balanced vs Original Test Sets\n{dataset_name} Dataset', fontsize=14, fontweight='bold', pad=20)
    ax3.set_ylabel('Overall MAE', fontsize=12)
    ax3.tick_params(axis='x', rotation=45)
    ax3.legend(title='Test Set Type', title_fontsize=11, fontsize=10)
    ax3.grid(True, alpha=0.2, linestyle='--')
    ax3.set_facecolor('#FDFDFD')
    
    # 4. Model-wise Overall MAE difference (Original - Balanced) - pastel colors
    mae_diff_data = []
    for model_name in model_names:
        balanced_mae = results[model_name]['mae_results_balanced']['overall']
        original_mae = results[model_name]['mae_results_original']['overall']
        mae_diff = original_mae - balanced_mae
        mae_diff_data.append({
            'Model': model_name,
            'MAE_Difference': mae_diff
        })
    
    df_diff = pd.DataFrame(mae_diff_data)
    # Change to pastel colors
    bar_colors = ['#F8BBD9' if x > 0 else '#ABEBC6' for x in df_diff['MAE_Difference']]
    bars = ax4.bar(df_diff['Model'], df_diff['MAE_Difference'], color=bar_colors, alpha=0.8, edgecolor='white', linewidth=1.5)
    ax4.set_title(f'Performance Gap: Original - Balanced Test Sets\n{dataset_name} Dataset', fontsize=14, fontweight='bold', pad=20)
    ax4.set_ylabel('MAE Difference (Original - Balanced)', fontsize=12)
    ax4.tick_params(axis='x', rotation=45)
    ax4.grid(True, alpha=0.2, linestyle='--')
    ax4.axhline(y=0, color='#85929E', linestyle='-', alpha=0.7, linewidth=1.5)
    ax4.set_facecolor('#FDFDFD')
    
    # Show values - improved styling
    for bar, diff in zip(bars, df_diff['MAE_Difference']):
        height = bar.get_height()
        ax4.text(bar.get_x() + bar.get_width()/2., height + (0.001 if height >= 0 else -0.001),
                f'{diff:+.4f}', ha='center', va='bottom' if height >= 0 else 'top', 
                fontweight='bold', fontsize=10, color='#2C3E50')
    
    # 5. CAS Coupling Effect Analysis - Few-shot difference
    if 'CAS_Only' in model_names and 'CAS_Feature_Only' in model_names:
        coupling_effect_data = []
        for test_type, test_name in [('balanced', 'Balanced'), ('original', 'Original')]:
            cas_coupled_mae = results['CAS_Only'][f'mae_results_{test_type}']['few']
            cas_feature_only_mae = results['CAS_Feature_Only'][f'mae_results_{test_type}']['few']
            coupling_effect = cas_feature_only_mae - cas_coupled_mae  # improvement
            coupling_effect_data.append({
                'Test_Set': test_name,
                'Coupling_Effect': coupling_effect
            })
        
        df_coupling = pd.DataFrame(coupling_effect_data)
        bars_coupling = ax5.bar(df_coupling['Test_Set'], df_coupling['Coupling_Effect'], 
                               color=['#ABEBC6' if x > 0 else '#F8BBD9' for x in df_coupling['Coupling_Effect']], 
                               alpha=0.8, edgecolor='white', linewidth=1.5)
        ax5.set_title(f'Coupling Effect: Few-shot Performance\n(Feature-Only - Coupled CAS)', fontsize=14, fontweight='bold', pad=20)
        ax5.set_ylabel('MAE Improvement by Coupling', fontsize=12)
        ax5.grid(True, alpha=0.2, linestyle='--')
        ax5.axhline(y=0, color='#85929E', linestyle='-', alpha=0.7, linewidth=1.5)
        ax5.set_facecolor('#FDFDFD')
        
        # Show values
        for bar, effect in zip(bars_coupling, df_coupling['Coupling_Effect']):
            height = bar.get_height()
            ax5.text(bar.get_x() + bar.get_width()/2., height + (0.001 if height >= 0 else -0.001),
                    f'{effect:+.4f}', ha='center', va='bottom' if height >= 0 else 'top', 
                    fontweight='bold', fontsize=11, color='#2C3E50')
    else:
        ax5.text(0.5, 0.5, 'CAS Coupling Analysis\nRequires both CAS_Only and\nCAS_Feature_Only models', 
                ha='center', va='center', transform=ax5.transAxes, fontsize=12)
        ax5.set_facecolor('#FDFDFD')
    
    # 6. Shot-wise Coupling Effect Analysis
    if 'CAS_Only' in model_names and 'CAS_Feature_Only' in model_names:
        shot_coupling_data = []
        for shot_type in ['few', 'medium', 'many']:
            cas_coupled_mae = results['CAS_Only']['mae_results_balanced'][shot_type]
            cas_feature_only_mae = results['CAS_Feature_Only']['mae_results_balanced'][shot_type]
            coupling_effect = cas_feature_only_mae - cas_coupled_mae
            shot_coupling_data.append({
                'Shot_Type': shot_type.capitalize(),
                'Coupling_Effect': coupling_effect
            })
        
        df_shot_coupling = pd.DataFrame(shot_coupling_data)
        bars_shot = ax6.bar(df_shot_coupling['Shot_Type'], df_shot_coupling['Coupling_Effect'],
                           color=['#ABEBC6' if x > 0 else '#F8BBD9' for x in df_shot_coupling['Coupling_Effect']], 
                           alpha=0.8, edgecolor='white', linewidth=1.5)
        ax6.set_title(f'Shot-wise Coupling Effect\n(Feature-Only - Coupled CAS)', fontsize=14, fontweight='bold', pad=20)
        ax6.set_ylabel('MAE Improvement by Coupling', fontsize=12)
        ax6.grid(True, alpha=0.2, linestyle='--')
        ax6.axhline(y=0, color='#85929E', linestyle='-', alpha=0.7, linewidth=1.5)
        ax6.set_facecolor('#FDFDFD')
        
        # Show values
        for bar, effect in zip(bars_shot, df_shot_coupling['Coupling_Effect']):
            height = bar.get_height()
            ax6.text(bar.get_x() + bar.get_width()/2., height + (0.001 if height >= 0 else -0.001),
                    f'{effect:+.4f}', ha='center', va='bottom' if height >= 0 else 'top', 
                    fontweight='bold', fontsize=11, color='#2C3E50')
    else:
        ax6.text(0.5, 0.5, 'Shot-wise Coupling Analysis\nRequires both CAS_Only and\nCAS_Feature_Only models', 
                ha='center', va='center', transform=ax6.transAxes, fontsize=12)
        ax6.set_facecolor('#FDFDFD')
    
    plt.tight_layout(pad=3.0)
    
    # Save
    os.makedirs(output_dir, exist_ok=True)
    timestamp = pd.Timestamp.now().strftime("%Y%m%d_%H%M%S")
    output_path = os.path.join(output_dir, f"CASMIR_Ablation_Study_{dataset_name}_{timestamp}.png")
    plt.savefig(output_path, dpi=300, bbox_inches='tight', facecolor='#FDFEFE')
    print(f" Performance comparison visualization saved: {output_path}")
    
    plt.show()
    
    return output_path, df_results

def print_ablation_summary(results, dataset_name):
    """Ablation Study results summary (Balanced + Original Imbalanced Test Sets) - 9 models"""
    print(f"\n{'='*80}")
    print(f" CASMIR Extended Ablation Study results summary - {dataset_name} Dataset")
    print(f" Total {len(results)} models evaluated")
    print(f"{'='*80}")
    
    # Analysis for both test sets
    for test_type, test_name in [('balanced', 'Balanced Test Set'), ('original', 'Original Imbalanced Test Set')]:
        print(f"\n{'='*60}")
        print(f" {test_name} Results")
        print(f"{'='*60}")
        
        mae_key = f'mae_results_{test_type}'
        
        # Sort by Overall MAE
        sorted_results = sorted(results.items(), key=lambda x: x[1][mae_key]['overall'])
        
        print(f"\n Overall Performance Ranking:")
        for i, (model_name, result) in enumerate(sorted_results, 1):
            mae = result[mae_key]['overall']
            print(f"   {i}. {model_name}: {mae:.6f}")
        
        print(f"\n Detailed Shot-wise Performance:")
        for model_name, result in results.items():
            mae_results = result[mae_key]
            print(f"\n    {model_name}:")
            print(f"      Few-shot MAE:    {mae_results['few']:.6f}")
            print(f"      Medium-shot MAE: {mae_results['medium']:.6f}")
            print(f"      Many-shot MAE:   {mae_results['many']:.6f}")
            print(f"      Overall MAE:     {mae_results['overall']:.6f}")
        
        # Improvement analysis
        print(f"\n Component Contribution Analysis:")
        
        moe_only_mae = results['MoE_Only'][mae_key]['overall']
        cas_only_mae = results['CAS_Only'][mae_key]['overall']
        case_mae = results['CASMIR'][mae_key]['overall']
        
        cas_contribution = moe_only_mae - case_mae
        moe_contribution = cas_only_mae - case_mae
        
        print(f"   CAS Contribution: {cas_contribution:+.6f} (MoE Only vs CASMIR)")
        print(f"   MoE Contribution: {moe_contribution:+.6f} (CAS Only vs CASMIR)")
        
        if cas_contribution > 0:
            print(f"    CAS {cas_contribution:.6f} improves performance")
        else:
            print(f"    CAS {abs(cas_contribution):.6f} degrades performance")
        
        if moe_contribution > 0:
            print(f"    MoE {moe_contribution:.6f} improves performance")
        else:
            print(f"    MoE {abs(moe_contribution):.6f} degrades performance")
        
        # Interaction effect analysis
        synergy_effect = (moe_only_mae + cas_only_mae) / 2 - case_mae
        if synergy_effect > 0:
            print(f"    Synergy Effect: +{synergy_effect:.6f} (additional improvement from component interaction)")
        else:
            print(f"     Negative Synergy: {synergy_effect:.6f} (performance degradation from component interference)")
        
        # CAS Coupled Smoothing Effect Analysis
        if 'CAS_Feature_Only' in results:
            print(f"\n CAS Coupling Effect Analysis:")
            
            cas_feature_only_mae = results['CAS_Feature_Only'][mae_key]['overall']
            coupling_effect = cas_feature_only_mae - cas_only_mae
            
            print(f"   Feature-Only CAS MAE: {cas_feature_only_mae:.6f}")
            print(f"   Coupled CAS MAE:     {cas_only_mae:.6f}")
            print(f"   Coupling Effect:     {coupling_effect:+.6f}")
            
            if coupling_effect > 0:
                print(f"    Coupled smoothing improves performance by {coupling_effect:.6f}")
                coupling_improvement_pct = (coupling_effect / cas_feature_only_mae) * 100
                print(f"    Relative improvement: {coupling_improvement_pct:.2f}%")
            else:
                print(f"    Coupled smoothing degrades performance by {abs(coupling_effect):.6f}")
            
            # Shot-wise coupling effect analysis
            print(f"\n Shot-wise Coupling Effects:")
            for shot in ['few', 'medium', 'many']:
                cas_coupled_shot = results['CAS_Only'][mae_key][shot]
                cas_feature_only_shot = results['CAS_Feature_Only'][mae_key][shot]
                shot_coupling_effect = cas_feature_only_shot - cas_coupled_shot
                
                shot_name = shot.capitalize() + '-shot'
                print(f"   {shot_name:12}: {shot_coupling_effect:+.6f}")
                
                if shot == 'few' and shot_coupling_effect > 0:
                    print(f"       Largest coupling effect observed in Few-shot!")
        else:
            print(f"\n  CAS_Feature_Only model is required for CAS Coupling Effect analysis.")
        
        # Additional Ablation Analysis
        print(f"\n Additional Ablation Study Analysis:")
        
        # Smoothing Strength Analysis
        if 'Smooth_Strength_0' in results and 'Smooth_Strength_1' in results and 'CAS_Only' in results:
            print(f"\n Smoothing Strength Effect:")
            strength_0_mae = results['Smooth_Strength_0'][mae_key]['overall']
            strength_1_mae = results['Smooth_Strength_1'][mae_key]['overall']
            adaptive_mae = results['CAS_Only'][mae_key]['overall']
            
            print(f"   Strength=0 (No Smoothing):  {strength_0_mae:.6f}")
            print(f"   Strength=1 (Full Smoothing): {strength_1_mae:.6f}")
            print(f"   Adaptive Strength:          {adaptive_mae:.6f}")
            
            best_strength = min(strength_0_mae, strength_1_mae, adaptive_mae)
            if abs(adaptive_mae - best_strength) < 1e-6:
                print(f"    Adaptive strength achieves best performance!")
            elif abs(strength_0_mae - best_strength) < 1e-6:
                print(f"     No smoothing performs better (consider data characteristics)")
            else:
                print(f"     Full smoothing performs better")
        
        # Learnable Metric Analysis
        if 'CAS_No_Learnable_Metric' in results and 'CAS_Only' in results:
            print(f"\n Learnable Metric (w) Effect:")
            no_learnable_mae = results['CAS_No_Learnable_Metric'][mae_key]['overall']
            learnable_mae = results['CAS_Only'][mae_key]['overall']
            
            print(f"   Fixed w=1 (No Learning): {no_learnable_mae:.6f}")
            print(f"   Learnable w:            {learnable_mae:.6f}")
            print(f"   Improvement:            {(no_learnable_mae - learnable_mae):+.6f}")
            
            if learnable_mae < no_learnable_mae:
                improvement_pct = ((no_learnable_mae - learnable_mae) / no_learnable_mae) * 100
                print(f"    Learnable metric improves performance by {improvement_pct:.2f}%")
            else:
                print(f"    Learnable metric degrades performance")
        
        # MixUp vs CAS Comparison
        if 'MixUp_MoE' in results and 'CASMIR' in results:
            print(f"\n MixUp vs CAS (with MoE):")
            mixup_moe_mae = results['MixUp_MoE'][mae_key]['overall']
            case_mae = results['CASMIR'][mae_key]['overall']
            
            print(f"   MixUp + MoE: {mixup_moe_mae:.6f}")
            print(f"   CAS + MoE:   {case_mae:.6f}")
            print(f"   CAS Advantage: {(mixup_moe_mae - case_mae):+.6f}")
            
            if case_mae < mixup_moe_mae:
                improvement_pct = ((mixup_moe_mae - case_mae) / mixup_moe_mae) * 100
                print(f"    CAS outperforms MixUp by {improvement_pct:.2f}%")
                
                # Few-shot performance comparison
                mixup_few = results['MixUp_MoE'][mae_key]['few']
                case_few = results['CASMIR'][mae_key]['few']
                few_improvement = ((mixup_few - case_few) / mixup_few) * 100
                print(f"    Few-shot CAS improvement: {few_improvement:.2f}%")
            else:
                print(f"     MixUp performs better")
    
    # Performance difference analysis between two test sets
    print(f"\n{'='*60}")
    print(f" Test Set Comparison Analysis")
    print(f"{'='*60}")
    
    print(f"\n Performance Gap (Original - Balanced):")
    for model_name in results.keys():
        balanced_mae = results[model_name]['mae_results_balanced']['overall']
        original_mae = results[model_name]['mae_results_original']['overall']
        gap = original_mae - balanced_mae
        gap_pct = (gap / balanced_mae) * 100 if balanced_mae > 0 else 0
        
        print(f"   {model_name}: {gap:+.6f} ({gap_pct:+.2f}%)")
        if gap > 0:
            print(f"      → Original performs worse (effect of imbalance)")
        else:
            print(f"      → Original performs better (rare case)")
    
    print(f"\n{'='*80}")

# =============================================================================
# Main Execution Function
# =============================================================================

def process_ablation_study(artifact_info, args):
    """Run Ablation Study for a single CASMIR artifact"""
    
    print(f"\n{'='*100}")
    print(f" CASMIR Ablation Study Start: {artifact_info['folder_name']}")
    print(f"{'='*100}")
    
    try:
        # 1. Load artifact
        metadata, hyperparams, preprocessor, data_split_info, experiment_config = load_artifact_metadata(artifact_info['path'])
        
        # 2. Load data and reproduce split
        print("\n Loading data and reproducing split...")
        dataset_result = load_dataset(args.dataset)
        
        if isinstance(dataset_result, tuple) and len(dataset_result) == 2:
            X, y_info = dataset_result
            if isinstance(y_info, str):
                if y_info in X.columns:
                    y = X[y_info].copy()
                    X = X.drop(columns=[y_info]).copy()
                else:
                    raise ValueError(f"Target column '{y_info}' not found")
            else:
                y = y_info
        else:
            raise ValueError(f"load_dataset returned unexpected value: {type(dataset_result)}")
        
        # Apply preprocessing
        X_processed = preprocessor.transform(X)
        
        # Extract data split indices
        train_indices = data_split_info['train_indices']
        val_indices = data_split_info['val_indices']
        test_indices = data_split_info['test_indices']
        
        # Calculate balanced indices
        test_balanced_relative_indices = data_split_info.get('test_balanced_indices', list(range(len(test_indices))))
        val_balanced_relative_indices = data_split_info.get('val_balanced_indices', list(range(len(val_indices))))
        
        test_balanced_absolute_indices = [test_indices[i] for i in test_balanced_relative_indices]
        val_balanced_absolute_indices = [val_indices[i] for i in val_balanced_relative_indices]
        
        # Prepare training data
        X_train = X_processed[train_indices]
        X_val = X_processed[val_balanced_absolute_indices]
        X_test_balanced = X_processed[test_balanced_absolute_indices]
        X_test_original = X_processed[test_indices]  # Original imbalanced test set
        
        y_train = y.iloc[train_indices].values if hasattr(y, 'iloc') else y[train_indices]
        y_val = y.iloc[val_balanced_absolute_indices].values if hasattr(y, 'iloc') else y[val_balanced_absolute_indices]
        y_test_balanced = y.iloc[test_balanced_absolute_indices].values if hasattr(y, 'iloc') else y[test_balanced_absolute_indices]
        y_test_original = y.iloc[test_indices].values if hasattr(y, 'iloc') else y[test_indices]  # Original imbalanced test set
        
        print(f" Data split reproduction complete:")
        print(f"   Train: {len(X_train)}, Val: {len(X_val)}")
        print(f"   Test Balanced: {len(X_test_balanced)}, Test Original: {len(X_test_original)}")
        
        # 3. Calculate density for CAS (if needed)
        print("\n Calculating density for CAS...")
        y_train_reshaped = y_train.reshape(-1, 1)
        densities_train, density_boundaries, kde = precompute_density_and_boundaries(y_train_reshaped)
        
        if densities_train is None:
            print(" Density calculation failed, proceeding without CAS")
            densities_train = None
        else:
            print(f" Density calculation complete: {densities_train.shape}")
        
        # 4. Create Ablation Study models
        input_dim = X_train.shape[1]
        models = create_ablation_models(hyperparams, experiment_config, input_dim)
        
        # 5. Train each model
        print(f"\n Training Ablation Study models...")
        trained_models = OrderedDict()
        training_histories = OrderedDict()
        
        for model_name, model in models.items():
            print(f"\n{'='*60}")
            print(f" {model_name} training started...")
            print(f"{'='*60}")
            
            # Set up optimizer and loss function
            optimizer, criterion = setup_optimizer_and_loss(model, hyperparams)
            
            # Train model
            max_epochs = getattr(args, 'epochs', 300)
            patience = getattr(args, 'patience', 50)
            trained_model, best_val_loss = train_single_model(
                model, model_name, optimizer, criterion,
                X_train, y_train, X_val, y_val,
                densities_train=densities_train,
                experiment_config=experiment_config,
                max_epochs=max_epochs, patience=patience
            )
            
            trained_models[model_name] = trained_model
            training_histories[model_name] = {'best_val_loss': best_val_loss}
            
            print(f" {model_name} training complete (Best Val Loss: {best_val_loss:.6f})")
        
        # 6. Evaluate model performance
        print(f"\n Evaluating all model performance...")
        
        # Prepare test shot mapping
        test_balanced_shot_indices = data_split_info.get('test_balanced_shot_indices', {})
        test_balanced_shot_indices = {int(k): v for k, v in test_balanced_shot_indices.items()}
        
        test_original_shot_mapping = data_split_info.get('test_shot_mapping', {})
        test_original_shot_mapping = {int(k): v for k, v in test_original_shot_mapping.items()}
        
        evaluation_results = evaluate_models(
            trained_models, 
            X_test_balanced, y_test_balanced, test_balanced_shot_indices,
            X_test_original, y_test_original, test_original_shot_mapping
        )
        
        #  6.5. Prediction Stability Analysis (Bias-Variance Trade-off proof)
        print(f"\n Prediction Stability Analysis (Bias-Variance Trade-off proof)...")
        
        # Reconstruct shot indices (based on Original test set)
        shot_indices_original = {'few': [], 'medium': [], 'many': []}
        for idx, shot_type in test_original_shot_mapping.items():
            if shot_type in shot_indices_original:
                # Convert to index in entire test set
                relative_idx = test_indices.index(idx) if idx in test_indices else None
                if relative_idx is not None:
                    shot_indices_original[shot_type].append(relative_idx)
        
        # Perform Prediction Stability analysis
        stability_results = analyze_prediction_stability(
            trained_models, 
            X_test_original, 
            y_test_original, 
            shot_indices_original,
            n_runs=20
        )
        
        # Stability visualization
        stability_plot_path = plot_prediction_stability(
            stability_results, 
            args.dataset, 
            output_dir="./outputs"
        )
        
        # Stability summary output
        print_stability_summary(stability_results, args.dataset)
        
        # 7. Results visualization and summary
        print(f"\n Results visualization and summary...")
        
        vis_path, results_df = create_comparison_visualization(evaluation_results, args.dataset)
        print_ablation_summary(evaluation_results, args.dataset)
        
        # 8. Results Save
        timestamp = pd.Timestamp.now().strftime("%Y%m%d_%H%M%S")
        results_save_path = f"./outputs/CASMIR_Ablation_Results_{args.dataset}_{timestamp}.json"
        
        # Convert to JSON serializable format
        save_results = {}
        for model_name, result in evaluation_results.items():
            save_results[model_name] = {
                'mae_results_balanced': result['mae_results_balanced'],
                'mae_results_original': result['mae_results_original'],
                'training_history': training_histories[model_name]
            }
        
        # Include Stability Results
        stability_dict = stability_results.to_dict('records') if stability_results is not None else []
        
        save_data = {
            'dataset': args.dataset,
            'artifact_path': artifact_info['path'],
            'timestamp': timestamp,
            'hyperparams': hyperparams,
            'results': save_results,
            'prediction_stability': stability_dict
        }
        
        os.makedirs('./outputs', exist_ok=True)
        with open(results_save_path, 'w', encoding='utf-8') as f:
            json.dump(save_data, f, indent=2, ensure_ascii=False)
        
        print(f" Ablation Study Results saved: {results_save_path}")
        print(f" Prediction Stability graph saved: {stability_plot_path}")
        
        return evaluation_results, vis_path
        
    except Exception as e:
        print(f" Ablation Study execution failed: {e}")
        import traceback
        traceback.print_exc()
        return None, None

def main():
    import argparse
    
    parser = argparse.ArgumentParser(description="CASMIR Ablation Study")
    
    parser.add_argument('--artifacts_folder', type=str, required=True,
                       help='Folder containing CASMIR artifacts')
    parser.add_argument('--dataset', type=str, required=True,
                       help='Dataset name')
    parser.add_argument('--epochs', type=int, default=300,
                       help='Training epochs for ablation models (default: 300)')
    parser.add_argument('--patience', type=int, default=50,
                       help='Early stopping patience (default: 50)')
    parser.add_argument('--display_off', action='store_true',
                       help='Server mode: save files only without display')
    
    args = parser.parse_args()
    
    # Set global random seed
    set_global_random_seed(42)
    
    print("CASMIR Ablation Study Starting")
    print("=" * 80)
    print(f"  Epochs: {args.epochs}, Patience: {args.patience}")
    
    # Scan CASMIR artifacts
    artifacts = scan_adasmooth_artifacts(args.artifacts_folder, args.dataset)
    
    if not artifacts:
        print("No CASMIR artifacts to process.")
        return
    
    # Use the first (latest) artifact
    if len(artifacts) > 1:
        print(f"Found {len(artifacts)} artifacts. Using latest: {artifacts[0]['folder_name']}")
    
    artifact_info = artifacts[0]
    results, vis_path = process_ablation_study(artifact_info, args)
    
    if results is not None:
        print(f"\nCASMIR Ablation Study Complete!")
        print(f"Visualization saved: {vis_path}")
    else:
        print("CASMIR Ablation Study failed.")

if __name__ == "__main__":
    main()
