import argparse
import glob
import logging
import math
import multiprocessing
import os
import pickle
import random
import string
import sys
import threading
import time
import warnings
from datetime import datetime
from tabulate import tabulate
from typing import Optional, Tuple

import numpy as np
from sklearn.base import clone
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_auc_score
from sklearn.exceptions import UndefinedMetricWarning

# --- Script Configuration ---
DEFAULT_LOG_FORMAT = '%(asctime)s - %(levelname)s - [%(processName)s/%(threadName)s] - %(message)s'
# Suppress specific warnings
warnings.filterwarnings("ignore", category=UndefinedMetricWarning)
warnings.filterwarnings("ignore", message="Liblinear failed to converge*")
warnings.filterwarnings("ignore", category=UserWarning, module='sklearn') # For convergence warnings not caught by message

# --- Logger Setup ---
logger = logging.getLogger(__name__)

def setup_logger(log_level_str: str = "INFO", log_file: str = None):
    """Configures the global logger for the script.

    Args:
        log_level_str (str, optional): Logging level string (e.g., "INFO", "DEBUG"). Defaults to "INFO".
        log_file (str, optional): Path to a log file. If provided, logs are also written here. Defaults to None.
    """
    log_level = getattr(logging, log_level_str.upper(), logging.INFO)
    logger.setLevel(log_level)
    
    # Remove existing handlers to avoid duplicate logs if called multiple times
    for handler in logger.handlers[:]:
        logger.removeHandler(handler)
            
    formatter = logging.Formatter(DEFAULT_LOG_FORMAT)
    
    # Console Handler
    ch = logging.StreamHandler(sys.stdout)
    ch.setLevel(log_level)
    ch.setFormatter(formatter)
    logger.addHandler(ch)
    
    # File Handler (Optional)
    if log_file:
        fh = logging.FileHandler(log_file)
        fh.setLevel(log_level)
        fh.setFormatter(formatter)
        logger.addHandler(fh)

# --- Utility Functions ---
def generate_random_string(length: int) -> str:
    """Generates a random string of specified length for unique identifiers.

    Args:
        length (int): Length of the random string to generate.

    Returns:
        str: Generated random string.
    """
    letters_and_digits = string.ascii_lowercase + string.digits
    return ''.join(random.choice(letters_and_digits) for _ in range(length))

def format_time_seconds(seconds: float) -> str:
    """Formats a duration in seconds into H:M:S string.

    Args:
        seconds (float): Duration in seconds.

    Returns:
        str: Formatted time string (H:M:S) or "--:--:--" for negative input.
    """
    if seconds < 0:
        return "--:--:--"
    hours = int(seconds // 3600)
    minutes = int((seconds % 3600) // 60)
    secs = int(seconds % 60)
    return f"{hours:02d}:{minutes:02d}:{secs:02d}"

def calculate_auc_metric(clf: LogisticRegression, x_data: np.ndarray, y_data: np.ndarray) -> float:
    """
    Calculates the Area Under the ROC Curve (AUC).
    Handles cases with a single class in y_data by returning 0.5.
    """
    # Calculates the Area Under the Receiver Operating Characteristic Curve (AUC-ROC).
    # Args:
    #   clf: A pre-trained LogisticRegression classifier.
    #   x_data: Feature data for evaluation.
    #   y_data: Corresponding true labels for evaluation.
    # Returns:
    #   The AUC score. Returns 0.5 (representing random chance) if y_data contains only a single class 
    #   or if the classifier is not properly fitted.
    if len(np.unique(y_data)) < 2:
        # logger.debug("AUC calculation: Only one class present in y_data. Returning 0.5.") # Debug for worker can be too verbose
        return 0.5
    try:
        # Ensure the model is fitted before predict_proba
        if not hasattr(clf, "classes_") or len(clf.classes_) < 2:
             # This can happen if fit fails or is called with single class y_train
             # logger.warning("AUC calculation: Classifier not properly fitted or only one class seen during fit. Returning 0.5.")
             return 0.5
        probs = clf.predict_proba(x_data)
        return roc_auc_score(y_data, probs[:, 1])
    except ValueError: # Catches "Only one class present in y_true. ROC AUC score is not defined in that case."
        # logger.warning(f"AUC calculation error (ValueError): {e}. Returning 0.5.")
        return 0.5
    except Exception: # Catch other potential errors like NotFittedError if check above fails
        # logger.warning(f"AUC calculation error (Exception): {e}. Returning 0.5.")
        return 0.5

# --- Core Shapley Value Calculator Class ---
class ShapleyCalculator:
    """Calculates Shapley values for a single label using TMC-Shapley with Logistic Regression."""

    def __init__(self, 
                 x_train: np.ndarray, 
                 y_train_label: np.ndarray, 
                 x_test: np.ndarray, 
                 y_test_label: np.ndarray, 
                 output_dir: str, 
                 c_value: float, 
                 metric_function: callable,
                 lr_base_params: dict,
                 truncation_tolerance: float = 0.01,
                 random_seed: int = None):
        """Initializes the ShapleyCalculator for a specific label.

        Args:
            x_train (np.ndarray): Training features.
            y_train_label (np.ndarray): Training labels for the specific target.
            x_test (np.ndarray): Test features.
            y_test_label (np.ndarray): Test labels for the specific target.
            output_dir (str): Directory to save iteration results for this label.
            c_value (float): Regularization parameter C for Logistic Regression.
            metric_function (callable): Function to evaluate model performance (e.g., calculate_auc_metric).
            lr_base_params (dict): Base parameters for the Logistic Regression classifier.
            truncation_tolerance (float, optional): Tolerance for TMC early stopping. Defaults to 0.01.
            random_seed (int, optional): Seed for random number generation. Defaults to None.
        """
        self.x_train = x_train
        self.y_train_label = y_train_label
        self.x_test = x_test
        self.y_test_label = y_test_label
        self.output_dir = output_dir
        self.c_value = c_value
        self.metric_function = metric_function
        self.lr_base_params = lr_base_params
        self.truncation_tolerance = truncation_tolerance
        self.n_train_samples = len(x_train)
        self.random_seed = random_seed
        self.calculator_label_name = os.path.basename(output_dir)


        os.makedirs(self.output_dir, exist_ok=True)
        
        if self.random_seed is not None:
            np.random.seed(self.random_seed)
            random.seed(self.random_seed)

        self._base_clf = self._initialize_classifier()
        self.random_performance = 0.5  # Default for AUC, assuming random chance
        self.full_data_performance = self._calculate_full_data_performance()

    def _initialize_classifier(self) -> LogisticRegression:
        """Initializes a Logistic Regression classifier with the specified C value."""
        current_lr_params = self.lr_base_params.copy()
        current_lr_params['C'] = self.c_value
        return LogisticRegression(**current_lr_params)

    def _get_cloned_classifier(self) -> LogisticRegression:
        """Returns a fresh, cloned instance of the base classifier to prevent state leakage."""
        return clone(self._base_clf)

    def _calculate_full_data_performance(self) -> float:
        """Calculates model performance on the test set after training on the entire training dataset.
        
        This value corresponds to v(N) in TMC-Shapley (utility of the grand coalition).
        Returns random performance (0.5 for AUC) if training data is empty or has only one class.
        """
        if len(np.unique(self.y_train_label)) < 2:
            # logger.debug(f"Calculator for {self.calculator_label_name}: Training data has only one class. Full data perf set to 0.5.")
            return self.random_performance
        if self.n_train_samples == 0:
            # logger.debug(f"Calculator for {self.calculator_label_name}: No training data. Full data perf set to 0.5.")
            return self.random_performance
            
        try:
            clf = self._get_cloned_classifier()
            clf.fit(self.x_train, self.y_train_label)
            performance = self.metric_function(clf, self.x_test, self.y_test_label)
            # logger.debug(f"Calculator for {self.calculator_label_name}: Full data performance = {performance:.4f}")
            return performance
        except Exception as e:
            # Use a more specific logger if available, or pass logger instance
            # print(f"Error calculating full data performance for {self.calculator_label_name}: {e}") # Fallback print
            logger.warning(f"Calculator for {self.calculator_label_name}: Error during full data performance calculation: {e}. Setting to 0.5.", exc_info=False) # Keep log clean
            return self.random_performance

    def _run_one_tmc_iteration(self) -> None:
        """Executes one iteration of the Truncated Monte Carlo (TMC) Shapley algorithm.

        Steps:
        1. Randomly permute training sample indices.
        2. Iteratively add samples to a subset based on the permutation.
        3. Train and evaluate the model after each sample addition.
        4. Calculate the marginal contribution of the added sample.
        5. Truncate iteration early if subset performance is close to full dataset performance.
        6. Persist marginal contributions and permutation to a .pkl file.
        """
        if self.n_train_samples == 0:
            return

        permutation_indices = np.random.permutation(self.n_train_samples)
        marginal_contributions = np.zeros(self.n_train_samples, dtype=float)
        
        current_x_batch = np.zeros((0,) + tuple(self.x_train.shape[1:]), dtype=self.x_train.dtype)
        current_y_batch = np.zeros(0, dtype=self.y_train_label.dtype)

        truncation_counter = 0
        previous_performance = self.random_performance
        
        num_points_processed = 0
        for i, sample_idx in enumerate(permutation_indices):
            num_points_processed += 1
            current_x_batch = np.append(current_x_batch, [self.x_train[sample_idx]], axis=0)
            current_y_batch = np.append(current_y_batch, [self.y_train_label[sample_idx]], axis=0)

            current_performance = previous_performance
            if len(np.unique(current_y_batch)) < 2:
                current_performance = self.random_performance 
            else:
                try:
                    clf = self._get_cloned_classifier()
                    clf.fit(current_x_batch, current_y_batch)
                    current_performance = self.metric_function(clf, self.x_test, self.y_test_label)
                except Exception:
                    current_performance = previous_performance 

            marginal_contributions[sample_idx] = current_performance - previous_performance
            previous_performance = current_performance

            denominator = abs(self.full_data_performance - self.random_performance)
            if denominator < 1e-6: 
                denominator = 1.0 
            
            if abs(current_performance - self.full_data_performance) <= self.truncation_tolerance * denominator :
                if self.full_data_performance != self.random_performance: 
                    truncation_counter += 1
                    if truncation_counter >= 5: 
                        break 
            else:
                truncation_counter = 0
        
        output_filename = f"tmc_iter_{generate_random_string(16)}.pkl"
        try:
            with open(os.path.join(self.output_dir, output_filename), 'wb') as f:
                pickle.dump((marginal_contributions, permutation_indices, num_points_processed, previous_performance), f)
        except IOError as e:
            logger.error(f"Calculator for {self.calculator_label_name}: Failed to save iteration result {output_filename}: {e}")

    def run_tmc_iterations(self, num_iterations: int):
        """Runs the specified number of TMC-Shapley iterations.
        
        Args:
            num_iterations (int): The number of TMC iterations to execute.
        """
        for i in range(num_iterations):
            self._run_one_tmc_iteration()


# --- Worker Function for Multiprocessing ---
def shapley_calculation_worker(
    worker_id_arg: int, # Renamed to avoid conflict with global worker_id
    label_index: int,
    label_name: str,
    iterations_for_worker: int,
    base_output_dir_timestamped: str, 
    x_train_data: np.ndarray,
    y_train_all_labels: np.ndarray,
    x_test_data: np.ndarray,
    y_test_all_labels: np.ndarray,
    c_value_for_label: float,
    metric_func_name: str, 
    lr_base_params_dict: dict,
    truncation_tol: float,
    global_random_seed_base: int
):
    """Worker function for parallel Shapley calculation for a specific label via TMC.

    Each worker executes an assigned number of TMC iterations for its designated label.

    Args:
        worker_id_arg (int): Conceptual worker ID for deriving unique random seeds.
        label_index (int): Index of the target label.
        label_name (str): Name of the target label.
        iterations_for_worker (int): Number of TMC iterations assigned to this worker for the label.
        base_output_dir_timestamped (str): Timestamped base directory for run outputs.
        x_train_data (np.ndarray): Training features.
        y_train_all_labels (np.ndarray): All training labels.
        x_test_data (np.ndarray): Test features.
        y_test_all_labels (np.ndarray): All test labels.
        c_value_for_label (float): Logistic Regression C parameter for this label.
        metric_func_name (str): Name of the performance metric (e.g., "auc").
        lr_base_params_dict (dict): Base parameters for Logistic Regression.
        truncation_tol (float): Truncation tolerance for TMC.
        global_random_seed_base (int): Global base random seed.
    """
    process_name = multiprocessing.current_process().name
    # Unique seed for this worker and label combination
    worker_seed = global_random_seed_base + worker_id_arg * 10000 + label_index * 100 
    np.random.seed(worker_seed)
    random.seed(worker_seed)
    
    # This log helps identify which process is doing what
    logger.debug(f"Process {process_name} (Conceptual Worker ID: {worker_id_arg}) for label '{label_name}' (Idx: {label_index}) started. Seed: {worker_seed}. Iterations assigned: {iterations_for_worker}.")

    y_train_current_label = y_train_all_labels[:, label_index].astype(int)
    y_test_current_label = y_test_all_labels[:, label_index].astype(int)

    if len(np.unique(y_train_current_label)) < 2:
        logger.warning(f"Process {process_name} for '{label_name}': Training data has only one class. Skipping Shapley calculation.")
        return
    if len(x_train_data) == 0:
        logger.warning(f"Process {process_name} for '{label_name}': No training data. Skipping.")
        return
        
    label_specific_output_dir = os.path.join(base_output_dir_timestamped, label_name)
    try:
        os.makedirs(label_specific_output_dir, exist_ok=True)
    except OSError as e:
        logger.error(f"Process {process_name} for '{label_name}': Failed to create output dir {label_specific_output_dir}: {e}. Aborting task for this label.")
        return

    if metric_func_name.lower() == 'auc':
        metric_function_to_use = calculate_auc_metric
    else:
        logger.error(f"Process {process_name} for '{label_name}': Unsupported metric '{metric_func_name}'. Aborting task.")
        return 

    try:
        # logger.debug(f"Process {process_name} for '{label_name}': Initializing ShapleyCalculator.")
        calculator = ShapleyCalculator(
            x_train=x_train_data,
            y_train_label=y_train_current_label,
            x_test=x_test_data,
            y_test_label=y_test_current_label,
            output_dir=label_specific_output_dir,
            c_value=c_value_for_label,
            metric_function=metric_function_to_use,
            lr_base_params=lr_base_params_dict,
            truncation_tolerance=truncation_tol,
            random_seed=worker_seed 
        )
        # logger.debug(f"Process {process_name} for '{label_name}': Starting {iterations_for_worker} TMC iterations.")
        calculator.run_tmc_iterations(iterations_for_worker)
        logger.debug(f"Process {process_name} (Conceptual Worker ID: {worker_id_arg}) for '{label_name}' completed {iterations_for_worker} iterations successfully.")
    except Exception as e:
        logger.error(f"Process {process_name} (Conceptual Worker ID: {worker_id_arg}) for '{label_name}' encountered an error: {e}", exc_info=True)


# --- Progress Monitoring Thread ---
def monitor_shapley_progress(
    output_dir_timestamped: str, 
    label_names_list: list, 
    total_iterations_per_label: int, # This is the grand total target for each label
    main_start_time: float, 
    stop_event: threading.Event):
    """Monitors Shapley calculation progress by counting result files.

    Periodically scans output directories for completed iteration files (.pkl) for each label,
    reporting progress, ETA, and processing rate via logger.

    Args:
        output_dir_timestamped (str): Timestamped base directory for the run's output.
        label_names_list (list): List of all label names.
        total_iterations_per_label (int): Target total iterations for each label.
        main_start_time (float): Timestamp of the main program start, for elapsed time calculation.
        stop_event (threading.Event): Event to signal thread termination.
    """
    headers = ["Label", "Progress (%)", "Files (Done/Target)", "Elapsed Time", "Est. ETA", "Rate (iter/hr)"]
    num_labels = len(label_names_list)
    
    first_print = True
    while not stop_event.is_set():
        progress_data = []
        overall_completed_files = 0
        # Target for overall is sum of targets for each label
        overall_expected_files_target = num_labels * total_iterations_per_label 

        current_time = time.time()
        elapsed_global = current_time - main_start_time
        
        for label_name in label_names_list:
            label_dir = os.path.join(output_dir_timestamped, label_name)
            completed_files_for_label = 0
            if os.path.exists(label_dir):
                try: # Add try-except for glob in case of permission issues during listing
                    completed_files_for_label = len(glob.glob(os.path.join(label_dir, 'tmc_iter_*.pkl')))
                except Exception as e_glob:
                    logger.warning(f"Monitor: Error accessing/globbing files in {label_dir}: {e_glob}")


            overall_completed_files += completed_files_for_label
            progress_percent = (completed_files_for_label / total_iterations_per_label) * 100 if total_iterations_per_label > 0 else 0
            
            eta_str = "--:--:--"
            rate_str = "0"

            # Calculate rate and ETA based on progress for THIS label using global time
            # This assumes work for a label is somewhat continuous, though distributed.
            if completed_files_for_label > 0 and elapsed_global > 10: 
                rate_per_sec_label = completed_files_for_label / elapsed_global
                if rate_per_sec_label > 1e-6:
                    rate_str = f"{rate_per_sec_label * 3600:.0f}"
                    remaining_files_label = total_iterations_per_label - completed_files_for_label
                    if remaining_files_label > 0:
                        eta_str = format_time_seconds(remaining_files_label / rate_per_sec_label)
                    else: # Label done
                        eta_str = "00:00:00" 
                else: 
                    eta_str = "inf" if total_iterations_per_label > completed_files_for_label else "00:00:00"


            progress_data.append([
                label_name,
                f"{progress_percent:.2f}",
                f"{completed_files_for_label}/{total_iterations_per_label}",
                format_time_seconds(elapsed_global), 
                eta_str,
                rate_str
            ])

        overall_progress_percent = (overall_completed_files / overall_expected_files_target) * 100 if overall_expected_files_target > 0 else 0
        overall_eta_str = "--:--:--"
        overall_rate_str = "0"
        if overall_completed_files > 0 and elapsed_global > 10:
            overall_rate_per_sec = overall_completed_files / elapsed_global
            if overall_rate_per_sec > 1e-6:
                overall_rate_str = f"{overall_rate_per_sec * 3600:.0f}"
                remaining_overall_files = overall_expected_files_target - overall_completed_files
                if remaining_overall_files > 0:
                    overall_eta_str = format_time_seconds(remaining_overall_files / overall_rate_per_sec)
                else: # All done
                    overall_eta_str = "00:00:00"
            else: # Rate too low
                overall_eta_str = "inf" if overall_expected_files_target > overall_completed_files else "00:00:00"

        
        summary_row = [
            "OVERALL",
            f"{overall_progress_percent:.2f}",
            f"{overall_completed_files}/{overall_expected_files_target}",
            format_time_seconds(elapsed_global),
            overall_eta_str,
            overall_rate_str
        ]
        
        table_output = tabulate([summary_row] + progress_data, headers=headers, tablefmt="pipe", stralign="center", numalign="center")
        
        # Log table sequentially. ANSI clearing is disabled for better file logs.
        # The logger's timestamp will differentiate updates.
        logger.info(f"Monitoring Update:\n{table_output}") # Still log to file/console via logger
            
        first_print = False
        
        stop_event.wait(30) # Update interval
    
    logger.info("Progress monitoring thread finished.")


# --- Aggregation Logic ---
def aggregate_label_tmc_results(label_specific_dir: str, expected_n_train: int = None) -> Tuple[Optional[np.ndarray], int, int]:
    """Aggregates TMC-Shapley results for a single label from .pkl files.

    Reads all `tmc_iter_*.pkl` files in the label-specific directory,
    accumulates and averages marginal contributions to derive final Shapley values.
    Performs consistency checks for the number of training samples (n_train).

    Args:
        label_specific_dir (str): Path to the directory with results for a specific label.
        expected_n_train (int, optional): Anticipated number of training samples for validation.
                                          If None, inferred from the first processed file. Defaults to None.

    Returns:
        tuple[Optional[np.ndarray], int, int]: 
            - shapley_values (np.ndarray | None): Calculated Shapley values, or None on failure.
            - iterations_processed (int): Number of iterations successfully processed.
            - n_train_for_this_label (int): Inferred/confirmed number of training samples.
    """
    label_name_log = os.path.basename(label_specific_dir)
    logger.info(f"Aggregating results for label: {label_name_log} from {label_specific_dir}")

    result_files = glob.glob(os.path.join(label_specific_dir, 'tmc_iter_*.pkl'))
    if not result_files:
        logger.warning(f"No .pkl result files found for label '{label_name_log}'. Skipping.")
        return None, 0, expected_n_train or 0

    shapley_sum_accumulator = None
    iterations_processed_count = 0
    n_train_current = expected_n_train

    for i, pkl_file_path in enumerate(result_files):
        try:
            with open(pkl_file_path, 'rb') as f:
                iter_data = pickle.load(f)
            
            if not isinstance(iter_data, tuple) or len(iter_data) < 3: # Expecting (marginal_contribs, perm_indices, num_points_processed, ...)
                logger.warning(f"Skipping invalid .pkl file (unexpected format): {pkl_file_path}")
                continue

            marginal_contribs = iter_data[0] 
            current_file_n_train = len(marginal_contribs)

            if n_train_current is None: 
                n_train_current = current_file_n_train
                logger.info(f"Inferred n_train = {n_train_current} for label '{label_name_log}' from {pkl_file_path}.")
                shapley_sum_accumulator = np.zeros(n_train_current, dtype=np.float64)
            elif current_file_n_train != n_train_current:
                logger.warning(f"Inconsistent n_train in {pkl_file_path} (expected {n_train_current}, got {current_file_n_train}). Skipping file.")
                continue
            
            if len(marginal_contribs) == n_train_current:
                 shapley_sum_accumulator += marginal_contribs
                 iterations_processed_count += 1
            else: # Should be caught by consistency check above, but as a safeguard
                logger.warning(f"Marginal contributions array length mismatch in {pkl_file_path}. Skipping file.")
                continue

        except FileNotFoundError: # Should not happen with glob
            logger.warning(f"File not found during aggregation: {pkl_file_path}")
        except (pickle.UnpicklingError, EOFError) as e_pickle:
            logger.warning(f"Error unpickling file {pkl_file_path}: {e_pickle}. Skipping.")
        except Exception as e:
            logger.error(f"Unexpected error processing file {pkl_file_path}: {e}", exc_info=True)

    if iterations_processed_count == 0 or shapley_sum_accumulator is None:
        logger.warning(f"No valid iterations processed for label '{label_name_log}'.")
        return None, 0, n_train_current or 0 # Return inferred n_train even if no files
    
    final_shapley_values = shapley_sum_accumulator / iterations_processed_count
    logger.info(f"Aggregation for label '{label_name_log}' complete. Processed {iterations_processed_count} iterations.")
    return final_shapley_values, iterations_processed_count, n_train_current

def run_results_aggregation(base_calculation_dir: str, output_aggregation_file: str, label_names: list, n_train_samples_global: int = None):
    """Main function to aggregate Shapley results from a completed calculation run.

    Iterates through each label's directory, calls `aggregate_label_tmc_results`,
    and saves all aggregated Shapley values into a final pickle file.

    Args:
        base_calculation_dir (str): Base directory of Shapley calculation results.
        output_aggregation_file (str): File path to save final aggregated results (.pkl).
        label_names (list): List of all label names.
        n_train_samples_global (int, optional): Global count of training samples for validation.
                                                 If None, inferred during aggregation. Defaults to None.

    Returns:
        bool: True if aggregation is successful and file is saved, False otherwise.
    """
    logger.info("=" * 60)
    logger.info(f"Starting Shapley results aggregation from: {base_calculation_dir}")
    logger.info(f"Aggregated results will be saved to: {output_aggregation_file}")

    if not os.path.isdir(base_calculation_dir):
        logger.error(f"Results directory not found: {base_calculation_dir}. Aborting aggregation.")
        return False

    all_aggregated_results = {}
    total_iterations_aggregated_overall = 0
    label_aggregation_summary = []

    inferred_n_train_from_first_label = n_train_samples_global

    for label_name in label_names:
        label_dir_path = os.path.join(base_calculation_dir, label_name)
        if not os.path.isdir(label_dir_path):
            logger.warning(f"Directory for label '{label_name}' not found at {label_dir_path}. Skipping.")
            label_aggregation_summary.append([label_name, "Not Found", 0, "N/A"])
            continue
        
        # If n_train is not globally known, it will be inferred per label.
        # If it was inferred from a previous label, use that for consistency check.
        current_expected_n_train = inferred_n_train_from_first_label
        shap_values, num_iters, n_train_for_this_label = aggregate_label_tmc_results(label_dir_path, current_expected_n_train)

        if shap_values is not None:
            all_aggregated_results[label_name] = shap_values
            total_iterations_aggregated_overall += num_iters
            label_aggregation_summary.append([label_name, "Success", num_iters, n_train_for_this_label])
            if inferred_n_train_from_first_label is None and n_train_for_this_label > 0 : # Store the first valid inference
                inferred_n_train_from_first_label = n_train_for_this_label
                logger.info(f"Global n_train inferred as {inferred_n_train_from_first_label} from label '{label_name}'.")
            elif inferred_n_train_from_first_label is not None and n_train_for_this_label > 0 and inferred_n_train_from_first_label != n_train_for_this_label:
                logger.warning(f"Inconsistent n_train ({n_train_for_this_label}) for label '{label_name}' compared to previously inferred value ({inferred_n_train_from_first_label}).")

        else:
            logger.warning(f"Failed to aggregate results for label '{label_name}'.")
            label_aggregation_summary.append([label_name, "Failed/No Data", num_iters, n_train_for_this_label if n_train_for_this_label > 0 else "N/A"])
            
    logger.info("--- Aggregation Summary per Label ---")
    logger.info(tabulate(label_aggregation_summary, headers=["Label", "Status", "Iterations Aggregated", "N_Train"], tablefmt="grid"))
    logger.info("--- End Aggregation Summary ---")

    if not all_aggregated_results:
        logger.error("No results were successfully aggregated. Output file will not be created.")
        return False

    try:
        aggregated_output_dir = os.path.dirname(output_aggregation_file)
        if aggregated_output_dir and not os.path.exists(aggregated_output_dir):
            os.makedirs(aggregated_output_dir, exist_ok=True)
            logger.info(f"Created directory for aggregated output: {aggregated_output_dir}")

        with open(output_aggregation_file, 'wb') as f_out:
            pickle.dump(all_aggregated_results, f_out, protocol=pickle.HIGHEST_PROTOCOL)
        
        logger.info(f"Successfully saved {len(all_aggregated_results)} aggregated Shapley results to: {output_aggregation_file}")
        logger.info(f"Total iterations contributing to the final aggregated file: {total_iterations_aggregated_overall}")
        return True
    except IOError as e:
        logger.error(f"Failed to save aggregated results to {output_aggregation_file}: {e}")
        return False
    except Exception as e:
        logger.error(f"An unexpected error occurred while saving aggregated results: {e}", exc_info=True)
        return False

# --- Main Execution ---
def main():
    """Main function to orchestrate Shapley value calculation and aggregation."""
    # Main execution function for the script.
    # Responsibilities include:
    # 1. Parsing command-line arguments.
    # 2. Configuring the logger.
    # 3. Determining whether to skip the calculation or aggregation phases based on arguments.
    # 4. (If calculation is performed) Loading training and test datasets.
    # 5. (If calculation is performed) Determining label names and their corresponding C values for Logistic Regression.
    # 6. (If calculation is performed) Initializing a multiprocessing pool and distributing Shapley calculation tasks to worker processes.
    # 7. (If calculation is performed) Launching a progress monitoring thread.
    # 8. (If aggregation is performed) Invoking the `run_results_aggregation` function to consolidate results.
    parser = argparse.ArgumentParser(description="Unified TMC-Shapley Value Calculator using Logistic Regression.",
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    
    data_group = parser.add_argument_group("Data Input Arguments")
    data_group.add_argument("--train-features-path", type=str, required=False, help="Path to training features (.npy file). Required if not skipping calculation.")
    data_group.add_argument("--train-labels-path", type=str, required=False, help="Path to training labels (.npy file, samples x labels). Required if not skipping calculation.")
    data_group.add_argument("--test-features-path", type=str, required=False, help="Path to test features (.npy file). Required if not skipping calculation.")
    data_group.add_argument("--test-labels-path", type=str, required=False, help="Path to test labels (.npy file, samples x labels). Required if not skipping calculation.")

    output_group = parser.add_argument_group("Output Configuration Arguments")
    output_group.add_argument("--base-output-dir", type=str, default="./shapley_runs_output", help="Base directory for all outputs.")
    output_group.add_argument("--run-name-prefix", type=str, default="shapley_run", help="Prefix for the timestamped run directory.")
    
    calc_group = parser.add_argument_group("Shapley Calculation Parameters")
    calc_group.add_argument("--num-workers", type=int, default=max(1, multiprocessing.cpu_count() // 2), help="Number of worker processes.")
    calc_group.add_argument("--iterations-per-label", type=int, default=1000, help="Total TMC iterations targeted for each label.")
    calc_group.add_argument("--truncation-tolerance", type=float, default=0.01, help="Tolerance for TMC truncation.")
    calc_group.add_argument("--random-seed", type=int, default=42, help="Base random seed.")
    calc_group.add_argument("--label-names", type=str, default=None, help="Comma-separated label names. If None, inferred or generated.")
    calc_group.add_argument("--c-values", type=str, default="1.0", help="Comma-separated C values for Logistic Regression (one per label, or one for all).")
    calc_group.add_argument("--metric", type=str, default="auc", choices=["auc"], help="Performance metric.")

    lr_group = parser.add_argument_group("Logistic Regression Parameters")
    lr_group.add_argument("--lr-max-iter", type=int, default=100, help="Max iterations for Logistic Regression.")
    lr_group.add_argument("--lr-solver", type=str, default="liblinear", help="Solver for Logistic Regression.")
    lr_group.add_argument("--lr-penalty", type=str, default="l2", choices=["l1", "l2"], help="Penalty for Logistic Regression.")
    
    mode_group = parser.add_argument_group("Operational Mode Arguments")
    mode_group.add_argument("--skip-calculation", action="store_true", help="Skip Shapley calculation. Requires --aggregate-only-dir if aggregation is desired.")
    mode_group.add_argument("--skip-aggregation", action="store_true", help="Skip final results aggregation.")
    mode_group.add_argument("--aggregate-only-dir", type=str, default=None, help="Path to an existing run's output directory to perform only aggregation.")

    log_group = parser.add_argument_group("Logging Arguments")
    log_group.add_argument("--log-level", type=str, default="INFO", choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], help="Logging level.")
    
    args = parser.parse_args()

    if args.aggregate_only_dir:
        if not args.skip_calculation:
            print("Info: --aggregate-only-dir is set. Forcing --skip-calculation.")
        args.skip_calculation = True 
        timestamped_run_output_dir = args.aggregate_only_dir
        if not os.path.isdir(timestamped_run_output_dir):
            print(f"Error: --aggregate-only-dir does not exist: {timestamped_run_output_dir}")
            sys.exit(1)
        log_file_path = os.path.join(timestamped_run_output_dir, f"{args.run_name_prefix}_aggregation_log_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log")
    else:
        if not args.skip_calculation: # Check if data paths are provided for calculation
            required_paths_for_calc = [args.train_features_path, args.train_labels_path, args.test_features_path, args.test_labels_path]
            if any(p is None for p in required_paths_for_calc):
                parser.error("All data input arguments (--train-features-path, --train-labels-path, --test-features-path, --test-labels-path) are required if not skipping calculation or using --aggregate-only-dir.")

        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        run_dir_name = f"{args.run_name_prefix}_{timestamp}"
        # Create base output dir first if it doesn't exist
        if not os.path.exists(args.base_output_dir):
            os.makedirs(args.base_output_dir, exist_ok=True)
        timestamped_run_output_dir = os.path.join(args.base_output_dir, run_dir_name)
        os.makedirs(timestamped_run_output_dir, exist_ok=True)
        log_file_path = os.path.join(timestamped_run_output_dir, f"{args.run_name_prefix}_main_run_log.log")

    setup_logger(args.log_level, log_file_path) # Setup logger ASAP
    logger.info("=" * 80)
    logger.info(f"Starting Unified Shapley Calculator. Run output directory: {timestamped_run_output_dir}")
    logger.info(f"Full Command: {' '.join(sys.argv)}")
    logger.info(f"Parsed Arguments: {vars(args)}") # Use vars(args) for a dict-like view
    logger.info("=" * 80)

    np.random.seed(args.random_seed)
    random.seed(args.random_seed)
    
    x_train, y_train_all, x_test, y_test_all = None, None, None, None
    n_samples_train, n_labels = None, None

    if not args.skip_calculation:
        logger.info("Loading data for calculation...")
        try:
            x_train = np.load(args.train_features_path)
            y_train_all = np.load(args.train_labels_path)
            x_test = np.load(args.test_features_path)
            y_test_all = np.load(args.test_labels_path)
            logger.info(f"  Train features: {x_train.shape}, Train labels: {y_train_all.shape}")
            logger.info(f"  Test features: {x_test.shape}, Test labels: {y_test_all.shape}")

            if x_train.shape[0] != y_train_all.shape[0] or x_test.shape[0] != y_test_all.shape[0]:
                logger.error("Sample count mismatch between features and labels. Aborting.")
                sys.exit(1)
            if x_train.shape[1] != x_test.shape[1]:
                 logger.warning(f"Feature count mismatch: Train ({x_train.shape[1]}) vs Test ({x_test.shape[1]}).")

            n_samples_train = x_train.shape[0]
            n_labels = y_train_all.shape[1]
            if n_labels == 0:
                logger.error("No labels found in training data (y_train_all.shape[1] is 0). Aborting.")
                sys.exit(1)


        except FileNotFoundError as e:
            logger.error(f"Data loading error: {e}. Aborting.")
            sys.exit(1)
        except Exception as e:
            logger.error(f"Unexpected data loading error: {e}. Aborting.", exc_info=True)
            sys.exit(1)
    
    # Determine Label Names
    final_label_names = []
    if args.label_names:
        final_label_names = [name.strip() for name in args.label_names.split(',')]
        if n_labels is not None and len(final_label_names) != n_labels: # n_labels is known if calculation is happening
            logger.error(f"Provided label names count ({len(final_label_names)}) mismatches data labels ({n_labels}). Aborting.")
            sys.exit(1)
    elif n_labels is not None : # If data loaded for calculation
        final_label_names = [f"label_{i}" for i in range(n_labels)]
    elif args.aggregate_only_dir: # Infer for aggregation
        try:
            subdirs = [d for d in os.listdir(timestamped_run_output_dir) 
                       if os.path.isdir(os.path.join(timestamped_run_output_dir, d)) and 
                       not d.startswith('.') and not d.endswith(('_log', '.log')) and d != os.path.basename(log_file_path.split('.')[0])] # Avoid log file/dir
            if not subdirs:
                 logger.error(f"Cannot infer label names from {timestamped_run_output_dir} for aggregation. Use --label-names.")
                 sys.exit(1)
            final_label_names = sorted(subdirs)
            logger.info(f"Inferred label names for aggregation: {final_label_names}")
        except Exception as e:
            logger.error(f"Error inferring label names for aggregation: {e}. Use --label-names.", exc_info=True)
            sys.exit(1)
    else: # Fallback, should be caught by arg checks
        logger.error("Cannot determine label names. Provide --label-names or ensure data is loaded for calculation.")
        sys.exit(1)
    
    if not final_label_names:
        logger.error("No label names determined. This can happen if --skip-calculation is used without --aggregate-only-dir or --label-names. Aborting.")
        sys.exit(1)


    c_values_list_str = args.c_values.split(',')
    c_values_parsed = [float(c.strip()) for c in c_values_list_str]
    c_value_map = {}
    if len(c_values_parsed) == 1:
        c_value_map = {name: c_values_parsed[0] for name in final_label_names}
    elif len(c_values_parsed) == len(final_label_names):
        c_value_map = dict(zip(final_label_names, c_values_parsed))
    else:
        logger.error(f"C-values count ({len(c_values_parsed)}) must be 1 or match labels count ({len(final_label_names)}). Aborting.")
        sys.exit(1)
    logger.info(f"Using C-values: {c_value_map}")

    lr_base_config = {
        'random_state': args.random_seed, 'max_iter': args.lr_max_iter,
        'class_weight': 'balanced', 'solver': args.lr_solver,
        'penalty': args.lr_penalty, 'n_jobs': 1 
    }
    logger.info(f"Base Logistic Regression Params: {lr_base_config}")
    
    if not args.skip_calculation:
        logger.info("--- Starting Shapley Value Calculation Phase ---")
        calc_start_time = time.time()

        for label_name_iter in final_label_names: # Ensure directories exist
            os.makedirs(os.path.join(timestamped_run_output_dir, label_name_iter), exist_ok=True)
            
        tasks_for_pool = []
        # Distribute iterations for each label among workers.
        # Each conceptual worker (0 to num_workers-1) gets a portion of iterations for each label.
        for conceptual_worker_id in range(args.num_workers):
            # Determine how many iterations this conceptual worker is responsible for
            # This is a base assignment, will be summed up if a worker handles multiple "slots"
            # iterations_per_slot = args.iterations_per_label // args.num_workers
            # if conceptual_worker_id < (args.iterations_per_label % args.num_workers):
            #     iterations_per_slot += 1
            # if iterations_per_slot == 0: continue


            for lbl_idx, lbl_name in enumerate(final_label_names):
                # Determine the number of iterations this specific worker process instance will run for this specific label.
                # Each label has `args.iterations_per_label`. These are divided among `args.num_workers`.
                start_iter_for_worker = conceptual_worker_id * (args.iterations_per_label // args.num_workers)
                end_iter_for_worker = (conceptual_worker_id + 1) * (args.iterations_per_label // args.num_workers)
                
                # Distribute remainder iterations
                # Each of the first (remainder) workers gets one extra iteration
                remainder = args.iterations_per_label % args.num_workers
                if conceptual_worker_id < remainder:
                    start_iter_for_worker += conceptual_worker_id
                    end_iter_for_worker += conceptual_worker_id + 1
                else: # For workers after the remainder has been distributed
                    start_iter_for_worker += remainder
                    end_iter_for_worker += remainder
                
                num_iters_for_this_task = end_iter_for_worker - start_iter_for_worker

                if num_iters_for_this_task <= 0:
                    continue

                tasks_for_pool.append((
                    conceptual_worker_id, lbl_idx, lbl_name, num_iters_for_this_task,
                    timestamped_run_output_dir, x_train, y_train_all, x_test, y_test_all,
                    c_value_map[lbl_name], args.metric, lr_base_config.copy(),
                    args.truncation_tolerance, args.random_seed
                ))
        
        if not tasks_for_pool:
            logger.warning("No calculation tasks generated. Check iterations_per_label and num_workers configuration.")
        else:
            logger.info(f"Generated {len(tasks_for_pool)} tasks for the multiprocessing pool.")
            # Example log for one task:
            # task_example = tasks_for_pool[0]
            # logger.debug(f"Example Task: Worker {task_example[0]} for Label '{task_example[2]}' will run {task_example[3]} iterations.")

            stop_monitor_event = threading.Event()
            monitor_thread = threading.Thread(
                target=monitor_shapley_progress,
                args=(timestamped_run_output_dir, final_label_names, args.iterations_per_label, calc_start_time, stop_monitor_event),
                daemon=True, name="ProgressMonitor"
            )
            monitor_thread.start()
            logger.info("Progress monitoring thread initiated.")

            try:
                with multiprocessing.Pool(processes=args.num_workers) as pool:
                    pool.starmap(shapley_calculation_worker, tasks_for_pool)
                logger.info("All Shapley calculation tasks completed by the pool.")
            except Exception as e:
                logger.error(f"Multiprocessing pool error: {e}", exc_info=True)
            finally:
                logger.info("Signaling progress monitor to stop...")
                stop_monitor_event.set()
                monitor_thread.join(timeout=10) 
                if monitor_thread.is_alive(): logger.warning("Monitor thread did not terminate cleanly.")
                else: logger.info("Monitor thread stopped.")
        
        logger.info(f"Shapley Calculation Phase finished in {format_time_seconds(time.time() - calc_start_time)}.")
    else:
        logger.info("--- Shapley Value Calculation Phase SKIPPED ---")

    if not args.skip_aggregation:
        logger.info("--- Starting Results Aggregation Phase ---")
        agg_start_time = time.time()
        agg_output_file = os.path.join(timestamped_run_output_dir, f"{args.run_name_prefix}_aggregated_shapley_values.pkl")
        
        # Pass n_samples_train if available (i.e., if calculation was not skipped)
        # Otherwise, it will be inferred during aggregation.
        n_train_for_agg = n_samples_train if not args.skip_calculation and n_samples_train is not None else None

        agg_success = run_results_aggregation(
            timestamped_run_output_dir, agg_output_file, final_label_names, n_train_for_agg
        )
        
        if agg_success: logger.info(f"Aggregation Phase completed in {format_time_seconds(time.time() - agg_start_time)}.")
        else: logger.error(f"Aggregation Phase FAILED. Elapsed: {format_time_seconds(time.time() - agg_start_time)}.")
    else:
        logger.info("--- Results Aggregation Phase SKIPPED ---")

    logger.info("=" * 80)
    logger.info(f"Unified Shapley Calculator script execution finished. Main output directory: {timestamped_run_output_dir}")
    logger.info("=" * 80)

if __name__ == '__main__':
    main() 