#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Inference script for pre-trained models.
"""

import os
import sys
import json
import joblib
import torch
import numpy as np
import pandas as pd
import argparse
import glob
from pathlib import Path
from datetime import datetime

from src.data.datasets import (
    load_dataset, preprocess_data, split_data, 
    classify_bins_by_samples, map_shot_types, calculate_shot_wise_mae
)
from src.models.basic_models import MLP, MLPadv, SimpleThreeMLPEnsemble, XGBoostWrapper, LightGBMWrapper
from src.models.CASMIR_V1 import CASMIR_V1, precompute_density_and_boundaries
from src.evaluation.evaluation import calculate_region_mae_with_thresholds
from sklearn.preprocessing import StandardScaler
from config import CONFIG

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def find_artifact(dataset=None, algorithm=None, artifacts_dir="learned_models"):
    """Find artifact directory by dataset and algorithm name."""
    if not os.path.exists(artifacts_dir):
        raise FileNotFoundError(f"Artifacts directory not found: {artifacts_dir}")
    
    pattern = f"V*_{dataset}_{algorithm}_*" if dataset and algorithm else "V*"
    search_path = os.path.join(artifacts_dir, pattern)
    
    matches = glob.glob(search_path)
    
    if not matches:
        raise FileNotFoundError(f"No matching artifact found: {search_path}")
    
    # Exact algorithm name matching
    exact_matches = []
    for match in matches:
        basename = os.path.basename(match)
        parts = basename.split('_')
        try:
            dataset_idx = parts.index(dataset)
            algo_parts = []
            for i in range(dataset_idx + 1, len(parts)):
                if len(parts[i]) == 4 and parts[i].isdigit():
                    break
                algo_parts.append(parts[i])
            extracted_algo = '_'.join(algo_parts)
            if extracted_algo == algorithm:
                exact_matches.append(match)
        except (ValueError, IndexError):
            continue
    
    if exact_matches:
        latest = sorted(exact_matches)[-1]
    else:
        latest = sorted(matches)[-1]
    
    print(f"[INFO] Found artifact: {latest}")
    return latest

def load_artifact_metadata(artifact_dir):
    """Load metadata from artifact directory."""
    if not os.path.exists(artifact_dir):
        raise FileNotFoundError(f"Artifact directory not found: {artifact_dir}")
    
    print(f"[LOAD] Artifact loading: {artifact_dir}")
    
    try:
        metadata_path = os.path.join(artifact_dir, 'metadata.json')
        model_info_path = os.path.join(artifact_dir, 'model_info.json')
        
        if os.path.exists(metadata_path):
            with open(metadata_path, 'r', encoding='utf-8') as f:
                metadata = json.load(f)
        elif os.path.exists(model_info_path):
            with open(model_info_path, 'r', encoding='utf-8') as f:
                metadata = json.load(f)
        else:
            raise FileNotFoundError("metadata.json or model_info.json not found")
        
        config = {}
        config_path = os.path.join(artifact_dir, 'config.json')
        if not os.path.exists(config_path):
            config_path = os.path.join(artifact_dir, 'experiment_config.json')
        if os.path.exists(config_path):
            with open(config_path, 'r', encoding='utf-8') as f:
                config = json.load(f)
        
        params_path = os.path.join(artifact_dir, 'best_params.json')
        if not os.path.exists(params_path):
            params_path = os.path.join(artifact_dir, 'hyperparams.json')
        with open(params_path, 'r', encoding='utf-8') as f:
            best_params = json.load(f)
        
        print(f"[OK] Metadata loaded")
        
        dataset_name = metadata.get('dataset') or metadata.get('dataset_name', 'Unknown')
        algorithm_name = metadata.get('algorithm') or metadata.get('algorithm_name', 'Unknown')
        saved_performance = metadata.get('performance', {})
        
        print(f"   - Dataset: {dataset_name}")
        print(f"   - Algorithm: {algorithm_name}")
        
        if saved_performance:
            ori_overall = saved_performance.get('ori_mae', {}).get('overall')
            bal_overall = saved_performance.get('bal_mae', {}).get('overall')
            if ori_overall:
                print(f"   - Saved ori_mae.overall: {ori_overall:.4f}")
            if bal_overall:
                print(f"   - Saved bal_mae.overall: {bal_overall:.4f}")
        
        return metadata, config, best_params
    
    except FileNotFoundError as e:
        raise FileNotFoundError(f"Required file not found: {e}")
    except json.JSONDecodeError as e:
        raise ValueError(f"JSON parsing error: {e}")

def load_model(artifact_dir, metadata, config, best_params, device=DEVICE, input_dim=None):
    """
    Load a saved model from artifact directory.
    
    Args:
        artifact_dir: Path to artifact directory.
        metadata: Model metadata dictionary.
        config: Experiment configuration dictionary.
        best_params: Hyperparameter dictionary.
        device: Target device (cpu or cuda).
        input_dim: Input dimension. Auto-detected if None.
    
    Returns:
        Tuple of (model, scaler). Scaler may be None for tree-based models.
    """
    algorithm = metadata.get('algorithm') or metadata.get('algorithm_name', 'Unknown')
    
    print(f"[LOAD] Model loading: {algorithm}")
    
    if algorithm in ['XGBoost', 'LightGBM', 'CatBoost', 'SMOTER_XGBoost', 'GaussianNoise_XGBoost']:
        model_path = os.path.join(artifact_dir, 'model.joblib')
        if not os.path.exists(model_path):
            raise FileNotFoundError(f"Model file not found: {model_path}")
        
        model = joblib.load(model_path)
        print(f"[OK] Tree model loaded")
        return model, None
    
    else:
        scaler_path = os.path.join(artifact_dir, 'scaler.joblib')
        if not os.path.exists(scaler_path):
            scaler_path = os.path.join(artifact_dir, 'preprocessor.joblib')
        
        if os.path.exists(scaler_path):
            scaler = joblib.load(scaler_path)
            print(f"[OK] Scaler loaded")
        else:
            print(f"[WARN] Scaler file not found. Creating new StandardScaler.")
            scaler = StandardScaler()
        
        if input_dim is None:
            input_dim = best_params.get('input_dim', config.get('input_dim'))
        
        if input_dim is None:
            data_split_path = os.path.join(artifact_dir, 'data_split_info.json')
            if os.path.exists(data_split_path):
                with open(data_split_path, 'r', encoding='utf-8') as f:
                    data_split_info = json.load(f)
                    original_shape = data_split_info.get('original_data_shape', [])
                    if len(original_shape) >= 2:
                        input_dim = original_shape[1]
                        print(f"   [INFO] Auto-detected input_dim={input_dim} from data_split_info")
        
        if input_dim is None:
            raise ValueError("input_dim not found. Load data first.")
        
        print(f"   [INFO] Using input_dim={input_dim}")
        
        if algorithm == 'CASMIR':
            expert_dim1 = best_params.get("expert_dim1", 128)
            expert_dim2 = best_params.get("expert_dim2", 64)
            expert_hidden_dims = [expert_dim1, expert_dim2]
            
            gate_dim1 = best_params.get("gate_dim1", 64)
            gate_hidden_dims = [gate_dim1]
            
            cas_params = {
                'k': best_params.get('k_neighbors', 10),
                'feature_bw': best_params.get('feature_bw', 1.5),
                'label_bw': best_params.get('label_bw', 10.0),
                'density_factor': best_params.get('density_factor', 0.1),
                'strength_base': best_params.get('strength_base', 0.6),
                'density_c': best_params.get('density_c', 20.0),
                'epsilon': 1e-6
            }
            
            num_experts = best_params.get('num_experts', 3)
            expert_type = best_params.get('expert_type', 'tabular')
            
            model = CASMIR_V1(
                input_dim=input_dim,
                num_experts=num_experts,
                expert_hidden_dims=expert_hidden_dims,
                gate_hidden_dims=gate_hidden_dims,
                cas_params=cas_params,
                expert_type=expert_type
            )
            
            model_path = os.path.join(artifact_dir, 'model.pth')
            if not os.path.exists(model_path):
                raise FileNotFoundError(f"Model weights not found: {model_path}")
            
            model.load_state_dict(torch.load(model_path, map_location=device))
        
        elif algorithm.startswith('MLP') or algorithm == 'Simple_Ensemble':
            model_path = os.path.join(artifact_dir, 'model.pth')
            if not os.path.exists(model_path):
                raise FileNotFoundError(f"Model weights not found: {model_path}")
            
            state_dict = torch.load(model_path, map_location='cpu')
            
            hidden_dim1_from_state = None
            hidden_dim2_from_state = None
            
            if 'feature_layers.0.weight' in state_dict:
                hidden_dim1_from_state = state_dict['feature_layers.0.weight'].shape[0]
            if 'feature_layers.4.weight' in state_dict:
                hidden_dim2_from_state = state_dict['feature_layers.4.weight'].shape[0]
            
            hidden_dim1 = hidden_dim1_from_state if hidden_dim1_from_state else best_params.get('hidden_dim1', 128)
            hidden_dim2 = hidden_dim2_from_state if hidden_dim2_from_state else best_params.get('hidden_dim2', 64)
            
            hidden_dims = [hidden_dim1, hidden_dim2]
            
            if algorithm == 'Simple_Ensemble':
                model = SimpleThreeMLPEnsemble(
                    input_dim=input_dim,
                    hidden_dims=hidden_dims,
                    output_dim=1,
                    dropout_rate=best_params.get('dropout_rate', 0.1),
                    use_residual=True,
                    num_models=3
                )
            else:
                model = MLP(
                    input_dim=input_dim,
                    hidden_dims=hidden_dims,
                    output_dim=1,
                    dropout_rate=best_params.get('dropout_rate', 0.1),
                    use_residual=True
                )
            
            model.load_state_dict(state_dict)
        
        else:
            raise ValueError(f"Unsupported algorithm: {algorithm}")
        
        model = model.to(device)
        model.eval()
        
        print(f"[OK] PyTorch model loaded")
        return model, scaler

def predict(model, X, scaler=None, device=DEVICE):
    """Run model prediction."""
    if hasattr(model, 'predict') and not isinstance(model, torch.nn.Module):
        predictions = model.predict(X)
        return predictions
    
    else:
        if scaler is not None:
            X_scaled = scaler.transform(X)
        else:
            X_scaled = X
        
        X_tensor = torch.FloatTensor(X_scaled).to(device)
        
        with torch.no_grad():
            predictions = model(X_tensor)
            if isinstance(predictions, tuple):
                predictions = predictions[0]
            predictions = predictions.cpu().numpy().flatten()
        
        return predictions

def verify_performance(calculated_mae, saved_mae, mae_type='ori', tolerance=0.001):
    """
    Compare calculated MAE against saved performance.
    
    Args:
        calculated_mae: Calculated MAE results dict.
        saved_mae: Saved MAE results dict (ori_mae or bal_mae).
        mae_type: Type of MAE ('ori' or 'bal') for display.
        tolerance: Tolerance for verification.
    
    Returns:
        Tuple of (passed, message).
    """
    if not saved_mae:
        return None, f"No saved {mae_type}_mae info"
    
    saved_overall = saved_mae.get('overall')
    
    if saved_overall is None:
        return None, f"No saved {mae_type}_mae.overall"
    
    diff = abs(calculated_mae['overall'] - saved_overall)
    
    if diff <= tolerance:
        return True, f"{mae_type.upper()} verification passed (diff: {diff:.6f})"
    else:
        return False, f"{mae_type.upper()} verification failed (diff: {diff:.6f}, tolerance: {tolerance})"

def main():
    parser = argparse.ArgumentParser(description="Run inference with pre-trained models")
    
    parser.add_argument('--artifact', type=str, default=None,
                       help='Path to artifact directory')
    parser.add_argument('--dataset', type=str, default=None,
                       help='Dataset name')
    parser.add_argument('--algorithm', type=str, default=None,
                       help='Algorithm name')
    parser.add_argument('--custom_data', type=str, default=None,
                       help='Path to custom data CSV file')
    parser.add_argument('--artifacts_dir', type=str, default='learned_models',
                       help='Path to artifacts directory (default: learned_models)')
    parser.add_argument('--use_hpo_results', action='store_true',
                       help='Load from optimized_hpo_results')
    parser.add_argument('--output', type=str, default=None,
                       help='Path to save predictions (CSV)')
    parser.add_argument('--no_verify', action='store_true',
                       help='Skip verification against saved performance')
    parser.add_argument('--save_results', action='store_true',
                       help='Auto-save prediction results to artifact folder')
    
    args = parser.parse_args()
    
    if args.use_hpo_results:
        artifacts_dir = 'optimized_hpo_results'
    else:
        artifacts_dir = args.artifacts_dir
    
    if args.artifact:
        artifact_dir = args.artifact
    elif args.dataset and args.algorithm:
        artifact_dir = find_artifact(args.dataset, args.algorithm, artifacts_dir)
    else:
        print("[ERROR] --artifact or (--dataset + --algorithm) required.")
        parser.print_help()
        return
    
    try:
        metadata, config, best_params = load_artifact_metadata(artifact_dir)
    except Exception as e:
        print(f"[ERROR] Metadata load failed: {e}")
        return
    
    dataset_name = metadata.get('dataset') or metadata.get('dataset_name', 'Unknown')
    algorithm_name = metadata.get('algorithm') or metadata.get('algorithm_name', 'Unknown')
    saved_performance = metadata.get('performance', {})
    
    input_dim = None
    use_saved_preprocessor = False
    
    # Variables for balanced test set
    X_test_balanced = None
    y_test_balanced = None
    test_balanced_shot_indices = None
    test_shot_mapping = None
    
    if args.custom_data:
        print(f"[DATA] Loading custom data: {args.custom_data}")
        df = pd.read_csv(args.custom_data)
        X = df.iloc[:, :-1].values
        y = df.iloc[:, -1].values if df.shape[1] > 1 else None
        input_dim = X.shape[1]
        print(f"   - Samples: {len(X)}")
        print(f"   - Features: {input_dim}")
        few_threshold = None
        many_threshold = None
        y_bins = None
    else:
        print(f"[DATA] Loading dataset: {dataset_name}")
        try:
            from src.data.datasets import split_data_stratified, freedman_diaconis_bins, create_balanced_dataset
            
            df, target_col = load_dataset(dataset_name)
            y_raw = df[target_col]
            X_raw = df.drop(columns=[target_col])
            
            preprocessor_path = os.path.join(artifact_dir, 'preprocessor.joblib')
            if os.path.exists(preprocessor_path):
                saved_preprocessor = joblib.load(preprocessor_path)
                X_processed = saved_preprocessor.transform(X_raw)
                print(f"   [INFO] Using saved preprocessor.joblib")
            else:
                X_processed, y_raw, _ = preprocess_data(df, target_col)
                print(f"   [WARN] No saved preprocessor, using new preprocess_data")
            
            dataset_config = CONFIG.get('dataset_configs', {}).get(dataset_name, {})
            y_bins = dataset_config.get('y_bins', 0)
            if y_bins == 0:
                y_bins = freedman_diaconis_bins(y_raw)
            
            data_split_seed = config.get('data_split_seed', CONFIG.get('data_split_seed', 42))
            test_size = config.get('test_size', CONFIG.get('test_size', 0.2))
            validation_size = config.get('validation_size', CONFIG.get('validation_size', 0.15))
            
            X_processed_df = pd.DataFrame(X_processed, index=X_raw.index)
            
            _, _, X_test_df, _, _, y_test = split_data_stratified(
                X_processed_df, y_raw,
                test_size=test_size,
                validation_size=validation_size,
                n_bins=y_bins,
                random_state=data_split_seed
            )
            
            X = X_test_df.values if hasattr(X_test_df, 'values') else X_test_df
            y = y_test.values if hasattr(y_test, 'values') else y_test
            
            use_saved_preprocessor = True
            print(f"   [INFO] Data split with stratified method")
            
            input_dim = X.shape[1]
            
            dataset_config = CONFIG.get('dataset_configs', {}).get(dataset_name, {})
            few_threshold = dataset_config.get('few_threshold')
            many_threshold = dataset_config.get('many_threshold')
            
            print(f"   - Test samples (Original): {len(X)}")
            print(f"   - Features: {input_dim}")
            if few_threshold and many_threshold:
                print(f"   - Few threshold: {few_threshold}")
                print(f"   - Many threshold: {many_threshold}")
            
            # Load data_split_info for balanced test set
            data_split_path = os.path.join(artifact_dir, 'data_split_info.json')
            
            # If not found in current artifact_dir, try to find it via artifact_reference.txt
            if not os.path.exists(data_split_path):
                ref_path = os.path.join(artifact_dir, 'artifact_reference.txt')
                if os.path.exists(ref_path):
                    with open(ref_path, 'r', encoding='utf-8') as f:
                        ref_content = f.read()
                    # Parse: "Complete artifact location: ./artifacts/V002_..."
                    for line in ref_content.split('\n'):
                        if 'Complete artifact location:' in line:
                            original_artifact_dir = line.split(':', 1)[1].strip()
                            data_split_path = os.path.join(original_artifact_dir, 'data_split_info.json')
                            if os.path.exists(data_split_path):
                                print(f"   [INFO] Found data_split_info.json in original artifact: {original_artifact_dir}")
                            break
            
            if os.path.exists(data_split_path):
                with open(data_split_path, 'r', encoding='utf-8') as f:
                    data_split_info = json.load(f)
                
                # Get test shot mapping for original test set
                test_shot_mapping = data_split_info.get('test_shot_mapping', {})
                test_shot_mapping = {int(k): v for k, v in test_shot_mapping.items()}
                
                # Get balanced test indices and shot indices
                test_balanced_indices = data_split_info.get('test_balanced_indices', None)
                test_balanced_shot_indices = data_split_info.get('test_balanced_shot_indices', {})
                test_balanced_shot_indices = {int(k) if isinstance(k, str) else k: v for k, v in test_balanced_shot_indices.items()}
                
                if test_balanced_indices is not None:
                    # Create balanced test set
                    X_test_balanced = X[test_balanced_indices]
                    y_test_balanced = y[test_balanced_indices]
                    print(f"   - Test samples (Balanced): {len(X_test_balanced)}")
                else:
                    print(f"   [WARN] No test_balanced_indices in data_split_info.json")
            else:
                print(f"   [WARN] data_split_info.json not found, balanced test not available")
        
        except Exception as e:
            print(f"[ERROR] Data loading failed: {e}")
            import traceback
            traceback.print_exc()
            return
    
    try:
        model, scaler = load_model(artifact_dir, metadata, config, best_params, input_dim=input_dim)
        if use_saved_preprocessor:
            scaler = None
            print(f"   [INFO] Scaler disabled (already applied by preprocessor)")
    except Exception as e:
        print(f"[ERROR] Model load failed: {e}")
        return
    
    print(f"\n[PREDICT] Running predictions...")
    try:
        # ==========================================
        # 1. Original Test Set Prediction
        # ==========================================
        predictions_ori = predict(model, X, scaler)
        print(f"[OK] Original test prediction complete: {len(predictions_ori)} samples")
        
        if y is not None and hasattr(y, 'values'):
            y = y.values
        
        ori_mae_results = None
        bal_mae_results = None
        
        if y is not None:
            y_flat = y.flatten() if hasattr(y, 'flatten') else np.array(y).flatten()
            
            overall_mae_ori = float(np.mean(np.abs(predictions_ori - y_flat)))
            rmse_ori = float(np.sqrt(np.mean((predictions_ori - y_flat) ** 2)))
            
            # Calculate shot-wise MAE for Original test set
            if test_shot_mapping:
                # Use saved shot mapping from data_split_info.json
                ori_mae_results = calculate_shot_wise_mae(y_flat, predictions_ori, test_shot_mapping)
            elif few_threshold is not None and many_threshold is not None and y_bins is not None:
                # Fallback: calculate shot mapping
                hist, bin_edges = np.histogram(y_flat, bins=y_bins)
                bin_types = classify_bins_by_samples(hist, few_threshold, many_threshold)
                y_series = pd.Series(y_flat)
                shot_mapping = map_shot_types(y_series, bin_edges, bin_types)
                ori_mae_results = calculate_shot_wise_mae(y_flat, predictions_ori, shot_mapping)
            else:
                ori_mae_results = {'overall': overall_mae_ori, 'few': None, 'medium': None, 'many': None}
            
            # Count samples per shot type
            if test_shot_mapping:
                few_count_ori = sum(1 for s in test_shot_mapping.values() if s == 'few')
                med_count_ori = sum(1 for s in test_shot_mapping.values() if s == 'medium')
                many_count_ori = sum(1 for s in test_shot_mapping.values() if s == 'many')
            else:
                few_count_ori = med_count_ori = many_count_ori = 0
            
            print(f"\n{'='*60}")
            print(f"[METRICS] Original Test Set (n={len(y_flat)})")
            print(f"{'='*60}")
            print(f"   MAE:  {overall_mae_ori:.4f}")
            print(f"   RMSE: {rmse_ori:.4f}")
            print(f"\n   Shot-wise MAE:")
            print(f"   - Few    (n={few_count_ori:3d}): {ori_mae_results['few']:.4f}" if ori_mae_results.get('few') else f"   - Few    (n={few_count_ori:3d}): N/A")
            print(f"   - Medium (n={med_count_ori:3d}): {ori_mae_results['medium']:.4f}" if ori_mae_results.get('medium') else f"   - Medium (n={med_count_ori:3d}): N/A")
            print(f"   - Many   (n={many_count_ori:3d}): {ori_mae_results['many']:.4f}" if ori_mae_results.get('many') else f"   - Many   (n={many_count_ori:3d}): N/A")
            print(f"   - Overall: {ori_mae_results['overall']:.4f}")
        
        # ==========================================
        # 2. Balanced Test Set Prediction
        # ==========================================
        if X_test_balanced is not None and y_test_balanced is not None:
            predictions_bal = predict(model, X_test_balanced, scaler)
            print(f"\n[OK] Balanced test prediction complete: {len(predictions_bal)} samples")
            
            y_bal_flat = y_test_balanced.flatten() if hasattr(y_test_balanced, 'flatten') else np.array(y_test_balanced).flatten()
            
            overall_mae_bal = float(np.mean(np.abs(predictions_bal - y_bal_flat)))
            rmse_bal = float(np.sqrt(np.mean((predictions_bal - y_bal_flat) ** 2)))
            
            # Calculate shot-wise MAE for Balanced test set
            if test_balanced_shot_indices:
                bal_mae_results = calculate_shot_wise_mae(y_bal_flat, predictions_bal, test_balanced_shot_indices)
            else:
                bal_mae_results = {'overall': overall_mae_bal, 'few': None, 'medium': None, 'many': None}
            
            # Count samples per shot type
            if test_balanced_shot_indices:
                few_count_bal = sum(1 for s in test_balanced_shot_indices.values() if s == 'few')
                med_count_bal = sum(1 for s in test_balanced_shot_indices.values() if s == 'medium')
                many_count_bal = sum(1 for s in test_balanced_shot_indices.values() if s == 'many')
            else:
                few_count_bal = med_count_bal = many_count_bal = 0
            
            print(f"\n{'='*60}")
            print(f"[METRICS] Balanced Test Set (n={len(y_bal_flat)})")
            print(f"{'='*60}")
            print(f"   MAE:  {overall_mae_bal:.4f}")
            print(f"   RMSE: {rmse_bal:.4f}")
            print(f"\n   Shot-wise MAE:")
            print(f"   - Few    (n={few_count_bal:3d}): {bal_mae_results['few']:.4f}" if bal_mae_results.get('few') else f"   - Few    (n={few_count_bal:3d}): N/A")
            print(f"   - Medium (n={med_count_bal:3d}): {bal_mae_results['medium']:.4f}" if bal_mae_results.get('medium') else f"   - Medium (n={med_count_bal:3d}): N/A")
            print(f"   - Many   (n={many_count_bal:3d}): {bal_mae_results['many']:.4f}" if bal_mae_results.get('many') else f"   - Many   (n={many_count_bal:3d}): N/A")
            print(f"   - Overall: {bal_mae_results['overall']:.4f}")
        else:
            print(f"\n[INFO] Balanced test set not available")
        
        # ==========================================
        # 3. Comparison Summary
        # ==========================================
        if ori_mae_results and bal_mae_results:
            print(f"\n{'='*60}")
            print(f"[SUMMARY] Performance Comparison")
            print(f"{'='*60}")
            print(f"{'Metric':<15} {'Original':<12} {'Balanced':<12} {'Diff':<12}")
            print(f"{'-'*51}")
            
            for metric in ['few', 'medium', 'many', 'overall']:
                ori_val = ori_mae_results.get(metric)
                bal_val = bal_mae_results.get(metric)
                if ori_val is not None and bal_val is not None:
                    diff = bal_val - ori_val
                    print(f"{metric.capitalize():<15} {ori_val:<12.4f} {bal_val:<12.4f} {diff:+.4f}")
                else:
                    ori_str = f"{ori_val:.4f}" if ori_val else "N/A"
                    bal_str = f"{bal_val:.4f}" if bal_val else "N/A"
                    print(f"{metric.capitalize():<15} {ori_str:<12} {bal_str:<12} {'N/A':<12}")
        
        # ==========================================
        # 4. Verification against saved performance
        # ==========================================
        if not args.no_verify and saved_performance:
            print(f"\n{'='*60}")
            print(f"[VERIFY] Comparing with saved performance...")
            print(f"{'='*60}")
            
            saved_ori_mae = saved_performance.get('ori_mae', {})
            saved_bal_mae = saved_performance.get('bal_mae', {})
            
            all_passed = True
            
            # Verify Original MAE
            if ori_mae_results:
                passed_ori, msg_ori = verify_performance(ori_mae_results, saved_ori_mae, 'ori')
                if passed_ori is True:
                    print(f"   [PASS] {msg_ori}")
                    print(f"          Calculated: {ori_mae_results['overall']:.6f}, Saved: {saved_ori_mae.get('overall', 'N/A')}")
                elif passed_ori is False:
                    print(f"   [FAIL] {msg_ori}")
                    print(f"          Calculated: {ori_mae_results['overall']:.6f}, Saved: {saved_ori_mae.get('overall', 'N/A')}")
                    all_passed = False
                else:
                    print(f"   [SKIP] {msg_ori}")
            
            # Verify Balanced MAE
            if bal_mae_results:
                passed_bal, msg_bal = verify_performance(bal_mae_results, saved_bal_mae, 'bal')
                if passed_bal is True:
                    print(f"   [PASS] {msg_bal}")
                    print(f"          Calculated: {bal_mae_results['overall']:.6f}, Saved: {saved_bal_mae.get('overall', 'N/A')}")
                elif passed_bal is False:
                    print(f"   [FAIL] {msg_bal}")
                    print(f"          Calculated: {bal_mae_results['overall']:.6f}, Saved: {saved_bal_mae.get('overall', 'N/A')}")
                    all_passed = False
                else:
                    print(f"   [SKIP] {msg_bal}")
            
            if all_passed:
                print(f"\n   *** ALL VERIFICATIONS PASSED ***")
            else:
                print(f"\n   *** SOME VERIFICATIONS FAILED ***")
        
        # ==========================================
        # 5. Save results
        # ==========================================
        prediction_results = {
            'dataset': dataset_name,
            'algorithm': algorithm_name,
            'artifact': artifact_dir,
            'timestamp': datetime.now().isoformat(),
            'ori_mae': ori_mae_results if ori_mae_results else {'overall': overall_mae_ori if y is not None else None},
            'bal_mae': bal_mae_results,
            'sample_count': {
                'original': len(predictions_ori),
                'balanced': len(predictions_bal) if X_test_balanced is not None else None
            }
        }
        
        if args.save_results:
            results_path = os.path.join(artifact_dir, 'prediction_results.json')
            with open(results_path, 'w', encoding='utf-8') as f:
                json.dump(prediction_results, f, indent=2, ensure_ascii=False)
            print(f"\n[SAVE] Results saved: {results_path}")
        
        if args.output:
            result_df = pd.DataFrame({
                'prediction': predictions_ori,
                'actual': y,
                'error': np.abs(predictions_ori - y)
            })
            result_df.to_csv(args.output, index=False)
            print(f"[SAVE] Predictions CSV saved: {args.output}")
        
        # ==========================================
        # 6. Sample predictions
        # ==========================================
        print(f"\n[SAMPLE] Prediction samples (first 5, Original test set):")
        for i in range(min(5, len(predictions_ori))):
            if y is not None:
                print(f"   [{i+1}] Pred: {predictions_ori[i]:.4f}, Actual: {y[i]:.4f}, Error: {abs(predictions_ori[i] - y[i]):.4f}")
            else:
                print(f"   [{i+1}] Pred: {predictions_ori[i]:.4f}")
    
    except Exception as e:
        print(f"[ERROR] Prediction failed: {e}")
        import traceback
        traceback.print_exc()
        return
    
    print(f"\n[DONE] Inference complete!")

if __name__ == "__main__":
    main()
