"""
SONATA Training Script

Trains improved SONATA model integrating:
- Martingale theory coreset selection
- Optimal stopping strategy
- Multi-scale time weighting with Ito formula
- Enhanced state estimation methods
"""

import numpy as np
import torch
import tqdm
import yaml
import time
import argparse
import os
import logging
import sys
from datetime import datetime

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Parse command line arguments
def parse_args_SONATA():
    description = "Martingale Dynamic CoreSet Tensor Factorization"
    parser = argparse.ArgumentParser(description=description)
    parser.add_argument(
        '--dataset',
        type=str,
        default=None,
        help='Dataset name: beijing_15k, beijing_20k, server, or traffic. If not specified, uses active_dataset from config.'
    )
    parser.add_argument(
        "--method",
        type=str,
        default=None,
        help="Decomposition method: CP or Tucker. If not specified, uses active_method from config."
    )
    parser.add_argument(
        "--config",
        type=str,
        default=None,
        help="Custom config file path. Default: ./config/config_SONATA.yaml"
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default="./result_log/sonata",
        help="Output directory"
    )
    parser.add_argument(
        "--device",
        type=str,
        default=None,
        help="Computation device (cpu or cuda). If specified, overrides config."
    )
    parser.add_argument(
        "--coreset_size",
        type=int,
        default=None,
        help="Maximum coreset size. If specified, overrides config."
    )
    parser.add_argument(
        "--rank",
        type=int,
        default=None,
        help="Embedding dimension (R_U). If specified, overrides config."
    )
    
    args = parser.parse_args()
    return args

# Load configuration file
def load_config(args):
    """Load unified config file and extract relevant configuration based on dataset and method"""
    if args.config:
        config_path = args.config
    else:
        config_path = "./config_SONATA.yaml"
    
    logger.info(f"Loading unified config file: {config_path}")
    
    try:
        # Explicitly specify UTF-8 encoding
        with open(config_path, "r", encoding="utf-8") as f:
            full_config = yaml.safe_load(f)
    except UnicodeDecodeError:
        # Try other encoding if UTF-8 fails
        try:
            with open(config_path, "r", encoding="latin1") as f:
                full_config = yaml.safe_load(f)
        except Exception as e:
            logger.error(f"Cannot read config file {config_path}: {e}")
            logger.error("Please ensure config file is saved with UTF-8 encoding")
            exit(1)
    except FileNotFoundError:
        logger.error(f"Config file does not exist: {config_path}")
        logger.info(f"Creating default unified config file: {config_path}")
        # Use the command-line args to override the active dataset/method
        create_default_unified_config(config_path)
        with open(config_path, "r", encoding="utf-8") as f:
            full_config = yaml.safe_load(f)

    # Extract relevant configuration
    # Get active dataset and method - first from config file, then override with command line if provided
    active_dataset = full_config.get("active_dataset", "beijing_20k")
    active_method = full_config.get("active_method", "CP")
    
    # Override with command line arguments if provided
    if args.dataset is not None:
        active_dataset = args.dataset.lower()
        logger.info(f"Using dataset from command line: {active_dataset}")
    else:
        logger.info(f"Using dataset from config file: {active_dataset}")
        
    if args.method is not None:
        active_method = args.method.upper()
        logger.info(f"Using method from command line: {active_method}")
    else:
        logger.info(f"Using method from config file: {active_method}")
    
    # Build the complete config by layering:
    # 1. Start with default parameters
    config = full_config.get("default", {}).copy()
    
    # 2. Apply dataset-specific overrides
    if full_config.get("datasets") and full_config["datasets"].get(active_dataset):
        dataset_config = full_config["datasets"][active_dataset]
        config.update(dataset_config)
    
    # 3. Apply method-specific overrides
    if full_config.get("methods") and full_config["methods"].get(active_method):
        method_config = full_config["methods"][active_method]
        config.update(method_config)
    
    # 4. Apply dataset-method combination overrides
    combo_key = f"{active_dataset}_{active_method}"
    if full_config.get("combinations") and full_config["combinations"].get(combo_key):
        combo_config = full_config["combinations"][combo_key]
        config.update(combo_config)
    
    # 5. Apply device and seed from main parameters
    if full_config.get("device"):
        config["device"] = full_config.get("device")
    if full_config.get("seed"):
        config["seed"] = full_config.get("seed")
    
    # Ensure config contains necessary parameters
    config["method"] = active_method
    config["dataset"] = active_dataset
    
    logger.info(f"Using configuration for dataset '{active_dataset}' with method '{active_method}'")
    
    return config

# Create default unified configuration file
def create_default_unified_config(config_path):
    """Create default unified config file that can control all parameters"""
    default_config = {
        # Main Parameters
        "active_dataset": "beijing_20k",
        "active_method": "CP",
        "device": "cpu",
        "seed": 300,
        
        # Global Default Parameters
        "default": {
            # Basic parameters
            "epoch": 100,
            "fold": 1,
            "fix_int": True,
            "time_type": "continues",
            
            # Tensor factorization parameters
            "R_U": 3,
            "a0": 1,
            "b0": 1,
            "v": 1,
            
            # Coreset parameters
            "coreset_max_size": 100,
            "coreset_threshold": 0.6,
            "adaptive_threshold": True,
            "importance_weights": [0.3, 0.2, 0.2, 0.3],
            
            # Martingale theory parameters
            "prediction_history_size": 50,
            "simulation_samples": 5,
            "discount_factor": 0.9,
            "bellman_optimization": True,
            
            # Multi-scale time weighting
            "num_time_scales": 3,
            "scale_hidden_dim": 32,
            "attention_temperature": 1.0,
            "time_scale_factor": 0.1,
            
            # Exploration-exploitation balance
            "initial_exploration_rate": 0.9,
            "exploration_decay_rate": 0.1,
            
            # Gaussian Process parameters
            "kernel": "Matern_23",
            "lengthscale": 0.3,
            "variance": 1,
            "noise": 1,
            
            # Optimization parameters
            "DAMPING": 0.5,
            "DAMPING_tau": 0.6,
            "DAMPING_gamma": 0.5,
            "EVALU_T": 60,
            "INNER_ITER": 50,
            "THRE": 1.0e-4,
            "CEP_UPDATE_INNNER_MODE": False,
        },
        
        # Dataset-Specific Parameters
        "datasets": {
            "beijing_15k": {
                "data_path": "data/beijing_15k.npy",
                "coreset_max_size": 800,
                "lengthscale": 0.25,
                "EVALU_T": 50,
            },
            
            "beijing_20k": {
                "data_path": "data/beijing_20k.npy",
                "coreset_max_size": 1000,
                "lengthscale": 0.3,
                "EVALU_T": 60,
            },
            
            "server": {
                "data_path": "data/server.npy",
                "coreset_max_size": 1200,
                "R_U": 4,
                "lengthscale": 0.35,
                "EVALU_T": 40,
            },
            
            "traffic": {
                "data_path": "data/traffic.npy",
                "coreset_max_size": 1500,
                "R_U": 5,
                "lengthscale": 0.4,
                "EVALU_T": 30,
            },
        },
        
        # Method-Specific Parameters
        "methods": {
            "CP": {
                "DAMPING": 0.5,
                "DAMPING_tau": 0.6,
            },
            
            "TUCKER": {
                "DAMPING": 0.5,
                "DAMPING_tau": 0.6,
                "DAMPING_gamma": 0.5,
            },
        },
        
        # Combined Dataset-Method Parameter Overrides
        "combinations": {
            "beijing_15k_CP": {
                "importance_weights": [0.35, 0.15, 0.2, 0.3],
            },
            
            "beijing_20k_TUCKER": {
                "R_U": 4,
                "scale_hidden_dim": 48,
            },
            
            "server_CP": {
                "initial_exploration_rate": 0.85,
                "time_scale_factor": 0.15,
            },
            
            "traffic_TUCKER": {
                "R_U": 6,
                "discount_factor": 0.85,
                "attention_temperature": 1.2,
            },
        },
    }
    
    os.makedirs(os.path.dirname(config_path), exist_ok=True)
    with open(config_path, "w", encoding="utf-8") as f:
        yaml.dump(default_config, f, default_flow_style=False, sort_keys=False)
    
    logger.info(f"Default unified config file created: {config_path}")


# Maintain backward compatibility with old config format
def create_default_config(config_path, method="CP"):
    """Create default config file (old format for backward compatibility)"""
    default_config = {
        "device": "cpu",
        "method": method.upper(),
        "seed": 300,
        "epoch": 100,
        "R_U": 3,
        "a0": 1,
        "b0": 1,
        "v": 1,
        "fold": 1,
        "time_type": "continues",
        
        # Coreset parameters
        "coreset_max_size": 100,
        "coreset_threshold": 0.6,
        "adaptive_threshold": True,
        "importance_weights": [0.3, 0.2, 0.2, 0.3],  # (uncertainty, influence, novelty, martingale_increment)
        
        # Multi-scale weighting parameters
        "num_time_scales": 3,
        "scale_hidden_dim": 32,
        "attention_temperature": 1.0,
        "time_scale_factor": 0.1,
        
        # Martingale theory related parameters
        "prediction_history_size": 50,
        "discount_factor": 0.9,
        "simulation_samples": 5,
        "bellman_optimization": True,
        
        # Exploration-exploitation balance parameters
        "initial_exploration_rate": 0.9,
        "exploration_decay_rate": 0.1,
        
        # Other model parameters
        "kernel": "Matern_23",
        "lengthscale": 0.3,
        "variance": 1,
        "noise": 1,
        "DAMPING": 0.5,
        "DAMPING_tau": 0.6,
        "DAMPING_gamma": 0.5,
        "EVALU_T": 20,
        "INNER_ITER": 50,
        "THRE": 1.0e-4,
        "CEP_UPDATE_INNNER_MODE": False,
    }
    
    os.makedirs(os.path.dirname(config_path), exist_ok=True)
    with open(config_path, "w", encoding="utf-8") as f:
        yaml.dump(default_config, f, default_flow_style=False, sort_keys=False)
    
    logger.info(f"Default config file created (legacy format): {config_path}")

# Try to import required modules
try:
    from model_SONATA import create_martingale_dctf
    from utils_martingale_coreset import make_martingale_coreset_dict
except ImportError as e:
    logger.error(f"Required module not found: {e}")
    logger.error("Please make sure model_SONATA.py and utils_martingale_coreset.py are available.")
    sys.exit(1)

# Try to import utils_streaming, with fallback implementation if not available
try:
    import utils_streaming
except ImportError:
    logger.warning("utils_streaming not found, using minimal implementation")
    
    # Minimal implementation of necessary functions from utils_streaming
    class MinimalUtilsStreaming:
        @staticmethod
        def make_data_dict(hyper_dict, data_file, fold_id):
            """Minimal data dictionary creator - only works with numpy arrays"""
            try:
                # Load data from file
                data = np.load(data_file, allow_pickle=True).item()
                
                # Extract train/test indices and values
                tr_ind = data.get('tr_ind', [])
                tr_y = data.get('tr_y', [])
                tr_T_disct = data.get('tr_T_disct', [])
                
                te_ind = data.get('te_ind', [])
                te_y = data.get('te_y', [])
                te_T_disct = data.get('te_T_disct', [])
                
                time_uni = data.get('time_uni', [])
                ndims = data.get('ndims', [])
                
                # Create data dictionary
                data_dict = {
                    'tr_ind': tr_ind,
                    'tr_y': tr_y,
                    'tr_T_disct': tr_T_disct,
                    'te_ind': te_ind,
                    'te_y': te_y,
                    'te_T_disct': te_T_disct,
                    'time_uni': time_uni,
                    'ndims': ndims
                }
                
                return data_dict
            except Exception as e:
                logger.error(f"Error loading data: {e}")
                # Return empty data dict as fallback
                return {
                    'tr_ind': [],
                    'tr_y': [],
                    'tr_T_disct': [],
                    'te_ind': [],
                    'te_y': [],
                    'te_T_disct': [],
                    'time_uni': [],
                    'ndims': []
                }
        
        @staticmethod
        def get_post(model, T):
            """Extract model posterior parameters for convergence check"""
            # Get factor posteriors
            post_elements = []
            for m in range(model.nmods):
                post_elements.append(model.post_U_m[m][:, :, :, T].reshape(-1))
                diag_elements = torch.diagonal(model.post_U_v[m][:, :, :, T], dim1=1, dim2=2)
                post_elements.append(diag_elements.reshape(-1))
            
            # Add tau posterior
            post_elements.append(torch.tensor([model.E_tau]))
            
            # Concatenate all
            return torch.cat(post_elements)
        
        @staticmethod
        def make_log(args, config, result_dict):
            """Simple log saver"""
            try:
                # Create output directory
                log_dir = f"./result_log/SONATA/{config['dataset']}"
                os.makedirs(log_dir, exist_ok=True)
                
                # Save log to file
                timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
                log_file = f"{log_dir}/{config['method']}_{config['R_U']}_{timestamp}.txt"
                
                with open(log_file, "w") as f:
                    f.write(f"Dataset: {config['dataset']}\n")
                    f.write(f"Method: {config['method']}\n")
                    f.write(f"Rank: {config['R_U']}\n")
                    f.write(f"RMSE: {result_dict['rmse_avg']:.4f} ± {result_dict['rmse_std']:.4f}\n")
                    f.write(f"MAE: {result_dict['MAE_avg']:.4f} ± {result_dict['MAE_std']:.4f}\n")
                    f.write(f"Time: {result_dict['time']:.2f}s\n")
                    
                    f.write("\nConfig:\n")
                    for k, v in config.items():
                        f.write(f"{k}: {v}\n")
                
                logger.info(f"Log saved to {log_file}")
            except Exception as e:
                logger.error(f"Error saving log: {e}")
    
    # Create instance to use as module
    utils_streaming = MinimalUtilsStreaming()

# Train SONATA model
def train_SONATA(config, args):
    """Train SONATA model"""
    # Set output directory
    output_dir = args.output_dir
    
    # Set random seed
    seed = config.get("seed", 300)
    torch.random.manual_seed(seed)
    np.random.seed(seed)
    
    # Create hyperparameter dictionary
    hyper_dict = make_martingale_coreset_dict(config)
    
    # Print run information
    logger.info(f"Dataset: {config['dataset']}, Method: {config['method']}, Rank: {hyper_dict['R_U']}, "
               f"Coreset max size: {hyper_dict['coreset_max_size']}")
    
    # Data file path
    data_file = config.get("data_path", f"data/{config['dataset']}.npy")
    
    # Convergence threshold and inner iteration count
    THRE = hyper_dict["THRE"]
    INNER_ITER = hyper_dict["INNER_ITER"]
    num_fold = config.get("fold", 1)

    # Online metrics storage
    running_rmse = []
    running_MAE = []
    running_N = []
    running_T = []
    test_rmse = []
    test_MAE = []
    
    # Coreset metrics storage
    coreset_sizes = []
    coreset_changes = []

    # Initialize results dictionary
    result_dict = {}
    start_time = time.time()

    # Loop over each fold
    for fold_id in range(num_fold):
        logger.info(f"Processing fold {fold_id+1}/{num_fold}")

        # Prepare data
        data_dict = utils_streaming.make_data_dict(hyper_dict, data_file, fold_id)

        # Initialize model
        model = create_martingale_dctf(hyper_dict, data_dict)
        model.reset()

        # Observation counter
        N = 0

        # Previous coreset size, for change calculation
        last_coreset_size = 0

        # Process each time step
        for T_id in tqdm.tqdm(range(len(model.unique_train_time)), desc=f"Processing time steps"):
            """ Flow: track_envloved_objects->filter_predict->msg_llk_init->update_coreset->CEP iteration->evaluate """

            T = model.unique_train_time[T_id]
            model.track_envloved_objects(T_id)

            N = N + model.N_T

            # KF prediction step
            model.filter_predict(T)
            
            # Initialize likelihood messages
            model.msg_llk_init()
            
            # Update data point coreset
            added, removed = model.update_coreset(T)
            
            # Record coreset metrics
            if T_id % 5 == 0:  # Record every 5 time steps
                current_size = model.coreset_manager.get_coreset_size()
                coreset_sizes.append(current_size)
                
                # Calculate change amount
                change = len(added) + len(removed)
                coreset_changes.append(change)
                
                # Update previous coreset size
                last_coreset_size = current_size
                
                if T_id % 10 == 0:  # Print every 10 steps
                    logger.info(f"T={T}, coreset size: {current_size}/{hyper_dict['coreset_max_size']}, " 
                              f"current batch data points: {model.N_T}, total data points: {N}, "
                              f"added: {len(added)}, removed: {len(removed)}")

            # CEP inner iteration
            for inner_it in range(INNER_ITER):
                # Get old posterior
                old_post = utils_streaming.get_post(model, T)
                flag = (inner_it == (INNER_ITER - 1))

                model.msg_U_m = []
                model.msg_U_V = []

                # Update based on configuration
                if hyper_dict['CEP_UPDATE_INNNER_MODE'] == True:
                    for mode in range(model.nmods):
                        model.msg_approx_U(T, mode)
                        model.filter_update(T, mode, flag)
                else:
                    for mode in range(model.nmods):
                        model.msg_approx_U(T, mode)

                    for mode in range(model.nmods):
                        model.filter_update(T, mode, flag)

                # Update tau
                model.msg_approx_tau(T)
                model.post_update_tau(T)
                
                # For Tucker model, update gamma
                if config["method"].upper() == "TUCKER":
                    model.msg_approx_gamma(T)
                    model.post_update_gamma(T)

                # Check convergence
                new_post = utils_streaming.get_post(model, T)
                relative_change = torch.square(new_post - old_post).sum() / (old_post.norm() + 1e-8)

                # Exit if last iteration or converged
                if flag:
                    '''Will not converge until max iterations'''
                    pass
                elif relative_change < THRE:
                    '''Early convergence'''
                    if hyper_dict['CEP_UPDATE_INNNER_MODE'] == True:
                        for mode in range(model.nmods):
                            model.msg_approx_U(T, mode)
                            model.filter_update(T, mode, True)
                    else:
                        for mode in range(model.nmods):
                            model.msg_approx_U(T, mode)

                        for mode in range(model.nmods):
                            model.filter_update(T, mode, True)

                    model.msg_approx_tau(T)
                    model.post_update_tau(T)
                    
                    # For Tucker model, update gamma
                    if config["method"].upper() == "TUCKER":
                        model.msg_approx_gamma(T)
                        model.post_update_gamma(T)
                        
                    break
                
            # Online evaluation (only for first fold)
            if hyper_dict["EVALU_T"] > 0 and fold_id == 0:
                """Store running test results (only for fold_0)"""
                if T % hyper_dict["EVALU_T"] == 0:
                    model.inner_smooth()
                    _, test_result = model.model_test(model.te_ind, model.te_y, model.test_time_ind)

                    logger.info(f"T: {T}, running error: {test_result['rmse']:.4f}, coreset size: {model.coreset_manager.get_coreset_size()}/{hyper_dict['coreset_max_size']}, "
                            f"current batch data points: {model.N_T}, total data points: {N}, "
                            f"coreset proportion: {model.coreset_manager.get_coreset_size()/N*100:.2f}%")
                    running_MAE.append(test_result['MAE'].cpu().numpy().squeeze())
                    running_rmse.append(test_result['rmse'].cpu().numpy().squeeze())

                    running_T.append(T)
                    running_N.append(N)

        # Final smoothing and posterior calculation
        model.smooth()
        model.get_post_U()

        # Final testing
        pred, test_result = model.model_test(model.te_ind, model.te_y, model.test_time_ind)

        # Record test results
        test_MAE.append(test_result['MAE'].cpu().numpy().squeeze())
        test_rmse.append(test_result['rmse'].cpu().numpy().squeeze())

        if fold_id == 0:
            running_MAE.append(test_result['MAE'].cpu().numpy().squeeze())
            running_rmse.append(test_result['rmse'].cpu().numpy().squeeze())
            running_T.append(T)
            running_N.append(N)

        logger.info(f"Fold {fold_id+1} test results: RMSE = {test_result['rmse']:.4f}, MAE = {test_result['MAE']:.4f}")

    # Calculate results statistics
    rmse_array = np.array(test_rmse)
    MAE_array = np.array(test_MAE)

    running_rmse_array = np.array(running_rmse)
    running_MAE_array = np.array(running_MAE)

    # Build results dictionary
    result_dict['time'] = time.time() - start_time
    result_dict['rmse_avg'] = rmse_array.mean()
    result_dict['rmse_std'] = rmse_array.std()
    result_dict['MAE_avg'] = MAE_array.mean()
    result_dict['MAE_std'] = MAE_array.std()

    result_dict['running_rmse'] = running_rmse_array
    result_dict['running_MAE'] = running_MAE_array

    result_dict['running_T'] = np.array(running_T)
    result_dict['running_N'] = np.array(running_N)
    
    # Add coreset metrics
    result_dict['coreset_sizes'] = np.array(coreset_sizes)
    result_dict['coreset_changes'] = np.array(coreset_changes)

    # Output final results
    logger.info(f"Results saved to {output_dir}. Final RMSE: {result_dict['rmse_avg']:.4f} ± {result_dict['rmse_std']:.4f}, "
            f"coreset size: {coreset_sizes[-1] if coreset_sizes else 0}/{hyper_dict['coreset_max_size']}, "
            f"total data points: {N}, coreset usage rate: {(coreset_sizes[-1] if coreset_sizes else 0)/N*100:.2f}%")

    return result_dict, rmse_array, MAE_array, coreset_sizes


# Main function
def main():
    # Parse command line arguments
    args = parse_args_SONATA()
    
    # Load configuration file
    config = load_config(args)
    
    # Apply command line overrides to config
    if args.device:
        config["device"] = args.device
        logger.info(f"Overriding device with command line setting: {args.device}")
    
    if args.coreset_size:
        config["coreset_max_size"] = args.coreset_size
        logger.info(f"Overriding coreset_max_size with command line setting: {args.coreset_size}")
    
    if args.rank:
        config["R_U"] = args.rank
        logger.info(f"Overriding rank (R_U) with command line setting: {args.rank}")
    
    # Create output directory
    output_dir = args.output_dir
    os.makedirs(output_dir, exist_ok=True)
    
    # Log active configuration
    logger.info(f"Active configuration:")
    logger.info(f"  Dataset: {config['dataset']}")
    logger.info(f"  Method: {config['method']}")
    logger.info(f"  Device: {config['device']}")
    logger.info(f"  Rank (R_U): {config['R_U']}")
    logger.info(f"  Coreset max size: {config.get('coreset_max_size', 'Not specified')}")
    
    # Train model
    result_dict, rmse_array, MAE_array, coreset_sizes = train_SONATA(config, args)
    
    # Save detailed log
    # Convert config to args format for logging
    class Args:
        def __init__(self, config):
            self.dataset = config['dataset']
            self.method = config['method']
            self.machine = 'local'
            self.num_fold = config.get('fold', 1)
    
    args_obj = Args(config)
    utils_streaming.make_log(args_obj, config, result_dict)
    
    # Save dedicated coreset log with meaningful filename
    dataset_name = config['dataset']
    method = config['method']
    rank = config['R_U']
    cs_size = config.get('coreset_max_size', 100)
    
    # Add timestamp to prevent overwriting
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    log_path = f"{output_dir}/{dataset_name}_{method}_R{rank}_CS{cs_size}_{timestamp}.npy"
    
    detailed_results = {
        'coreset_sizes': result_dict['coreset_sizes'],
        'coreset_changes': result_dict['coreset_changes'],
        'rmse': rmse_array,
        'mae': MAE_array,
        'running_rmse': result_dict['running_rmse'],
        'running_mae': result_dict['running_MAE'],
        'config': config,  # Save config for reproducibility
    }
    np.save(log_path, detailed_results)
    
    # Create a summary file for easy reference
    summary_path = f"{output_dir}/summary_{dataset_name}_{method}_{timestamp}.txt"
    with open(summary_path, "w") as f:
        f.write(f"Dataset: {dataset_name}\n")
        f.write(f"Method: {method}\n")
        f.write(f"Rank (R_U): {rank}\n")
        f.write(f"Coreset max size: {cs_size}\n")
        f.write(f"RMSE: {result_dict['rmse_avg']:.4f} ± {result_dict['rmse_std']:.4f}\n")
        f.write(f"MAE: {result_dict['MAE_avg']:.4f} ± {result_dict['MAE_std']:.4f}\n")
        f.write(f"Runtime: {result_dict['time']:.2f} seconds\n")
        f.write(f"Final coreset size: {coreset_sizes[-1] if len(coreset_sizes) > 0 else 0}\n")
        f.write(f"Coreset usage rate: {(coreset_sizes[-1] if len(coreset_sizes) > 0 else 0)/result_dict.get('running_N', [-1])[-1]*100:.2f}%\n")
    
    logger.info(f"Results saved to {output_dir}")
    logger.info(f"Summary saved to {summary_path}")
    logger.info(f"Final RMSE: {result_dict['rmse_avg']:.4f} ± {result_dict['rmse_std']:.4f}")
    logger.info(f"Final MAE: {result_dict['MAE_avg']:.4f} ± {result_dict['MAE_std']:.4f}")
    logger.info(f"Runtime: {result_dict['time']:.2f} seconds")


# If script is executed directly
if __name__ == "__main__":
    main()