# -*- coding: utf-8 -*-
"""
Main training script for CASMIR.
Supports multiple algorithms and datasets with HPO (Hyperparameter Optimization).
"""
import datetime
import subprocess
import sys
from matplotlib import pyplot as plt
import pandas as pd
import numpy as np
import torch
import optuna
import argparse
import os
import json
import glob
import time
import joblib
from sklearn.preprocessing import StandardScaler

from resreg.resampler import pdf_relevance, sigmoid_relevance
from src.data.datasets import calculate_shot_wise_mae, update_shot_configs
from resreg import resampler
from sklearn.neighbors import KernelDensity
import traceback
import random

from src.data.datasets import (
    classify_bins_by_samples, create_balanced_dataset, freedman_diaconis_bins,
    load_dataset, map_shot_types, preprocess_data, split_data,
    OPENML_DATASETS, SKLEARN_DATASETS, LOCAL_DIR_DATASETS, UCI_DATASETS, KAGGLE_DATASETS,
    split_data_stratified
)
from src.models.basic_models import MLP, MLPadv, SimpleThreeMLPEnsemble, XGBoostWrapper, LightGBMWrapper, CatBoostWrapper
from src.models.CASMIR_V1 import CASMIR_V1, precompute_density_and_boundaries
from src.training.losses import GAILossMD, BMCLossMD, ConRLoss, RankSimLoss, WeightedL1Loss, WeightedMSELoss
from src.models.samplers import apply_smoter, apply_gaussian_noise
from src.evaluation.evaluation import calculate_region_mae, calculate_region_mae_with_thresholds
from src.training.hpo import objective_MLP, objective_tree, objective_CASMIR_V1
from src.training.train_utils import train_pytorch_model, train_CASMIR_V1
import torch.optim as optim
import torch.nn as nn
from src.models.fds_layer import FDSLayer
from sklearn.preprocessing import KBinsDiscretizer
from src.utils.utils import calculate_balanced_weights, get_gmm
from config import CONFIG

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def set_global_random_seed(seed=None):
    """
    Set global random seed for full reproducibility.
    
    Args:
        seed: Random seed value. If None, uses CONFIG['random_state'].
    """
    if seed is None:
        seed = CONFIG.get('random_state', 42)
    
    print(f"[SEED] Setting global random seed to {seed}...")
    
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
    
    try:
        torch.use_deterministic_algorithms(True)
    except (AttributeError, RuntimeError):
        pass
    
    os.environ['PYTHONHASHSEED'] = str(seed)
    print(f"[SEED] All random sources fixed to seed {seed}\n")


def setup_device(gpu_id=None, force_cpu=False):
    """Set up GPU device for training."""
    if force_cpu:
        device = torch.device("cpu")
        print(f"Device: {device} (CPU forced)")
        return device
        
    if gpu_id is not None and torch.cuda.is_available():
        if gpu_id < torch.cuda.device_count():
            device = torch.device(f"cuda:{gpu_id}")
            torch.cuda.set_device(gpu_id)
            print(f"Device: {device} (GPU {gpu_id})")
        else:
            print(f"Warning: GPU {gpu_id} not available. Using default device.")
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    else:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Device: {device}")
    return device


def save_requirements():
    """Save current environment dependencies."""
    try:
        result = subprocess.run(['pip', 'freeze'], capture_output=True, text=True)
        requirements_content = result.stdout
        os.makedirs('./artifacts', exist_ok=True)
        with open('./artifacts/requirements.txt', 'w') as f:
            f.write(requirements_content)
        return requirements_content
    except Exception as e:
        print(f"Error saving dependencies: {e}")
        return None


def get_next_version_number():
    """Generate next version number automatically."""
    artifacts_dir = "./artifacts"
    if not os.path.exists(artifacts_dir):
        return "V001"
    
    existing_dirs = [d for d in os.listdir(artifacts_dir) if os.path.isdir(os.path.join(artifacts_dir, d))]
    version_numbers = []
    for dir_name in existing_dirs:
        if dir_name.startswith('V') and len(dir_name) >= 4:
            try:
                version_num = int(dir_name[1:4])
                version_numbers.append(version_num)
            except ValueError:
                continue
    
    if not version_numbers:
        return "V001"
    
    next_version = max(version_numbers) + 1
    return f"V{next_version:03d}"


def save_complete_model_artifact(model, algorithm_name, dataset_name, best_params, 
                                preprocessor, data_split_info, config_snapshot,
                                fds_discretizer=None, gmm_params=None, 
                                calculated_weights_tensor=None, run_name="",
                                experiment_results=None, save_to_learned_models=False,
                                learned_models_dir="learned_models"):
    """
    Save complete model artifact including weights, hyperparameters, and metadata.
    
    Args:
        model: Trained model instance.
        algorithm_name: Name of the algorithm.
        dataset_name: Name of the dataset.
        best_params: Best hyperparameters from HPO.
        preprocessor: Data preprocessor.
        data_split_info: Information about data splits.
        config_snapshot: Configuration snapshot.
        fds_discretizer: FDS discretizer (optional).
        gmm_params: GMM parameters (optional).
        calculated_weights_tensor: Pre-calculated weights (optional).
        run_name: Experiment run name.
        experiment_results: Experiment results dictionary.
        save_to_learned_models: Whether to save to learned_models directory.
        learned_models_dir: Directory for model storage.
    
    Returns:
        tuple: (artifact_dir, versioned_run_name) or (None, None) on failure.
    """
    version = get_next_version_number()
    now = datetime.datetime.now()
    timestamp = now.strftime("%Y_%m_%d__%H_%M")
    versioned_run_name = f"{version}_{dataset_name}_{algorithm_name}_{timestamp}"
    artifact_dir = f"./artifacts/{versioned_run_name}"
    os.makedirs(artifact_dir, exist_ok=True)
    print(f"Saving artifacts to: {artifact_dir}")
    
    try:
        is_pytorch_like = hasattr(model, 'state_dict') or isinstance(model, (MLPadv, SimpleThreeMLPEnsemble, CASMIR_V1))
        
        # Save model
        if is_pytorch_like:
            torch.save(model.state_dict(), f"{artifact_dir}/model.pth")
        else:
            if hasattr(model, 'model'):
                joblib.dump(model.model, f"{artifact_dir}/model.joblib")
            else:
                joblib.dump(model, f"{artifact_dir}/model.joblib")
        
        # Save hyperparameters
        with open(f"{artifact_dir}/hyperparams.json", 'w', encoding='utf-8') as f:
            json.dump(best_params, f, indent=2, ensure_ascii=False)
        
        # Save preprocessor
        joblib.dump(preprocessor, f"{artifact_dir}/preprocessor.joblib")
        
        # Save data split info
        with open(f"{artifact_dir}/data_split_info.json", 'w', encoding='utf-8') as f:
            json.dump(data_split_info, f, indent=2, ensure_ascii=False, default=str)
        
        # Save algorithm-specific components
        if fds_discretizer is not None:
            joblib.dump(fds_discretizer, f"{artifact_dir}/fds_discretizer.joblib")
        
        if gmm_params is not None:
            gmm_serializable = {}
            for key, value in gmm_params.items():
                if isinstance(value, np.ndarray):
                    gmm_serializable[key] = value.tolist()
                elif isinstance(value, torch.Tensor):
                    gmm_serializable[key] = value.cpu().numpy().tolist()
                else:
                    gmm_serializable[key] = value
            with open(f"{artifact_dir}/gmm_params.json", 'w', encoding='utf-8') as f:
                json.dump(gmm_serializable, f, indent=2, ensure_ascii=False)
        
        if calculated_weights_tensor is not None:
            torch.save(calculated_weights_tensor, f"{artifact_dir}/weights_tensor.pth")
        
        # Save benchmark results
        if experiment_results is not None:
            bench_filename = f"bench_{dataset_name}_{algorithm_name}_{timestamp}.json"
            bench_data = {dataset_name: {algorithm_name: experiment_results}}
            with open(f"{artifact_dir}/{bench_filename}", 'w', encoding='utf-8') as f:
                json.dump(bench_data, f, indent=2, ensure_ascii=False)
        
        # Save metadata
        metadata = {
            "version": version,
            "algorithm_name": algorithm_name,
            "dataset_name": dataset_name,
            "versioned_run_name": versioned_run_name,
            "original_run_name": run_name,
            "timestamp": datetime.datetime.now().isoformat(),
            "config_snapshot": config_snapshot,
            "model_type": "pytorch" if is_pytorch_like else "sklearn",
            "device": str(DEVICE),
            "python_version": sys.version,
            "torch_version": torch.__version__ if hasattr(torch, '__version__') else None,
            "numpy_version": np.__version__,
        }
        
        
        with open(f"{artifact_dir}/metadata.json", 'w', encoding='utf-8') as f:
            json.dump(metadata, f, indent=2, ensure_ascii=False, default=str)
        
        # Save experiment config
        experiment_config = {
            "version": version,
            "versioned_run_name": versioned_run_name,
            "algorithm": algorithm_name,
            "dataset": dataset_name,
            "best_params": best_params,
            "global_config": config_snapshot,
            "timestamp": datetime.datetime.now().isoformat(),
        }
        with open(f"{artifact_dir}/experiment_config.json", 'w', encoding='utf-8') as f:
            json.dump(experiment_config, f, indent=2, ensure_ascii=False, default=str)
        
        print(f"Artifact saved: {artifact_dir} (Version: {version})")
        
        # Save to learned_models if requested
        if save_to_learned_models:
            output_model_dir = f"./{learned_models_dir}/{versioned_run_name}"
            os.makedirs(output_model_dir, exist_ok=True)
            
            try:
                if is_pytorch_like:
                    torch.save(model.state_dict(), f"{output_model_dir}/model.pth")
                else:
                    if hasattr(model, 'model'):
                        joblib.dump(model.model, f"{output_model_dir}/model.joblib")
                    else:
                        joblib.dump(model, f"{output_model_dir}/model.joblib")
                
                with open(f"{output_model_dir}/hyperparams.json", 'w', encoding='utf-8') as f:
                    json.dump(best_params, f, indent=2, ensure_ascii=False)
                
                joblib.dump(preprocessor, f"{output_model_dir}/preprocessor.joblib")
                
                simple_metadata = {
                    "version": version,
                    "algorithm": algorithm_name,
                    "dataset": dataset_name,
                    "versioned_run_name": versioned_run_name,
                    "timestamp": datetime.datetime.now().isoformat(),
                    "model_type": "pytorch" if is_pytorch_like else "tree",
                    "performance": experiment_results if experiment_results else {},
                    "random_seed": config_snapshot.get("random_state", 42)
                }
                with open(f"{output_model_dir}/model_info.json", 'w', encoding='utf-8') as f:
                    json.dump(simple_metadata, f, indent=2, ensure_ascii=False, default=str)
                
                with open(f"{output_model_dir}/artifact_reference.txt", 'w', encoding='utf-8') as f:
                    f.write(f"Complete artifact location: {artifact_dir}\n")
                    f.write(f"Version: {version}\n")
                
                print(f"Model saved to: {output_model_dir}")
            except Exception as learned_err:
                print(f"Warning: Error saving to learned_models: {learned_err}")
        
        return artifact_dir, versioned_run_name
        
    except Exception as e:
        print(f"Error saving artifact: {e}")
        traceback.print_exc()
        return None, None


def get_final_pytorch_criterion(best_params, algo_name, y_train_npy, weights_tensor=None, gmm_params=None):
    """Get loss function (criterion) from best hyperparameters."""
    if algo_name in ['MLP', 'MLP_FDS']:
        base_loss = best_params.get("base_loss", "l1")
        return nn.MSELoss() if base_loss == "mse" else nn.L1Loss()
    
    elif algo_name in ["MLP_SQRT_INV", "MLP_LDS_Notebook", "MLP_FDS_Notebook", "MLP_LDS_FDS_Notebook"]:
        base_loss = CONFIG.get("loss_for_reweighting", "l1")
        
        if algo_name == "MLP_FDS_Notebook" and weights_tensor is None:
            return nn.L1Loss() if base_loss == "l1" else nn.MSELoss()
            
        if weights_tensor is None:
            return nn.L1Loss() if base_loss == "l1" else nn.MSELoss()

        if base_loss == "l1":
            return WeightedL1Loss(weights_tensor.cpu())
        else:
            return WeightedMSELoss(weights_tensor.cpu())

    elif algo_name == "MLP_GAI_BMSE":
        if gmm_params is None:
            return nn.L1Loss()
        init_noise_sigma = best_params.get("gai_init_noise_sigma", CONFIG.get('gai_init_noise_sigma', 6.0))
        return GAILossMD(gmm_dict=gmm_params, init_noise_sigma=init_noise_sigma)

    elif algo_name == "MLP_BMC_BMSE":
        init_noise_sigma = best_params.get("bmc_init_noise_sigma", CONFIG.get('bmse_noise_sigma', 6.0))
        return BMCLossMD(init_noise_sigma=init_noise_sigma)

    elif algo_name == 'MLP_RankSim':
        lambda_val = best_params.get('ranksim_lambda_val', CONFIG.get('ranksim_lambda_val', 1.0))
        alpha = best_params.get('ranksim_alpha', CONFIG.get('ranksim_alpha', 1.0))
        return RankSimLoss(lambda_val=lambda_val, alpha=alpha)
    
    elif algo_name == 'MLP_ConR':
        w = best_params.get('conr_distance_threshold', CONFIG.get('conr_distance_threshold', 1.0))
        t = best_params.get('conr_temperature', CONFIG.get('conr_temperature', 0.07))
        e = best_params.get('conr_pushing_power', CONFIG.get('conr_pushing_power', 0.01))
        alpha = best_params.get('conr_alpha', CONFIG.get('conr_alpha', 1.0))
        mse_weight = best_params.get('conr_mse_weight', CONFIG.get('conr_mse_weight', 1.0))
        return ConRLoss(w=w, t=t, e=e, alpha=alpha, mse_weight=mse_weight)
    
    else:
        return nn.L1Loss()


def load_best_params_from_hpo(load_hpo_dir, dataset_name, algorithm_name):
    """
    Load best hyperparameters from HPO results directory.
    
    Args:
        load_hpo_dir: Directory containing HPO results.
        dataset_name: Name of the dataset.
        algorithm_name: Name of the algorithm.
    
    Returns:
        dict: Best hyperparameters, or None if loading failed.
    """
    if not load_hpo_dir or not os.path.exists(load_hpo_dir):
        print(f"HPO directory not found: {load_hpo_dir}")
        return None
    
    # Case 1: Direct artifact path
    direct_params_path = os.path.join(load_hpo_dir, 'hyperparams.json')
    if os.path.exists(direct_params_path):
        try:
            with open(direct_params_path, 'r', encoding='utf-8') as f:
                best_params = json.load(f)
            print(f"Loaded hyperparameters: {len(best_params)} params")
            return best_params
        except Exception as e:
            print(f"Failed to load hyperparameters: {e}")
            return None
    
    # Case 2: Pattern matching in parent directory
    pattern = os.path.join(load_hpo_dir, f"V*_{dataset_name}_{algorithm_name}_*")
    matches = glob.glob(pattern)
    
    if not matches:
        print(f"No matching HPO results found: {pattern}")
        return None
    
    latest = sorted(matches)[-1]
    print(f"Loading HPO results: {latest}")
    
    params_path = os.path.join(latest, 'hyperparams.json')
    if not os.path.exists(params_path):
        params_path = os.path.join(latest, 'best_params.json')
    
    if not os.path.exists(params_path):
        print(f"Hyperparameters file not found: {params_path}")
        return None
    
    try:
        with open(params_path, 'r', encoding='utf-8') as f:
            best_params = json.load(f)
        print(f"Loaded hyperparameters: {len(best_params)} params")
        return best_params
    except Exception as e:
        print(f"Failed to load hyperparameters: {e}")
        return None


def run_experiment(dataset_name, algorithm_name, ada_expert_type='tabular', use_mlpadv=True, 
                   num_workers=0, n_jobs=1, load_hpo_dir=None, skip_hpo=False, 
                   save_models=False, output_dir="learned_models"):
    """
    Run complete HPO and evaluation for a dataset/algorithm combination.

    Args:
        dataset_name: Name of the dataset.
        algorithm_name: Name of the algorithm.
        ada_expert_type: CASMIR expert type ('tabular' or 'basic').
        use_mlpadv: Whether to use MLPadv (deprecated, always True).
        num_workers: DataLoader workers.
        n_jobs: Optuna parallel jobs.
        load_hpo_dir: Directory to load pre-trained HPO results.
        skip_hpo: Skip HPO and use loaded parameters directly.
        save_models: Save trained models to learned_models directory.
        output_dir: Directory for model storage.
    
    Returns:
        dict: Experiment results with balanced and original MAE metrics.
    """
    start_time = time.time()
    print(f"\n--- Running experiment: dataset='{dataset_name}', algorithm='{algorithm_name}' ---")
    
    if load_hpo_dir:
        print(f"HPO load directory: {load_hpo_dir}")
        if skip_hpo:
            print(f"Skip HPO: Using pre-optimized hyperparameters")

    # Run name setup
    now = datetime.datetime.now()
    formatted_string = now.strftime("%Y_%m_%d__%H_%M")
    run_name = f"{dataset_name}_{algorithm_name}_{formatted_string}"

    # Initialize algorithm-specific variables
    precomputed_train_densities_np = None
    density_boundaries = None
    fitted_kde = None
    fds_discretizer = None
    calculated_weights_tensor = None
    gmm_params = None
    apply_weights_in_loss = False

    # 1. Data loading and preprocessing
    df, target_col = load_dataset(dataset_name)
    if pd.api.types.is_numeric_dtype(df[target_col]):
        skewness = df[target_col].skew()
        print(f"Target column '{target_col}' skewness: {skewness:.2f}")

    X, y, preprocessor = preprocess_data(df, target_col)

    # Get dataset config
    dataset_config = CONFIG['dataset_configs'][dataset_name]
    few_threshold = dataset_config['few_threshold']
    many_threshold = dataset_config['many_threshold']
    y_bins = dataset_config['y_bins'] if dataset_config['y_bins'] != 0 else freedman_diaconis_bins(y)

    # 2. Stratified data split
    X_train_df, X_val_df, X_test_df, y_train_sr, y_val_sr, y_test_sr = split_data_stratified(
        X, y,
        test_size=CONFIG['test_size'],
        validation_size=CONFIG['validation_size'],
        n_bins=y_bins,
        random_state=CONFIG['data_split_seed']
    )

    # Data split info for reproducibility
    data_split_info = {
        "original_data_shape": X.shape,
        "train_indices": X_train_df.index.tolist(),
        "val_indices": X_val_df.index.tolist(),
        "test_indices": X_test_df.index.tolist(),
        "train_size": len(X_train_df),
        "val_size": len(X_val_df),
        "test_size": len(X_test_df),
        "y_bins": y_bins,
        "few_threshold": few_threshold,
        "many_threshold": many_threshold,
        "target_col": target_col,
        "data_split_seed": CONFIG['data_split_seed'],
        "algorithm_random_state": CONFIG['random_state'],
    }

    # 3. Shot type classification
    val_hist, val_bin_edges = np.histogram(y_val_sr, bins=y_bins)
    test_hist, test_bin_edges = np.histogram(y_test_sr, bins=val_bin_edges)

    val_bin_types = classify_bins_by_samples(val_hist, few_threshold, many_threshold)
    test_bin_types = classify_bins_by_samples(test_hist, few_threshold, many_threshold)

    val_shot_mapping = map_shot_types(y_val_sr, val_bin_edges, val_bin_types)
    test_shot_mapping = map_shot_types(y_test_sr, test_bin_edges, test_bin_types)

    print(f"Validation shot distribution: {dict((shot, list(val_shot_mapping.values()).count(shot)) for shot in ['few', 'medium', 'many'])}")
    print(f"Test shot distribution: {dict((shot, list(test_shot_mapping.values()).count(shot)) for shot in ['few', 'medium', 'many'])}")

    # 4. Create balanced datasets
    X_val_balanced, y_val_balanced, val_balanced_indices, val_balanced_shot_indices = create_balanced_dataset(
        CONFIG, X_val_df.to_numpy(), y_val_sr.to_numpy(), val_shot_mapping
    )
    X_test_balanced, y_test_balanced, test_balanced_indices, test_balanced_shot_indices = create_balanced_dataset(
        CONFIG, X_test_df.to_numpy(), y_test_sr.to_numpy(), test_shot_mapping
    )

    X_val_balanced_df = pd.DataFrame(X_val_balanced, columns=X_val_df.columns)
    X_test_balanced_df = pd.DataFrame(X_test_balanced, columns=X_test_df.columns)
    y_val_balanced_sr = pd.Series(y_val_balanced)
    y_test_balanced_sr = pd.Series(y_test_balanced)

    # Keep original for evaluation
    X_val_original_df = X_val_df.copy()
    X_test_original_df = X_test_df.copy()
    y_val_original_sr = y_val_sr.copy()
    y_test_original_sr = y_test_sr.copy()

    # Use balanced for HPO and evaluation
    X_val_df = X_val_balanced_df
    y_val_sr = y_val_balanced_sr
    X_test_df = X_test_balanced_df
    y_test_sr = y_test_balanced_sr

    # 5. Preprocessing
    X_train = preprocessor.fit_transform(X_train_df)
    X_val = preprocessor.transform(X_val_df)
    X_test = preprocessor.transform(X_test_df)
    X_test_ori = preprocessor.transform(X_test_original_df)

    y_train_npy = y_train_sr.to_numpy().reshape(-1, 1)
    y_val_npy = y_val_sr.to_numpy().reshape(-1, 1)
    y_test_npy = y_test_sr.to_numpy().reshape(-1, 1)
    y_test_ori_npy = y_test_original_sr.to_numpy().reshape(-1, 1)
    y_train_flat = y_train_npy.flatten()

    # 6. Algorithm-specific preprocessing
    if algorithm_name == "CASMIR":
        densities_np, boundaries, kde_model = precompute_density_and_boundaries(
            y_train_npy,
            kde_bandwidth=CONFIG.get("casmir_kde_bandwidth", 'silverman'),
            density_percentiles=CONFIG.get("casmir_density_percentiles", [33.3, 66.7])
        )
        if densities_np is None:
            raise ValueError(f"Density calculation failed for {dataset_name}")
        precomputed_train_densities_np = densities_np
        density_boundaries = boundaries
        fitted_kde = kde_model

    elif algorithm_name == "MLP_SQRT_INV":
        weights_np = calculate_balanced_weights(y_train_flat, reweight='sqrt_inv', lds=False)
        if weights_np is not None:
            calculated_weights_tensor = torch.from_numpy(weights_np).cpu()
            apply_weights_in_loss = True

    elif algorithm_name == "MLP_LDS_Notebook":
        weights_np = calculate_balanced_weights(
            y_train_flat, reweight=CONFIG.get('lds_reweight_base', 'sqrt_inv'),
            binning_method='quantile', lds=True,
            lds_kernel=CONFIG.get('lds_kernel', 'gaussian'),
            lds_ks=CONFIG.get('lds_ks', 5),
            lds_sigma=CONFIG.get('lds_sigma', 2)
        )
        if weights_np is not None:
            calculated_weights_tensor = torch.from_numpy(weights_np).cpu()
            apply_weights_in_loss = True

    elif algorithm_name == "MLP_FDS_Notebook":
        fds_n_bins = CONFIG.get("fds_num_target_bins", 50)
        fds_strategy = CONFIG.get("fds_discretizer_strategy", 'uniform')
        try:
            fds_discretizer = KBinsDiscretizer(n_bins=fds_n_bins, encode='ordinal', strategy=fds_strategy, subsample=None)
            fds_discretizer.fit(y_train_npy)
        except Exception as disc_err:
            print(f"FDS Discretizer failed: {disc_err}")
            fds_discretizer = None

    elif algorithm_name == "MLP_LDS_FDS_Notebook":
        # LDS weights
        weights_np = calculate_balanced_weights(
            y_train_flat, reweight=CONFIG.get('lds_reweight_base', 'sqrt_inv'),
            binning_method='quantile', lds=True,
            lds_kernel=CONFIG.get('lds_kernel', 'gaussian'),
            lds_ks=CONFIG.get('lds_ks', 5),
            lds_sigma=CONFIG.get('lds_sigma', 2)
        )
        if weights_np is not None:
            calculated_weights_tensor = torch.from_numpy(weights_np).cpu()
            apply_weights_in_loss = True
        
        # FDS discretizer
        fds_n_bins = CONFIG.get("fds_num_target_bins", 50)
        fds_strategy = CONFIG.get("fds_discretizer_strategy", 'uniform')
        try:
            fds_discretizer = KBinsDiscretizer(n_bins=fds_n_bins, encode='ordinal', strategy=fds_strategy, subsample=None)
            fds_discretizer.fit(y_train_npy)
        except Exception as disc_err:
            print(f"FDS Discretizer failed: {disc_err}")
            fds_discretizer = None

    elif algorithm_name in ["MLP_FDS", "MLP_LDS_FDS"]:
        fds_n_bins = CONFIG.get("fds_num_target_bins", 50)
        fds_strategy = CONFIG.get("fds_discretizer_strategy", 'uniform')
        try:
            fds_discretizer = KBinsDiscretizer(n_bins=fds_n_bins, encode='ordinal', strategy=fds_strategy, subsample=None)
            fds_discretizer.fit(y_train_npy)
        except Exception as disc_err:
            print(f"FDS Discretizer failed: {disc_err}")
            fds_discretizer = None

    elif algorithm_name == "MLP_GAI_BMSE":
        gmm_components = CONFIG.get('gmm_components', 3)
        gmm_params = get_gmm(y_train_flat, n_components=gmm_components)

    # 7. HPO
    best_params = None
    
    if skip_hpo and load_hpo_dir:
        print("Skip HPO mode: Loading pre-optimized hyperparameters...")
        best_params = load_best_params_from_hpo(load_hpo_dir, dataset_name, algorithm_name)
        if best_params is None:
            print("Failed to load HPO results. Aborting experiment.")
            return None
    else:
        print("Starting hyperparameter optimization...")
        try:
            sampler = optuna.samplers.TPESampler(seed=CONFIG["random_state"])
            study = optuna.create_study(
                direction="minimize",
                sampler=sampler,
                pruner=optuna.pruners.MedianPruner()
            )

            hpo_config = {
                "algorithm_name": algorithm_name,
                "input_dim": X_train.shape[1],
                "hpo_epochs": CONFIG["hpo_epochs"],
                "hpo_patience": CONFIG["hpo_patience"],
                "random_state": CONFIG["random_state"],
                "loss_bins": CONFIG.get("loss_bins", 10),
                "precomputed_train_densities_np": precomputed_train_densities_np,
                "density_boundaries": density_boundaries,
                "casmir_num_experts": CONFIG.get("casmir_num_experts", 3),
                "ada_expert_type": ada_expert_type,
                "use_mlpadv": use_mlpadv,
                "fds_discretizer": fds_discretizer,
                "fds_start_epoch_hpo": 0,
                "num_workers": num_workers,
                "calculated_weights_tensor": calculated_weights_tensor,
                "gmm_params": gmm_params,
                "apply_weights_in_loss": apply_weights_in_loss,
                "base_loss_for_reweighting": CONFIG.get("loss_for_reweighting", "l1"),
                "lds_ks_default": CONFIG.get('lds_ks', 5),
                "lds_sigma_default": CONFIG.get('lds_sigma', 2),
                "gai_init_noise_sigma_default": CONFIG.get('gai_init_noise_sigma', 6.0),
            }

            objective_fn = None
            if (algorithm_name.startswith("MLP") or algorithm_name == "Simple_Ensemble") and algorithm_name != "CASMIR":
                objective_fn = objective_MLP
            elif algorithm_name in ["XGBoost", "LightGBM", "CatBoost"] or algorithm_name.endswith(("_XGBoost", "_LightGBM", "_CatBoost")):
                hpo_config["sampler"] = None
                if "SMOTER" in algorithm_name:
                    hpo_config["sampler"] = "SMOTER"
                elif "GaussianNoise" in algorithm_name:
                    hpo_config["sampler"] = "GaussianNoise"
                objective_fn = objective_tree
            elif algorithm_name == "CASMIR":
                objective_fn = objective_CASMIR_V1

            if objective_fn:
                study.optimize(
                    lambda trial: objective_fn(trial, hpo_config, X_train, y_train_npy.flatten(), X_val, y_val_npy.flatten(), shot_mapping=val_balanced_shot_indices),
                    n_trials=CONFIG["n_trials"],
                    timeout=CONFIG.get("hpo_timeout", None),
                    n_jobs=n_jobs
                )
                best_trial = study.best_trial
                best_params = best_trial.params
                print(f"HPO complete. Best validation MAE: {best_trial.value:.4f}")
            else:
                return None

        except Exception as e:
            print(f"HPO error for {run_name}: {e}")
            traceback.print_exc()
            return None

    # 8. Final model training
    print("Training final model with best hyperparameters...")
    
    if best_params is None:
        print("Error: No hyperparameters available for final training.")
        return None

    # Reset seed for reproducibility
    set_global_random_seed()

    final_model = None
    final_train_config = {
        'epochs': CONFIG["final_epochs"],
        'patience': CONFIG["final_patience"],
        'batch_size': best_params.get("batch_size", 64),
        'lr': best_params.get("lr", 0.001),
        'optimizer_name': best_params.get("optimizer", "Adam"),
        "apply_weights_in_loss": apply_weights_in_loss
    }

    try:
        if (algorithm_name.startswith("MLP") or algorithm_name == "Simple_Ensemble") and algorithm_name != "CASMIR":
            use_fds_final = algorithm_name in ["MLP_FDS", "MLP_LDS_FDS", "MLP_FDS_Notebook", "MLP_LDS_FDS_Notebook"]
            fds_config_final = None
            
            if use_fds_final and fds_discretizer is not None:
                fds_config_final = {
                    'num_target_bins': fds_discretizer.n_bins_,
                    'fds_momentum': best_params.get('fds_momentum', CONFIG['fds_momentum']),
                    'start_update_epoch': CONFIG.get('fds_start_epoch', 0),
                    'kernel': best_params.get('fds_kernel', 'gaussian'),
                    'kernel_size': best_params.get('fds_kernel_size', 5),
                    'kernel_sigma': best_params.get('fds_kernel_sigma', 2.0),
                    'start_smooth_epoch': best_params.get('fds_start_smooth_epoch', 1)
                }
            elif use_fds_final and fds_discretizer is None:
                use_fds_final = False

            n_layers = best_params.get("n_layers", 2)
            hidden_dims = [best_params.get(f"n_units_l{i}", 128) for i in range(n_layers)]
            if not hidden_dims:
                hidden_dims = [best_params.get("hidden_dim1", 128), best_params.get("hidden_dim2", 64)]
            
            dropout_rate = best_params.get("dropout", 0.2)

            if algorithm_name == "Simple_Ensemble":
                final_model = SimpleThreeMLPEnsemble(
                    input_dim=X_train.shape[1], hidden_dims=hidden_dims, output_dim=1, dropout_rate=dropout_rate,
                    use_fds=use_fds_final, fds_config=fds_config_final, fds_discretizer=fds_discretizer, use_residual=True
                ).to(DEVICE)
            else:
                final_model = MLPadv(
                    input_dim=X_train.shape[1], hidden_dims=hidden_dims, output_dim=1, dropout_rate=dropout_rate,
                    use_fds=use_fds_final, fds_config=fds_config_final, fds_discretizer=fds_discretizer, use_residual=True
                ).to(DEVICE)

            if final_train_config['optimizer_name'] == "Adam":
                optimizer = optim.Adam(final_model.parameters(), lr=final_train_config['lr'])
            elif final_train_config['optimizer_name'] == "AdamW":
                optimizer = optim.AdamW(final_model.parameters(), lr=final_train_config['lr'])
            else:
                optimizer = optim.RMSprop(final_model.parameters(), lr=final_train_config['lr'])
            
            criterion = get_final_pytorch_criterion(best_params, algorithm_name, y_train_npy.flatten(), weights_tensor=calculated_weights_tensor, gmm_params=gmm_params)

            train_pytorch_model(
                final_model, X_train, y_train_npy.flatten(), X_val, y_val_npy.flatten(),
                criterion, optimizer, final_train_config, DEVICE, trial=None
            )

        elif algorithm_name in ["XGBoost", "LightGBM", "CatBoost"] or algorithm_name.endswith(("_XGBoost", "_LightGBM", "_CatBoost")):
            smoter_params_keys = ['pdf_bandwidth', 'relevance_method', 'relevance_percentile', 'resampling_strategy',
                                  'sigmoid_cl', 'sigmoid_ch', 'sampler_k_neighbors', 'sampler_delta']
            tree_params = {k: v for k, v in best_params.items() if k not in smoter_params_keys}
            sampler_k = best_params.get('sampler_k_neighbors', None)
            sampler_delta = best_params.get('sampler_delta', None)

            base_algo = algorithm_name
            sampler_type = None
            if "SMOTER_" in algorithm_name:
                base_algo = algorithm_name.split("SMOTER_")[1]
                sampler_type = "SMOTER"
            elif "GaussianNoise_" in algorithm_name:
                base_algo = algorithm_name.split("GaussianNoise_")[1]
                sampler_type = "GaussianNoise"

            if base_algo == 'XGBoost':
                final_model = XGBoostWrapper(tree_params)
            elif base_algo == 'LightGBM':
                final_model = LightGBMWrapper(tree_params)
            elif base_algo == 'CatBoost':
                final_model = CatBoostWrapper(tree_params)

            X_train_resampled, y_train_resampled = X_train, y_train_npy.flatten()

            if sampler_type == "SMOTER":
                relevance_method = best_params.get('relevance_method', 'sigmoid')
                if relevance_method == 'sigmoid':
                    cl = best_params.get('sigmoid_cl', np.percentile(y_train_npy.flatten(), 33))
                    ch = best_params.get('sigmoid_ch', np.percentile(y_train_npy.flatten(), 67))
                    relevance = sigmoid_relevance(y_train_npy.flatten(), cl=cl, ch=ch)
                else:
                    bandwidth = best_params.get('pdf_bandwidth', 1.0)
                    relevance = pdf_relevance(y_train_npy.flatten(), bandwidth=bandwidth)

                relevance_threshold = np.percentile(relevance, best_params.get('relevance_percentile', 70))
                resampling_strategy = best_params.get('resampling_strategy', 'balance')
                
                try:
                    X_train_resampled, y_train_resampled = apply_smoter(
                        X_train, y_train_npy.flatten(), relevance=relevance,
                        relevance_threshold=relevance_threshold, k=sampler_k, over=resampling_strategy
                    )
                except Exception:
                    try:
                        X_train_resampled, y_train_resampled = apply_smoter(
                            X_train, y_train_npy.flatten(), relevance=relevance,
                            relevance_threshold=relevance_threshold, k=sampler_k, over='balance'
                        )
                    except Exception:
                        X_train_resampled, y_train_resampled = X_train, y_train_npy.flatten()

            elif sampler_type == "GaussianNoise":
                relevance = pdf_relevance(y_train_npy.flatten(), bandwidth=best_params.get('pdf_bandwidth', 1.0))
                for percentile in [80, 85, 90, 95]:
                    try:
                        relevance_threshold = np.percentile(relevance, percentile)
                        X_train_resampled, y_train_resampled = apply_gaussian_noise(
                            X_train, y_train_npy.flatten(), relevance, relevance_threshold=relevance_threshold, delta=sampler_delta
                        )
                        break
                    except Exception:
                        continue
                else:
                    X_train_resampled, y_train_resampled = X_train, y_train_npy.flatten()

            final_model.fit(X_train_resampled, y_train_resampled.flatten(), X_val, y_val_npy.flatten())

        elif algorithm_name == "CASMIR":
            expert_dim1 = best_params.get("expert_dim1", 128)
            expert_dim2 = best_params.get("expert_dim2", 64)
            expert_hidden = [expert_dim1, expert_dim2]
            gate_dim1 = best_params.get("gate_dim1", 64)
            gate_hidden = [gate_dim1]

            cas_params_final = {
                'k': best_params.get('k_neighbors', CONFIG['casmir_k_neighbors']),
                'feature_bw': best_params.get('feature_bw', CONFIG['casmir_feature_bw']),
                'label_bw': best_params.get('label_bw', CONFIG['casmir_label_bw']),
                'density_factor': best_params.get('density_factor', CONFIG['casmir_density_factor']),
                'strength_base': best_params.get('strength_base', CONFIG['casmir_strength_base']),
                'density_c': best_params.get('density_c', CONFIG['casmir_density_c']),
                'epsilon': 1e-6
            }

            final_model = CASMIR_V1(
                input_dim=X_train.shape[1],
                num_experts=CONFIG.get('casmir_num_experts', 3),
                expert_hidden_dims=expert_hidden,
                gate_hidden_dims=gate_hidden,
                cas_params=cas_params_final,
                expert_type=ada_expert_type
            ).to(DEVICE)

            if final_train_config['optimizer_name'] == "Adam":
                optimizer = optim.Adam(final_model.parameters(), lr=final_train_config['lr'])
            elif final_train_config['optimizer_name'] == "AdamW":
                optimizer = optim.AdamW(final_model.parameters(), lr=final_train_config['lr'])
            else:
                optimizer = optim.RMSprop(final_model.parameters(), lr=final_train_config['lr'])

            casmir_train_params = final_train_config.copy()
            casmir_train_params['lambda_aux'] = best_params.get('lambda_aux', CONFIG['casmir_lambda_aux'])
            casmir_train_params['lambda_load'] = best_params.get('lambda_load', CONFIG['casmir_lambda_load'])
            casmir_train_params['density_boundaries'] = density_boundaries

            train_CASMIR_V1(
                model=final_model, X_train=X_train, y_train=y_train_npy,
                train_densities_np=precomputed_train_densities_np,
                X_val=X_val, y_val=y_val_npy,
                optimizer=optimizer, config=casmir_train_params,
                device=DEVICE, trial=None
            )

    except Exception as e:
        print(f"Final training error for {run_name}: {e}")
        traceback.print_exc()
        return None

    # 9. Final evaluation
    print("Evaluating final model on test set...")

    if not final_model:
        print("Final model was not trained.")
        return None

    y_pred_test = None

    if isinstance(final_model, (CASMIR_V1, MLPadv, SimpleThreeMLPEnsemble)):
        final_model.eval()
        with torch.no_grad():
            X_test_tensor = torch.tensor(X_test, dtype=torch.float32).to(DEVICE)
            X_test_tensor_ori = torch.tensor(X_test_ori, dtype=torch.float32).to(DEVICE)
            if isinstance(final_model, CASMIR_V1):
                y_pred_test = final_model(X_test_tensor, apply_smoothing=False).cpu().numpy().flatten()
                y_pred_test_ori = final_model(X_test_tensor_ori, apply_smoothing=False).cpu().numpy().flatten()
            else:
                y_pred_test = final_model(X_test_tensor).cpu().numpy().flatten()
                y_pred_test_ori = final_model(X_test_tensor_ori).cpu().numpy().flatten()

    elif isinstance(final_model, (XGBoostWrapper, LightGBMWrapper, CatBoostWrapper)):
        y_pred_test = final_model.predict(X_test)
        y_pred_test_ori = final_model.predict(X_test_ori)

    if y_pred_test is not None:
        mae_results_ori = calculate_shot_wise_mae(y_test_ori_npy.flatten(), y_pred_test_ori, test_shot_mapping)
        mae_results_bal = calculate_shot_wise_mae(y_test_npy.flatten(), y_pred_test, test_balanced_shot_indices)
        
        print(f"Few-shot MAE: {mae_results_bal['few']}, {mae_results_ori['few']}")
        print(f"Medium-shot MAE: {mae_results_bal['medium']}, {mae_results_ori['medium']}")
        print(f"Many-shot MAE: {mae_results_bal['many']}, {mae_results_ori['many']}")
        print(f"Overall MAE: {mae_results_bal['overall']}, {mae_results_ori['overall']}")

        # Save artifacts
        data_split_info.update({
            "val_balanced_indices": val_balanced_indices.tolist() if val_balanced_indices is not None else None,
            "test_balanced_indices": test_balanced_indices.tolist() if test_balanced_indices is not None else None,
            "val_balanced_shot_indices": val_balanced_shot_indices,
            "test_balanced_shot_indices": test_balanced_shot_indices,
            "val_shot_mapping": val_shot_mapping,
            "test_shot_mapping": test_shot_mapping
        })
        
        config_snapshot = CONFIG.copy()
        experiment_results = {"bal_mae": mae_results_bal, "ori_mae": mae_results_ori}
        
        artifact_dir, versioned_run_name = save_complete_model_artifact(
            model=final_model,
            algorithm_name=algorithm_name,
            dataset_name=dataset_name,
            best_params=best_params,
            preprocessor=preprocessor,
            data_split_info=data_split_info,
            config_snapshot=config_snapshot,
            fds_discretizer=fds_discretizer,
            gmm_params=gmm_params,
            calculated_weights_tensor=calculated_weights_tensor,
            run_name=run_name,
            experiment_results=experiment_results,
            save_to_learned_models=save_models,
            learned_models_dir=output_dir
        )

    end_time = time.time()
    elapsed_time = end_time - start_time
    print(f"Experiment '{run_name}' completed in {elapsed_time:.2f} seconds")
    
    return {"bal_mae": mae_results_bal, "ori_mae": mae_results_ori}


def main(args):
    seed = args.seed if args.seed is not None else CONFIG.get("random_state", 42)
    set_global_random_seed(seed)
    
    output_dir = args.output_dir if args.output_dir else "learned_models"
    save_requirements()
    
    CONFIG["n_trials"] = args.trials
    CONFIG["random_state"] = seed
    
    if args.hpo_epochs is not None:
        CONFIG["hpo_epochs"] = args.hpo_epochs
    if args.final_epochs is not None:
        CONFIG["final_epochs"] = args.final_epochs

    os.makedirs('./data', exist_ok=True)
    os.makedirs('./outputs', exist_ok=True)

    all_results = {}

    for dataset in args.datasets:
        all_results[dataset] = {}
        
        for algorithm in args.algorithms:
            print(f"\n===== {dataset} / {algorithm} =====")
            
            results = run_experiment(
                dataset, algorithm, args.ada, args.mlpadv, args.num_workers, args.n_jobs,
                load_hpo_dir=args.load_hpo, skip_hpo=args.skip_hpo,
                save_models=args.save_models, output_dir=output_dir
            )

            all_results[dataset][algorithm] = {"balanced": None, "original": None}

            if results is not None:
                if "bal_mae" in results:
                    bal_mae_dict = results["bal_mae"]
                    if not pd.Series(bal_mae_dict).isna().all():
                        all_results[dataset][algorithm]["balanced"] = {
                            "few": float(bal_mae_dict["few"]) if not pd.isna(bal_mae_dict["few"]) else None,
                            "medium": float(bal_mae_dict["medium"]) if not pd.isna(bal_mae_dict["medium"]) else None,
                            "many": float(bal_mae_dict["many"]) if not pd.isna(bal_mae_dict["many"]) else None,
                            "overall": float(bal_mae_dict["overall"]) if not pd.isna(bal_mae_dict["overall"]) else None
                        }

                if "ori_mae" in results:
                    ori_mae_dict = results["ori_mae"]
                    if not pd.Series(ori_mae_dict).isna().all():
                        all_results[dataset][algorithm]["original"] = {
                            "few": float(ori_mae_dict["few"]) if not pd.isna(ori_mae_dict["few"]) else None,
                            "medium": float(ori_mae_dict["medium"]) if not pd.isna(ori_mae_dict["medium"]) else None,
                            "many": float(ori_mae_dict["many"]) if not pd.isna(ori_mae_dict["many"]) else None,
                            "overall": float(ori_mae_dict["overall"]) if not pd.isna(ori_mae_dict["overall"]) else None
                        }

    print("\n--- Results Summary ---")

    def convert_nan(o):
        if isinstance(o, float) and np.isnan(o):
            return None
        if isinstance(o, np.generic):
            return o.item()
        if isinstance(o, np.ndarray):
            return o.tolist()
        return o

    print(json.dumps(all_results, indent=2, default=convert_nan, ensure_ascii=False))

    results_file = f'./Benchmark/ALL_benchmark_results_{args.algorithms[0]}.json'
    try:
        os.makedirs('./Benchmark', exist_ok=True)
        with open(results_file, 'w', encoding='utf-8') as f:
            json.dump(all_results, f, indent=2, default=convert_nan, ensure_ascii=False)
        print(f"Results saved to {results_file}")
    except Exception as write_err:
        print(f"Error saving results: {write_err}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="DIR Benchmark Training")
    parser.add_argument('--datasets', nargs='+', default=CONFIG["datasets_to_run"], help='Datasets to run')
    parser.add_argument('--algorithms', nargs='+', default=CONFIG["algorithms_to_run"], help='Algorithms to run')
    parser.add_argument('--trials', type=int, default=CONFIG["n_trials"], help='HPO trials')
    parser.add_argument('--gpu', type=int, default=None, help='GPU device ID')
    parser.add_argument('--cpu', action='store_true', help='Force CPU only')
    parser.add_argument('--ada', type=str, default='tabular', choices=['tabular', 'basic'], help='CASMIR expert type')
    parser.add_argument('--mlpadv', action='store_true', default=True, help='[Deprecated] MLP is now an alias for MLPadv')
    parser.add_argument('--num_workers', type=int, default=0, help='DataLoader num_workers')
    parser.add_argument('--n_jobs', type=int, default=1, help='Optuna parallel trials')
    parser.add_argument('--load_hpo', type=str, default=None, help='Load HPO results from directory')
    parser.add_argument('--skip_hpo', action='store_true', help='Skip HPO and use loaded hyperparameters')
    parser.add_argument('--save_models', action='store_true', help='Save trained models to learned_models/')
    parser.add_argument('--hpo_epochs', type=int, default=None, help='HPO epochs')
    parser.add_argument('--final_epochs', type=int, default=None, help='Final training epochs')
    parser.add_argument('--seed', type=int, default=None, help='Random seed')
    parser.add_argument('--output_dir', type=str, default=None, help='Model output directory')

    args = parser.parse_args()
    
    if args.gpu is not None and args.cpu:
        print("Error: Cannot use --gpu and --cpu together.")
        sys.exit(1)
    
    DEVICE = setup_device(args.gpu, force_cpu=args.cpu)
    main(args)
