#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
visualize_tsne.py

Retrain model with optimal hyperparameters from HPO results and perform
PCA/t-SNE visualization of learned features.

Usage:
    python scripts/visualize_tsne.py --artifact_dir "artifacts_path" --dataset dataset_name
"""

import os
import sys

# Add parent directory to path for module imports
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

import json
import numpy as np
import pandas as pd
import torch
import joblib
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from sklearn.preprocessing import StandardScaler
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from sklearn.decomposition import PCA
import copy
import random
import glob
from pathlib import Path

# Internal module imports
from src.data.datasets import load_dataset, preprocess_data, split_data_stratified, create_balanced_dataset, map_shot_types, classify_bins_by_samples, freedman_diaconis_bins
from src.models.basic_models import MLP, MLPadv, SimpleThreeMLPEnsemble
from src.models.CASMIR_V1 import CASMIR_V1, AdaSmoothMoEV1, precompute_density_and_boundaries
from src.training.losses import ConRLoss, RankSimLoss, GAILossMD, BMCLossMD, WeightedL1Loss, WeightedMSELoss
from src.training.train_utils import train_pytorch_model, train_CASMIR_V1
from src.utils.utils import calculate_balanced_weights, get_gmm
from config import CONFIG

# Optional MLP variant models
try:
    from src.models.ranksim_models import RankSimMLP
except ImportError:
    RankSimMLP = None

try:
    from src.models.conr_models import ConRMLP  
except ImportError:
    ConRMLP = None
    
try:
    from src.models.fds_models import FDSMLP
except ImportError:
    FDSMLP = None
    
try:
    from src.models.lds_models import LDSMLP
except ImportError:
    LDSMLP = None

print("All internal modules imported successfully")

# Device setup
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Training device: {DEVICE}")

def set_global_random_seed(seed=42):
    """Set seed for all random sources for complete reproducibility"""
    print(f"Setting all random sources to seed {seed}...")
    
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
        print("    CUDA deterministic mode enabled")
    
    try:
        torch.use_deterministic_algorithms(True)
        print("    PyTorch deterministic algorithms enabled")
    except AttributeError:
        print("    PyTorch version too old for deterministic algorithms")
    except RuntimeError as e:
        print(f"    Warning during deterministic algorithms setup: {e}")
    
    os.environ['PYTHONHASHSEED'] = str(seed)
    print(f"All random sources fixed to seed {seed}\n")

def scan_artifacts_folder(artifacts_folder, dataset_name, algorithm_filter=None):
    """Scan artifacts folder for artifact directories matching the dataset"""
    
    filter_msg = f" (algorithm filter: {algorithm_filter})" if algorithm_filter else ""
    print(f"Scanning {artifacts_folder} for {dataset_name} dataset artifacts{filter_msg}...")
    
    if not os.path.exists(artifacts_folder):
        print(f"[ERROR] Artifact folder does not exist: {artifacts_folder}")
        return []
    
    # Pattern: V###_dataset_algorithm_date_time format folders
    pattern = f"V*_{dataset_name}_*"
    search_path = os.path.join(artifacts_folder, pattern)
    
    artifact_dirs = glob.glob(search_path)
    
    if not artifact_dirs:
        print(f"[ERROR] No artifacts found for dataset: {dataset_name}")
        print(f"   Search pattern: {pattern}")
        return []
    
    valid_artifacts = []
    filtered_count = 0
    
    for artifact_dir in sorted(artifact_dirs):
        folder_name = os.path.basename(artifact_dir)
        parts = folder_name.split('_')
        if len(parts) >= 4:
            version = parts[0]
            
            year_idx = -1
            for i in range(1, len(parts)):
                if parts[i].isdigit() and len(parts[i]) == 4 and parts[i].startswith('20'):
                    year_idx = i
                    break
            
            if year_idx > 1:
                dataset_name_parts = dataset_name.split('_')
                dataset_end_idx = 1 + len(dataset_name_parts)
                algorithm_parts = parts[dataset_end_idx:year_idx]
                unwanted_words = ['Notebook']
                algorithm_parts = [part for part in algorithm_parts if part not in unwanted_words]
                algorithm = '_'.join(algorithm_parts)
                dataset = dataset_name
            else:
                dataset = parts[1]
                algorithm = parts[2]
            
            # Filter boosting algorithms (no representation learning)
            boosting_algorithms = ['XGBoost', 'LightGBM', 'GaussianNoise_XGBoost', 'SMOTER_XGBoost']
            if any(boost_alg in algorithm for boost_alg in boosting_algorithms):
                print(f"[SKIP] Filtered: {folder_name} (algorithm: {algorithm}) - boosting type")
                filtered_count += 1
                continue
            
            if algorithm_filter:
                if algorithm_filter.lower() != algorithm.lower():
                    print(f"[SKIP] Filtered: {folder_name} (algorithm: {algorithm})")
                    filtered_count += 1
                    continue
            
            metadata_file = os.path.join(artifact_dir, 'metadata.json')
            hyperparams_file = os.path.join(artifact_dir, 'hyperparams.json')
            
            if os.path.exists(metadata_file) and os.path.exists(hyperparams_file):
                valid_artifacts.append({
                    'path': artifact_dir,
                    'version': version,
                    'dataset': dataset,
                    'algorithm': algorithm,
                    'folder_name': folder_name
                })
                print(f"[OK] Found: {folder_name} (algorithm: {algorithm})")
            else:
                print(f"[WARN] Skip: {folder_name} (required files missing)")
    
    if algorithm_filter and filtered_count > 0:
        print(f"[INFO] {filtered_count} artifacts filtered by algorithm")
    
    print(f"[INFO] Total {len(valid_artifacts)} valid artifacts found")
    return valid_artifacts

def process_single_artifact(artifact_info, args):
    """Process visualization for a single artifact"""
    
    print(f"\n{'='*80}")
    print(f"Processing: {artifact_info['folder_name']}")
    print(f"   Algorithm: {artifact_info['algorithm']}")
    print(f"   Dataset: {artifact_info['dataset']}")
    print(f"{'='*80}")
    
    try:
        return process_artifact_visualization(artifact_info['path'], args)
        
    except Exception as e:
        print(f"[ERROR] {artifact_info['folder_name']} processing failed: {e}")
        import traceback
        traceback.print_exc()
        return None

def load_and_recreate_data_split(dataset_name, data_split_info, preprocessor):
    """Load data and recreate data split"""
    
    print("Loading data and recreating split...")
    
    dataset_result = load_dataset(dataset_name)
    
    if isinstance(dataset_result, tuple) and len(dataset_result) == 2:
        X, y_info = dataset_result
        
        if isinstance(y_info, str):
            if y_info in X.columns:
                y = X[y_info].copy()
                X = X.drop(columns=[y_info]).copy()
                print(f"[OK] Target column '{y_info}' extracted")
            else:
                raise ValueError(f"Target column '{y_info}' not found")
        else:
            y = y_info
        
        print(f"[OK] Dataset loaded: X.shape={X.shape}, y.shape={y.shape}")
            
    else:
        raise ValueError(f"Unexpected load_dataset return type: {type(dataset_result)}")
    
    X_processed = preprocessor.transform(X)
    
    return X_processed, y

def extract_features_for_visualization(model, X, y, data_split_info, casmir_weighted=False):
    """Feature extraction and shot type classification"""
    
    print("\nExtracting features from Val + Test data...")
    
    train_indices = data_split_info.get('train_indices', [])
    val_indices = data_split_info.get('val_indices', [])
    test_indices = data_split_info.get('test_indices', [])
    
    print(f"Data size: Val ({len(val_indices)}) + Test ({len(test_indices)}) = Total {len(val_indices) + len(test_indices)}")
    
    val_test_indices = val_indices + test_indices
    X_val_test = X[val_test_indices]
    y_val_test = y.iloc[val_test_indices] if hasattr(y, 'iloc') else y[val_test_indices]
    
    feature_extractor = FeatureExtractorModel(model, casmir_weighted=casmir_weighted)
    features = feature_extractor.extract_features(X_val_test)
    
    if features is None:
        raise ValueError("Feature extraction failed")
    
    print(f"[OK] Feature extraction complete: {features.shape}")
    
    print("\nShot type classification...")
    val_shot_mapping = data_split_info.get('val_shot_mapping', {})
    test_shot_mapping = data_split_info.get('test_shot_mapping', {})
    
    shot_types = []
    
    if val_shot_mapping and test_shot_mapping:
        print("[OK] Using original val + test shot mapping (no data leakage)")
        
        for relative_idx in range(len(val_indices)):
            shot_type = val_shot_mapping.get(str(relative_idx), 'unknown')
            shot_types.append(shot_type)
        
        for relative_idx in range(len(test_indices)):
            shot_type = test_shot_mapping.get(str(relative_idx), 'unknown')
            shot_types.append(shot_type)
    else:
        print("[WARN] Shot mapping not found, using y-value based classification")
        y_val_test_flat = y_val_test.values.flatten() if hasattr(y_val_test, 'values') else y_val_test.flatten()
        few_threshold = np.percentile(y_val_test_flat, 33.33)
        many_threshold = np.percentile(y_val_test_flat, 66.67)
        
        for val in y_val_test_flat:
            if val <= few_threshold:
                shot_types.append('few')
            elif val >= many_threshold:
                shot_types.append('many')
            else:
                shot_types.append('medium')
    
    from collections import Counter
    shot_counts = Counter(shot_types)
    print(f"Shot type distribution: Few={shot_counts.get('few', 0)}, Medium={shot_counts.get('medium', 0)}, Many={shot_counts.get('many', 0)}")
    
    return features, shot_types

def process_artifact_visualization(artifact_dir, args):
    """Process visualization for a single artifact directory"""
    
    try:
        # 1. Load artifact
        metadata, hyperparams, preprocessor, data_split_info, experiment_config = load_artifact_metadata(artifact_dir)
        
        # 2. Load data and recreate split
        X, y = load_and_recreate_data_split(args.dataset, data_split_info, preprocessor)
        
        train_indices = data_split_info.get('train_indices', [])
        val_indices = data_split_info.get('val_indices', [])
        test_indices = data_split_info.get('test_indices', [])
        
        test_balanced_relative_indices = data_split_info.get('test_balanced_indices', list(range(len(test_indices))))
        val_balanced_relative_indices = data_split_info.get('val_balanced_indices', list(range(len(val_indices))))
        
        test_balanced_absolute_indices = [test_indices[i] for i in test_balanced_relative_indices]
        val_balanced_absolute_indices = [val_indices[i] for i in val_balanced_relative_indices]
        
        X_train = X[train_indices]
        X_val = X[val_balanced_absolute_indices]
        y_train = y.iloc[train_indices].values if hasattr(y, 'iloc') else y[train_indices]
        y_val = y.iloc[val_balanced_absolute_indices].values if hasattr(y, 'iloc') else y[val_balanced_absolute_indices]
        
        print(f"[OK] Data split recreated:")
        print(f"   Train (Original Imbalanced): {len(X_train)}")
        print(f"   Val (Balanced): {len(X_val)}")
        
        # 3. Model recreation and training
        print("\nRecreating model with correct hyperparameters...")
        input_dim = X_train.shape[1] 
        model = recreate_model_with_correct_hyperparams(metadata, hyperparams, experiment_config, input_dim)
        
        print("\nSetting up optimizer and loss function...")
        optimizer, criterion = setup_optimizer_and_loss(model, hyperparams, metadata['algorithm_name'])
        
        print("\nTraining model with Early Stopping...")
        max_epochs = getattr(args, 'epochs', 600)
        patience = getattr(args, 'patience', 600)
        model = train_model_with_early_stopping(model, optimizer, criterion, X_train, y_train, X_val, y_val, experiment_config, max_epochs=max_epochs, patience=patience)
        
        # 4. Feature extraction and visualization
        features, shot_types = extract_features_for_visualization(model, X, y, data_split_info, casmir_weighted=getattr(args, 'weighted_feature', False))
        
        # 5. Prepare target values for coloring
        target_values_for_vis = None
        if args.coloring in ['target_value', 'target_log', 'target_quantile']:
            val_test_indices = val_indices + test_indices
            y_val_test = y.iloc[val_test_indices] if hasattr(y, 'iloc') else y[val_test_indices]
            target_values_for_vis = y_val_test.values.flatten() if hasattr(y_val_test, 'values') else y_val_test.flatten()
            
            print(f"\nTarget value range: {target_values_for_vis.min():.3f} ~ {target_values_for_vis.max():.3f}")
            
        # 6. Create visualization
        output_path = create_visualization(
            features, shot_types,
            metadata['algorithm_name'], args.dataset,
            method=args.visualization,
            coloring=args.coloring,
            target_values=target_values_for_vis,
            colormap_name=args.colormap,
            display_off=args.display_off
        )
        
        print(f"[OK] Visualization complete: {output_path}")
        return output_path
        
    except Exception as e:
        print(f"[ERROR] Visualization processing failed: {e}")
        import traceback
        traceback.print_exc()
        return None

class FeatureExtractorModel:
    """Extract features before the final prediction layer from PyTorch models"""
    
    def __init__(self, model, casmir_weighted=False):
        self.model = model
        self.model.eval()
        self.extracted_features = None
        self.hook_handles = []
        self.casmir_weighted = casmir_weighted  # CASMIR weighted feature option
        
        # Variables for CASMIR weighted feature
        self.gate_weights = None
        self.expert_features = [None] * 3  # Default 3 experts
        
        self._register_hooks()
    
    def _register_hooks(self):
        """Register hooks to extract features before the final layer"""
        
        if isinstance(self.model, MLP):
            if hasattr(self.model, 'feature_layers') and len(self.model.feature_layers) > 0:
                last_feature_layer = self.model.feature_layers[-1]
                hook = last_feature_layer.register_forward_hook(self._feature_hook)
                self.hook_handles.append(hook)
                print(f"[OK] MLP Feature Hook registered: {last_feature_layer}")
            elif hasattr(self.model, 'layers') and len(self.model.layers) > 1:
                last_feature_layer = self.model.layers[-2]
                hook = last_feature_layer.register_forward_hook(self._feature_hook)
                self.hook_handles.append(hook)
                print(f"[OK] MLP Feature Hook registered (layers): {last_feature_layer}")
            elif hasattr(self.model, 'network') and len(self.model.network) > 1:
                last_feature_layer = self.model.network[-2]
                hook = last_feature_layer.register_forward_hook(self._feature_hook)
                self.hook_handles.append(hook)
                print(f"[OK] MLP Feature Hook registered (network): {last_feature_layer}")
            else:
                print("[WARN] MLP feature layers not found")
                print(f"   Model attributes: {dir(self.model)}")
                print(f"   Model structure: {self.model}")
        
        elif any(variant in str(type(self.model)) for variant in ['RankSim', 'ConR', 'FDS', 'LDS', 'GAI', 'BMC', 'SQRT_INV']):
            if hasattr(self.model, 'feature_layers') and len(self.model.feature_layers) > 0:
                last_feature_layer = self.model.feature_layers[-1]
                hook = last_feature_layer.register_forward_hook(self._feature_hook)
                self.hook_handles.append(hook)
                print(f"[OK] MLP Variant Feature Hook registered: {last_feature_layer}")
            elif hasattr(self.model, 'layers') and len(self.model.layers) > 1:
                last_feature_layer = self.model.layers[-2]
                hook = last_feature_layer.register_forward_hook(self._feature_hook)
                self.hook_handles.append(hook)
                print(f"[OK] MLP Variant Feature Hook registered (layers): {last_feature_layer}")
            elif hasattr(self.model, 'network') and len(self.model.network) > 1:
                last_feature_layer = self.model.network[-2]
                hook = last_feature_layer.register_forward_hook(self._feature_hook)
                self.hook_handles.append(hook)
                print(f"[OK] MLP Variant Feature Hook registered (network): {last_feature_layer}")
            elif hasattr(self.model, 'backbone') or hasattr(self.model, 'encoder'):
                backbone = getattr(self.model, 'backbone', getattr(self.model, 'encoder', None))
                if backbone and hasattr(backbone, 'layers') and len(backbone.layers) > 1:
                    last_feature_layer = backbone.layers[-2]
                    hook = last_feature_layer.register_forward_hook(self._feature_hook)
                    self.hook_handles.append(hook)
                    print(f"[OK] MLP Variant Feature Hook registered (backbone): {last_feature_layer}")
                else:
                    print("[WARN] MLP Variant backbone/encoder not found")
            else:
                print("[WARN] MLP Variant feature layers not found")
                print(f"   Model type: {type(self.model)}")
                print(f"   Model attributes: {[attr for attr in dir(self.model) if not attr.startswith('_')]}")
                
        elif isinstance(self.model, (CASMIR_V1, AdaSmoothMoEV1)):
            if self.casmir_weighted:
                print(f"CASMIR Gate-weighted Average Feature extraction mode")
                
                self.expert_features = [None] * len(self.model.experts)
                
                for i, expert in enumerate(self.model.experts):
                    if hasattr(expert, 'network') and len(expert.network) > 1:
                        hidden_layer = expert.network[-3]
                        hook = hidden_layer.register_forward_hook(self._make_weighted_expert_hook(i))
                        self.hook_handles.append(hook)
                        print(f"[OK] CASMIR Expert {i} Weighted Hook registered: {hidden_layer}")
                    else:
                        print(f"[WARN] Expert {i} network structure not found")
                
                self._setup_gate_capture()
                print(f"[OK] CASMIR Gate-weighted Average Hook registration complete")
                print(f"   This provides unified representation similar to MLP dimensions")
                
            else:
                print(f"CASMIR Expert Concatenation Feature extraction mode")
                
                self.expert_hidden_features = [None] * len(self.model.experts)
                
                for i, expert in enumerate(self.model.experts):
                    if hasattr(expert, 'network') and len(expert.network) > 1:
                        hidden_layer = expert.network[-3]
                        hook = hidden_layer.register_forward_hook(self._make_expert_hook(i))
                        self.hook_handles.append(hook)
                        print(f"[OK] CASMIR Expert {i} Hidden Hook registered: {hidden_layer}")
                    else:
                        print(f"[WARN] Expert {i} network structure not found")
                
                print(f"[OK] CASMIR Expert Hidden Concatenation Hook registration complete")
        else:
            print(f"[WARN] Unsupported model type: {type(self.model)}")
    
    def _feature_hook(self, module, input, output):
        """Hook function for feature extraction (MLP models)"""
        if isinstance(output, tuple):
            features = output[0]
        else:
            features = output
            
        self.extracted_features = features.detach().cpu().numpy()
    
    def _make_expert_hook(self, expert_idx):
        """Create hook function for specific CASMIR expert (Concatenation)"""
        def expert_hook(module, input, output):
            if isinstance(output, tuple):
                features = output[0]
            else:
                features = output
            self.expert_hidden_features[expert_idx] = features.detach().cpu().numpy()
        return expert_hook
    
    def _make_weighted_expert_hook(self, expert_idx):
        """Create hook function for specific CASMIR expert (Weighted Average)"""
        def weighted_expert_hook(module, input, output):
            if isinstance(output, tuple):
                features = output[0]
            else:
                features = output
            self.expert_features[expert_idx] = features.detach().cpu().numpy()
        return weighted_expert_hook
    
    def _setup_gate_capture(self):
        """Setup for capturing CASMIR gate weights"""
        original_forward = self.model.forward
        
        def capture_gate_forward(x, y=None, density=None, apply_smoothing=True):
            result = original_forward(x, y, density, apply_smoothing)
            
            x_norm = self.model.feature_norm(x) if self.model.use_feature_norm else x
            gate_logits = self.model.gating_network(x_norm)
            gate_weights = F.softmax(gate_logits / self.model.gate_temperature.clamp(min=0.5), dim=1)
            
            self.gate_weights = gate_weights.detach().cpu().numpy()
            
            return result
        
        self.model.forward = capture_gate_forward
    
    def extract_features(self, X_data):
        """Extract features from given data"""
        self.extracted_features = None
        
        if isinstance(self.model, AdaSmoothMoEV1):
            if not self.casmir_weighted:
                self.expert_hidden_features = [None] * len(self.model.experts)
        
        with torch.no_grad():
            X_tensor = torch.tensor(X_data, dtype=torch.float32).to(DEVICE)
            outputs = self.model(X_tensor)
            
            if isinstance(outputs, tuple):
                _ = outputs[0]
            
            if isinstance(self.model, AdaSmoothMoEV1):
                if self.casmir_weighted and self.gate_weights is not None and self.expert_features[0] is not None:
                    weighted_features = np.zeros_like(self.expert_features[0])
                    
                    for i, expert_feat in enumerate(self.expert_features):
                        if expert_feat is not None:
                            gate_weight = self.gate_weights[:, i:i+1]
                            weighted_features += gate_weight * expert_feat
                    
                    print(f"CASMIR Gate-weighted Average Features created: {weighted_features.shape}")
                    return weighted_features
                    
                elif not self.casmir_weighted and hasattr(self, 'expert_hidden_features') and self.expert_hidden_features[0] is not None:
                    concatenated_features = np.concatenate(
                        [feat for feat in self.expert_hidden_features if feat is not None], 
                        axis=1
                    )
                    print(f"CASMIR Expert Concatenation Features created: {concatenated_features.shape}")
                    return concatenated_features
                else:
                    print("[ERROR] CASMIR Features extraction failed - no hook data")
                    return None
            
            elif self.extracted_features is not None:
                return self.extracted_features
            else:
                print("[ERROR] Features extraction failed")
                return None
    
    def cleanup(self):
        """Clean up hooks"""
        for handle in self.hook_handles:
            handle.remove()
        self.hook_handles = []

def load_artifact_metadata(artifact_dir):
    """Load artifact metadata"""
    
    print(f"Loading artifact metadata: {artifact_dir}")
    
    with open(os.path.join(artifact_dir, 'metadata.json'), 'r', encoding='utf-8') as f:
        metadata = json.load(f)
    
    with open(os.path.join(artifact_dir, 'hyperparams.json'), 'r', encoding='utf-8') as f:
        hyperparams = json.load(f)
    
    preprocessor = joblib.load(os.path.join(artifact_dir, 'preprocessor.joblib'))
    
    with open(os.path.join(artifact_dir, 'data_split_info.json'), 'r', encoding='utf-8') as f:
        data_split_info = json.load(f)
    
    with open(os.path.join(artifact_dir, 'experiment_config.json'), 'r', encoding='utf-8') as f:
        experiment_config = json.load(f)
    
    print(f"[OK] Metadata loaded: {metadata['algorithm_name']} on {metadata['dataset_name']}")
    
    return metadata, hyperparams, preprocessor, data_split_info, experiment_config

def recreate_model_with_correct_hyperparams(metadata, hyperparams, experiment_config, input_dim):
    """Recreate model with correct hyperparameters"""
    
    algorithm_name = metadata['algorithm_name']
    print(f"Recreating model with correct hyperparameters: {algorithm_name}")
    print(f"Hyperparameters: {hyperparams}")
    
    # Handle both CASMIR and AdaSmoothMoE names
    if algorithm_name in ['CASMIR', 'AdaSmoothMoE']:
        if 'expert_dim1' in hyperparams and 'expert_dim2' in hyperparams:
            expert_hidden_dims = [hyperparams['expert_dim1'], hyperparams['expert_dim2']]
        else:
            global_config = experiment_config.get('global_config', {})
            expert_hidden_dims = global_config.get('casmir_expert_hidden_dims', [32, 16])
            
        if 'gate_dim1' in hyperparams:
            gate_hidden_dims = [hyperparams['gate_dim1']]
        else:
            global_config = experiment_config.get('global_config', {})
            gate_hidden_dims = global_config.get('casmir_gate_hidden_dims', [32])
        
        # CAS parameters
        cas_params = {
            'k': hyperparams.get('k_neighbors', 10),
            'feature_bw': hyperparams.get('feature_bw', 1.5),
            'label_bw': hyperparams.get('label_bw', 10.0),
            'density_factor': hyperparams.get('density_factor', 0.1),
            'strength_base': hyperparams.get('strength_base', 0.6),
            'density_c': hyperparams.get('density_c', 20.0),
            'epsilon': 1e-6
        }
        
        model = CASMIR_V1(
            input_dim=input_dim,
            expert_hidden_dims=expert_hidden_dims,
            gate_hidden_dims=gate_hidden_dims,
            num_experts=3,
            cas_params=cas_params,
            expert_dropout=hyperparams.get('expert_dropout', 0.2)
        ).to(DEVICE)
        
        print(f"[OK] CASMIR structure: expert_dims={expert_hidden_dims}, gate_dims={gate_hidden_dims}")
        
    elif algorithm_name.startswith('MLP'):
        if 'hidden_dim1' in hyperparams and 'hidden_dim2' in hyperparams:
            hidden_dims = [hyperparams['hidden_dim1'], hyperparams['hidden_dim2']]
        else:
            n_layers = hyperparams.get("n_layers", 2)
            hidden_dims = [hyperparams.get(f"n_units_l{i}", 128) for i in range(n_layers)]
        
        dropout_rate = hyperparams.get("dropout", 0.2)
        
        model = MLP(
            input_dim=input_dim,
            hidden_dims=hidden_dims,
            output_dim=1,
            dropout_rate=dropout_rate
        ).to(DEVICE)
        
        print(f"[OK] MLP structure: {hidden_dims}, dropout: {dropout_rate}")
        
    else:
        raise ValueError(f"Unsupported algorithm: {algorithm_name}")
    
    return model

def setup_optimizer_and_loss(model, hyperparams, algorithm_name):
    """Setup optimizer and loss function"""
    
    optimizer_name = hyperparams.get('optimizer', 'Adam')
    lr = hyperparams.get('lr', 0.001)
    weight_decay = hyperparams.get('weight_decay', 0.0)
    
    if optimizer_name == 'AdamW':
        optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    elif optimizer_name == 'Adam':
        optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    else:
        optimizer = optim.Adam(model.parameters(), lr=lr)
    
    print(f"[OK] Optimizer: {optimizer_name}, lr: {lr}, weight_decay: {weight_decay}")
    
    if 'ConR' in algorithm_name:
        w = hyperparams.get('conr_distance_threshold', 1.0)
        t = hyperparams.get('conr_temperature', 0.07)
        e = hyperparams.get('conr_pushing_power', 0.01)
        alpha = hyperparams.get('conr_alpha', 1.0)
        mse_weight = hyperparams.get('conr_mse_weight', 1.0)
        criterion = ConRLoss(w=w, t=t, e=e, alpha=alpha, mse_weight=mse_weight)
        print(f"[OK] ConRLoss: w={w}, t={t}, e={e}, alpha={alpha}, mse_weight={mse_weight}")
    elif 'RankSim' in algorithm_name:
        lambda_val = hyperparams.get('ranksim_lambda_val', 1.0)
        alpha = hyperparams.get('ranksim_alpha', 1.0)
        criterion = RankSimLoss(lambda_val=lambda_val, alpha=alpha)
        print(f"[OK] RankSimLoss: lambda={lambda_val}, alpha={alpha}")
    else:
        criterion = nn.MSELoss()
        print("[OK] MSELoss")
    
    return optimizer, criterion

def train_model_with_early_stopping(model, optimizer, criterion, X_train, y_train, X_val, y_val, experiment_config, max_epochs=600, patience=600):
    """Train model with Early Stopping"""
    
    batch_size = experiment_config.get('batch_size', 32)
    
    print(f"Training model: {max_epochs} epochs, patience: {patience}")
    
    X_train_tensor = torch.tensor(X_train, dtype=torch.float32).to(DEVICE)
    y_train_tensor = torch.tensor(y_train.reshape(-1, 1), dtype=torch.float32).to(DEVICE)
    X_val_tensor = torch.tensor(X_val, dtype=torch.float32).to(DEVICE)
    y_val_tensor = torch.tensor(y_val.reshape(-1, 1), dtype=torch.float32).to(DEVICE)
    
    best_val_loss = float('inf')
    best_model_state = None
    patience_counter = 0
    
    for epoch in range(max_epochs):
        model.train()
        
        train_loss = 0.0
        num_batches = 0
        for i in range(0, len(X_train_tensor), batch_size):
            batch_X = X_train_tensor[i:i+batch_size]
            batch_y = y_train_tensor[i:i+batch_size]
            
            optimizer.zero_grad()
            outputs = model(batch_X)
            
            if isinstance(outputs, tuple):
                main_output = outputs[0]
                loss = criterion(main_output, batch_y)
            else:
                loss = criterion(outputs, batch_y)
                
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            num_batches += 1
        
        avg_train_loss = train_loss / num_batches
        
        model.eval()
        with torch.no_grad():
            val_outputs = model(X_val_tensor)
            
            if isinstance(val_outputs, tuple):
                main_val_output = val_outputs[0]
                val_loss = nn.MSELoss()(main_val_output, y_val_tensor).item()
            else:
                val_loss = nn.MSELoss()(val_outputs, y_val_tensor).item()
        
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model_state = copy.deepcopy(model.state_dict())
            patience_counter = 0
        else:
            patience_counter += 1
        
        if epoch % 50 == 0 or patience_counter == 0:
            print(f"  Epoch {epoch:3d}: Train Loss: {avg_train_loss:.4f}, Val Loss: {val_loss:.4f}, Best: {best_val_loss:.4f}")
        
        if patience_counter >= patience:
            print(f"Early stopping at epoch {epoch} (patience: {patience})")
            break
    
    if best_model_state is not None:
        model.load_state_dict(best_model_state)
        print(f"[OK] Best validation model restored (val_loss: {best_val_loss:.4f})")
    
    return model

def create_visualization(features, shot_types, algorithm_name, dataset_name, method='both', coloring='shot_type', target_values=None, colormap_name='Spectral', display_off=False):
    """Create visualization using PCA/t-SNE"""
    
    print(f"Creating visualization: {algorithm_name} on {dataset_name}")
    
    # Colormap setup
    colormap = getattr(plt.cm, colormap_name)
    
    # Visualization method selection
    if method in ['both', 'pca']:
        # PCA
        pca = PCA(n_components=2, random_state=42)
        embedding_pca = pca.fit_transform(features)
        
    if method in ['both', 'tsne']:
        # t-SNE
        try:
            tsne = TSNE(n_components=2, perplexity=min(30, len(features)//4), random_state=42)
            embedding_tsne = tsne.fit_transform(features)
            tsne_success = True
        except Exception as e:
            print(f"[WARN] t-SNE failed: {e}")
            tsne_success = False
    
    # Color method selection and target value transformation
    if coloring == 'shot_type':
        # Shot type based color
        color_map = {'few': 'red', 'medium': 'orange', 'many': 'green', 'unknown': 'gray'}
        use_continuous_color = False
        color_values = None
    elif coloring == 'target_value':
        # Continuous color based on target value
        use_continuous_color = True
        color_values = target_values
        cbar_label = 'Target Value'
    elif coloring == 'target_log':
        # Log transformed target value (improves skewed distribution)
        use_continuous_color = True
        color_values = np.log1p(target_values)  # log(1 + x) for safe transformation
        cbar_label = 'Log(1 + Target Value)'
    elif coloring == 'target_quantile':
        # Quantile transformation (uniform distribution)
        use_continuous_color = True
        from scipy.stats import rankdata
        color_values = rankdata(target_values) / len(target_values)  # normalize to 0-1 range
        cbar_label = 'Target Value Quantile'
    else:
        # Default fallback
        color_map = {'few': 'red', 'medium': 'orange', 'many': 'green', 'unknown': 'gray'}
        use_continuous_color = False
        color_values = None
    
    # Plot setup
    if method == 'both' and tsne_success:
        fig, axes = plt.subplots(1, 2, figsize=(15, 6))
    else:
        fig, ax = plt.subplots(1, 1, figsize=(8, 6))
        axes = [ax]
    
    # PCA plot
    if method in ['both', 'pca']:
        ax = axes[0] if len(axes) > 1 else axes[0]
        
        if use_continuous_color and color_values is not None:
            # Continuous color plot
            scatter = ax.scatter(
                embedding_pca[:, 0], 
                embedding_pca[:, 1], 
                c=color_values, 
                cmap=colormap,
                alpha=0.7, 
                s=50,
                edgecolors='black',
                linewidth=0.3
            )
            # Add colorbar
            cbar = plt.colorbar(scatter, ax=ax)
            cbar.set_label(cbar_label, rotation=270, labelpad=15)
        else:
            # Shot type based color plot
            for shot_type in ['few', 'medium', 'many']:
                mask = np.array(shot_types) == shot_type
                if np.any(mask):
                    ax.scatter(
                        embedding_pca[mask, 0], 
                        embedding_pca[mask, 1], 
                        c=color_map[shot_type], 
                        alpha=0.7, 
                        s=50,
                        label=f'{shot_type.capitalize()}-shot ({np.sum(mask)})',
                        edgecolors='black',
                        linewidth=0.3
                    )
            ax.legend()
        
        ax.set_xlabel('PCA Component 1')
        ax.set_ylabel('PCA Component 2')
        ax.grid(True, alpha=0.3)
    
    # t-SNE plot
    if method in ['both', 'tsne'] and tsne_success:
        ax = axes[1] if len(axes) > 1 else axes[0]
        
        if use_continuous_color and color_values is not None:
            # Continuous color plot
            scatter = ax.scatter(
                embedding_tsne[:, 0], 
                embedding_tsne[:, 1], 
                c=color_values, 
                cmap=colormap,
                alpha=0.7, 
                s=50,
                edgecolors='black',
                linewidth=0.3
            )
            # Add colorbar
            cbar = plt.colorbar(scatter, ax=ax)
            cbar.set_label(cbar_label, rotation=270, labelpad=15)
        else:
            # Shot type based color plot
            for shot_type in ['few', 'medium', 'many']:
                mask = np.array(shot_types) == shot_type
                if np.any(mask):
                    ax.scatter(
                        embedding_tsne[mask, 0], 
                        embedding_tsne[mask, 1], 
                        c=color_map[shot_type], 
                        alpha=0.7, 
                        s=50,
                        label=f'{shot_type.capitalize()}-shot ({np.sum(mask)})',
                        edgecolors='black',
                        linewidth=0.3
                    )
            ax.legend()
        
        ax.set_xlabel('t-SNE Component 1')
        ax.set_ylabel('t-SNE Component 2')
        ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    
    # Save
    timestamp = pd.Timestamp.now().strftime("%Y%m%d_%H%M%S")
    output_path = f"./outputs/FIXED_{algorithm_name}_{dataset_name}_{timestamp}.png"
    os.makedirs('./outputs', exist_ok=True)
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    print(f"[OK] Visualization saved: {output_path}")
    
    features_path = f"./outputs/FIXED_{algorithm_name}_{dataset_name}_{timestamp}_features.npz"
    np.savez_compressed(features_path, features=features, shot_types=shot_types)
    print(f"[OK] Features saved: {features_path}")
    
    if not display_off:
        plt.show()
    else:
        plt.close()
        print("[INFO] Server mode: display skipped")
    
    return output_path

def main():
    import argparse
    
    parser = argparse.ArgumentParser(description="Model retraining and representation visualization based on HPO results")
    
    parser.add_argument('--artifact_dir', type=str, help='Artifact directory (direct specification)')
    parser.add_argument('--artifacts_folder', type=str, help='Artifact base folder (for auto scan)')
    parser.add_argument('--auto_scan', action='store_true', help='Auto scan mode: find and process folders matching dataset in artifacts_folder')
    parser.add_argument('--algorithm', type=str, help='Filter specific algorithm (e.g., CASMIR, MLP, MLP_ConR)')
    
    parser.add_argument('--dataset', type=str, required=True, help='Dataset name')
    parser.add_argument('--visualization', choices=['pca', 'tsne', 'both'], default='tsne', 
                       help='Visualization method')
    parser.add_argument('--coloring', choices=['shot_type', 'target_value', 'target_log', 'target_quantile'], default='target_value',
                       help='Coloring method: target_value (continuous, default), shot_type (Few/Medium/Many), target_log, target_quantile')
    parser.add_argument('--colormap', choices=['viridis', 'coolwarm', 'RdYlBu', 'plasma', 'inferno', 'magma', 'cividis', 'RdBu', 'seismic', 'Spectral'], default='Spectral',
                       help='Colormap')
    parser.add_argument('--weighted_feature', action='store_true',
                       help='CASMIR only: Use gate-weighted average feature (default: Expert Concatenation)')
    parser.add_argument('--display_off', action='store_true',
                       help='Server mode: save files only without display')
    parser.add_argument('--epochs', type=int, default=600,
                       help='Number of training epochs (default: 600)')
    parser.add_argument('--patience', type=int, default=600,
                       help='Early stopping patience (default: 600)')
    
    args = parser.parse_args()
    
    if args.auto_scan and not args.artifacts_folder:
        print("[ERROR] --auto_scan mode requires --artifacts_folder.")
        return
    
    if not args.auto_scan and not args.artifact_dir:
        print("[ERROR] Normal mode requires --artifact_dir.")
        return
    
    set_global_random_seed(42)
    
    print("Model Retraining and Visualization based on HPO Results")
    print("=" * 70)
    
    if args.auto_scan:
        print(f"Auto scan mode: {args.artifacts_folder}")
        
        artifacts = scan_artifacts_folder(args.artifacts_folder, args.dataset, args.algorithm)
        
        if not artifacts:
            print("[ERROR] No artifacts to process.")
            return
        
        success_count = 0
        for artifact_info in artifacts:
            result = process_single_artifact(artifact_info, args)
            if result:
                success_count += 1
        
        print(f"\nAuto scan complete!")
        print(f"Results: {success_count}/{len(artifacts)} successful")
        return
    
    try:
        # 1. Load artifact
        metadata, hyperparams, preprocessor, data_split_info, experiment_config = load_artifact_metadata(args.artifact_dir)
        
        # 2. Load data and recreate split
        print("\nLoading data and recreating split...")
        df, target_col = load_dataset(args.dataset)
        X, y, _ = preprocess_data(df, target_col)
        
        train_indices = data_split_info['train_indices']
        val_indices = data_split_info['val_indices']
        test_indices = data_split_info['test_indices']
        
        test_balanced_relative_indices = data_split_info.get('test_balanced_indices', list(range(len(test_indices))))
        val_balanced_relative_indices = data_split_info.get('val_balanced_indices', list(range(len(val_indices))))
        
        test_balanced_absolute_indices = [test_indices[i] for i in test_balanced_relative_indices]
        val_balanced_absolute_indices = [val_indices[i] for i in val_balanced_relative_indices]
        
        X_train_df = X.iloc[train_indices].copy()
        X_val_df = X.iloc[val_balanced_absolute_indices].copy()
        X_test_df = X.iloc[test_balanced_absolute_indices].copy()
        X_test_original_df = X.iloc[test_indices].copy()
        
        y_train_sr = y.iloc[train_indices].copy()
        y_val_sr = y.iloc[val_balanced_absolute_indices].copy()
        y_test_sr = y.iloc[test_balanced_absolute_indices].copy()
        y_test_original_sr = y.iloc[test_indices].copy()
        
        X_train = preprocessor.transform(X_train_df)
        X_val = preprocessor.transform(X_val_df)
        X_test = preprocessor.transform(X_test_df)
        X_test_original = preprocessor.transform(X_test_original_df)
        
        y_train = y_train_sr.values
        y_val = y_val_sr.values
        y_test = y_test_sr.values
        y_test_original = y_test_original_sr.values
        
        print(f"[OK] Data split recreated:")
        print(f"   Train (Original Imbalanced): {len(X_train)}")
        print(f"   Val (Balanced): {len(X_val)}")
        print(f"   Test (Balanced): {len(X_test)}")
        
        # 3. Model recreation
        print("\nRecreating model with correct hyperparameters...")
        input_dim = X_train.shape[1]
        model = recreate_model_with_correct_hyperparams(metadata, hyperparams, experiment_config, input_dim)
        
        # 4. Optimizer and loss function setup
        print("\nSetting up optimizer and loss function...")
        optimizer, criterion = setup_optimizer_and_loss(model, hyperparams, metadata['algorithm_name'])
        
        # 5. Train with Early Stopping
        print("\nTraining model with Early Stopping...")
        model = train_model_with_early_stopping(model, optimizer, criterion, X_train, y_train, X_val, y_val, experiment_config, max_epochs=args.epochs, patience=args.patience)
        
        # 6. Feature extraction
        print("\nExtracting features from Val + Test data...")
        print(f"Data size: Val ({len(val_indices)}) + Test ({len(test_indices)}) = Total {len(val_indices) + len(test_indices)}")
        
        X_val_test_original = X.iloc[val_indices + test_indices].copy()
        y_val_test_original = y.iloc[val_indices + test_indices].copy()
        X_val_test = preprocessor.transform(X_val_test_original)
        
        feature_extractor = FeatureExtractorModel(model)
        features = feature_extractor.extract_features(X_val_test)
        
        if features is None:
            print("[ERROR] Feature extraction failed")
            return
        
        print(f"[OK] Feature extraction complete: {features.shape}")
        
        # 7. Shot type classification
        print("\nShot type classification...")
        val_shot_mapping = data_split_info.get('val_shot_mapping', {})
        test_shot_mapping = data_split_info.get('test_shot_mapping', {})
        
        if val_shot_mapping and test_shot_mapping:
            print("[OK] Using original val + test shot mapping (no data leakage)")
            shot_types = []
            
            for relative_idx in range(len(val_indices)):
                shot_type = val_shot_mapping.get(str(relative_idx), 'unknown')
                shot_types.append(shot_type)
            
            for relative_idx in range(len(test_indices)):
                shot_type = test_shot_mapping.get(str(relative_idx), 'unknown')
                shot_types.append(shot_type)
        else:
            print("[WARN] Shot mapping not found, using y-value based classification")
            y_val_test_flat = y_val_test_original.values.flatten()
            few_threshold = np.percentile(y_val_test_flat, 33.33)
            many_threshold = np.percentile(y_val_test_flat, 66.67)
            
            shot_types = []
            for val in y_val_test_flat:
                if val <= few_threshold:
                    shot_types.append('few')
                elif val >= many_threshold:
                    shot_types.append('many')
                else:
                    shot_types.append('medium')
        
        print(f"Shot type distribution: Few={shot_types.count('few')}, Medium={shot_types.count('medium')}, Many={shot_types.count('many')}")
        
        # 8. Performance evaluation
        print("\nPerformance evaluation (benchmark comparison)...")
        
        model.eval()
        with torch.no_grad():
            device = model.device if hasattr(model, 'device') else next(model.parameters()).device
            
            X_test_tensor = torch.tensor(X_test, dtype=torch.float32).to(device)
            if metadata['algorithm_name'] in ['CASMIR', 'AdaSmoothMoE']:
                outputs = model(X_test_tensor, apply_smoothing=False)
                if isinstance(outputs, tuple):
                    y_pred_balanced = outputs[0].cpu().numpy().flatten()
                else:
                    y_pred_balanced = outputs.cpu().numpy().flatten()
            else:
                y_pred_balanced = model(X_test_tensor).cpu().numpy().flatten()
            
            X_test_original_tensor = torch.tensor(X_test_original, dtype=torch.float32).to(device)
            if metadata['algorithm_name'] in ['CASMIR', 'AdaSmoothMoE']:
                outputs_original = model(X_test_original_tensor, apply_smoothing=False)
                if isinstance(outputs_original, tuple):
                    y_pred_original = outputs_original[0].cpu().numpy().flatten()
                else:
                    y_pred_original = outputs_original.cpu().numpy().flatten()
            else:
                y_pred_original = model(X_test_original_tensor).cpu().numpy().flatten()
        
        from src.data.datasets import calculate_shot_wise_mae
        
        test_balanced_shot_indices = data_split_info.get('test_balanced_shot_indices', {})
        test_balanced_shot_indices = {int(k): v for k, v in test_balanced_shot_indices.items()}
        mae_results_balanced = calculate_shot_wise_mae(y_test.flatten(), y_pred_balanced, test_balanced_shot_indices)
        
        test_shot_mapping = data_split_info.get('test_shot_mapping', {})
        test_shot_mapping = {int(k): v for k, v in test_shot_mapping.items()}
        mae_results_original = calculate_shot_wise_mae(y_test_original.flatten(), y_pred_original, test_shot_mapping)
        
        print(f"Retrained model performance:")
        print(f"   === Balanced Test Set ===")
        print(f"   Few-shot MAE: {mae_results_balanced['few']:.6f}")
        print(f"   Medium-shot MAE: {mae_results_balanced['medium']:.6f}")
        print(f"   Many-shot MAE: {mae_results_balanced['many']:.6f}")
        print(f"   Overall MAE: {mae_results_balanced['overall']:.6f}")
        
        print(f"   === Original Imbalanced Test Set ===")
        print(f"   Few-shot MAE: {mae_results_original['few']:.6f}")
        print(f"   Medium-shot MAE: {mae_results_original['medium']:.6f}")
        print(f"   Many-shot MAE: {mae_results_original['many']:.6f}")
        print(f"   Overall MAE: {mae_results_original['overall']:.6f}")
        
        # Compare with original benchmark if available
        artifact_dir = args.artifact_dir
        bench_file_pattern = f"bench_{args.dataset}_{metadata['algorithm_name']}_*.json"
        import glob
        bench_files = glob.glob(os.path.join(artifact_dir, bench_file_pattern))
        
        if bench_files:
            bench_file = bench_files[0]
            with open(bench_file, 'r', encoding='utf-8') as f:
                bench_data = json.load(f)
            
            if args.dataset in bench_data and metadata['algorithm_name'] in bench_data[args.dataset]:
                original_bench_balanced = bench_data[args.dataset][metadata['algorithm_name']]['bal_mae']
                original_bench_original = bench_data[args.dataset][metadata['algorithm_name']]['ori_mae']
                
                print(f"\nOriginal benchmark results:")
                print(f"   === Balanced Test Set ===")
                print(f"   Few-shot MAE: {original_bench_balanced['few']:.6f}")
                print(f"   Medium-shot MAE: {original_bench_balanced['medium']:.6f}")
                print(f"   Many-shot MAE: {original_bench_balanced['many']:.6f}")
                print(f"   Overall MAE: {original_bench_balanced['overall']:.6f}")
                
                print(f"   === Original Imbalanced Test Set ===")
                print(f"   Few-shot MAE: {original_bench_original['few']:.6f}")
                print(f"   Medium-shot MAE: {original_bench_original['medium']:.6f}")
                print(f"   Many-shot MAE: {original_bench_original['many']:.6f}")
                print(f"   Overall MAE: {original_bench_original['overall']:.6f}")
                
                print(f"\nDifference analysis:")
                print(f"   === Balanced Test Set Diff ===")
                print(f"   Few-shot diff: {abs(mae_results_balanced['few'] - original_bench_balanced['few']):.6f}")
                print(f"   Medium-shot diff: {abs(mae_results_balanced['medium'] - original_bench_balanced['medium']):.6f}")
                print(f"   Many-shot diff: {abs(mae_results_balanced['many'] - original_bench_balanced['many']):.6f}")
                print(f"   Overall diff: {abs(mae_results_balanced['overall'] - original_bench_balanced['overall']):.6f}")
                
                print(f"   === Original Imbalanced Test Set Diff ===")
                print(f"   Few-shot diff: {abs(mae_results_original['few'] - original_bench_original['few']):.6f}")
                print(f"   Medium-shot diff: {abs(mae_results_original['medium'] - original_bench_original['medium']):.6f}")
                print(f"   Many-shot diff: {abs(mae_results_original['many'] - original_bench_original['many']):.6f}")
                print(f"   Overall diff: {abs(mae_results_original['overall'] - original_bench_original['overall']):.6f}")
                
                max_diff_original = max(abs(mae_results_original['few'] - original_bench_original['few']),
                                       abs(mae_results_original['medium'] - original_bench_original['medium']),
                                       abs(mae_results_original['many'] - original_bench_original['many']),
                                       abs(mae_results_original['overall'] - original_bench_original['overall']))
                
                if max_diff_original < 0.1:
                    print("[OK] Performance reproduced: diff < 0.1")
                elif max_diff_original < 1.0:
                    print("[WARN] Performance partially reproduced: diff < 1.0")
                else:
                    print("[ERROR] Performance reproduction failed: diff >= 1.0")
        
        # 9. Visualization
        print("\nRepresentation Learning visualization...")
        if args.coloring in ['target_value', 'target_log', 'target_quantile']:
            target_values_for_vis = y_val_test_original.values.flatten()
            print(f"Target value range: {target_values_for_vis.min():.3f} ~ {target_values_for_vis.max():.3f}")
            
            print(f"Target value distribution:")
            print(f"   0-1: {np.sum((target_values_for_vis >= 0) & (target_values_for_vis < 1))}")
            print(f"   1-10: {np.sum((target_values_for_vis >= 1) & (target_values_for_vis < 10))}") 
            print(f"   10-50: {np.sum((target_values_for_vis >= 10) & (target_values_for_vis < 50))}")
            print(f"   50+: {np.sum(target_values_for_vis >= 50)}")
        else:
            target_values_for_vis = None
            
        output_path = create_visualization(
            features, shot_types,
            metadata['algorithm_name'], args.dataset,
            method=args.visualization,
            coloring=args.coloring,
            target_values=target_values_for_vis,
            colormap_name=args.colormap,
            display_off=args.display_off
        )
        
        # 10. Save features
        features_save_path = output_path.replace('.png', '_features.npz')
        np.savez(features_save_path, 
                features=features, 
                shot_types=shot_types,
                y_test=y_test,
                algorithm=metadata['algorithm_name'],
                dataset=args.dataset,
                hyperparams=hyperparams,
                experiment_config=experiment_config)
        print(f"[OK] Features saved: {features_save_path}")
        
        feature_extractor.cleanup()
        
        print("\nAll tasks complete!")
        print(f"Result: {output_path}")
        
    except Exception as e:
        print(f"[ERROR] Error occurred: {e}")
        import traceback
        traceback.print_exc()

if __name__ == "__main__":
    main()
