import numpy as np
from sklearn.metrics import mean_absolute_error
from src.utils.utils import get_shot_regions, assign_data_to_regions 
from src.utils.utils import get_shot_regions_kde, assign_data_to_regions_kde
from src.utils.deprecation import deprecated


def calculate_region_mae(y_true, y_pred, y_train_for_regions,
                         kde_bw_method=None,
                         region_definition_method='kde'
                        ):
    """
    Calculate MAE for overall and many/medium/few shot regions.
    Regions can be defined using KBinsDiscretizer ('bins') or KDE ('kde').

    Args:
        y_true (np.ndarray): Actual target values.
        y_pred (np.ndarray): Predicted target values.
        y_train_for_regions (np.ndarray): Training target values used for region definition.
        kde_bw_method (str, optional): KDE bandwidth method (for 'kde' method).
        region_definition_method (str): Region definition method ('kde' or 'bins').

    Returns:
        dict: Dictionary containing MAE for 'overall', 'many', 'medium', 'few'.
    """
    results = {}

    y_true = np.asarray(y_true).flatten()
    y_pred = np.asarray(y_pred).flatten()
    y_train_for_regions = np.asarray(y_train_for_regions).flatten()

    region_indices = None

    if region_definition_method == 'kde':
        print("Region definition method: KDE")
        try:
            kde_calculator, density_thresholds = get_shot_regions_kde(
                y_train_for_regions,
                bw_method=kde_bw_method
            )
            region_indices, all_indices = assign_data_to_regions_kde(
                y_true, kde_calculator, density_thresholds
            )
        except Exception as e:
            print(f"Error during KDE-based region definition/assignment: {e}. Skipping MAE calculation.")
            results['mae_overall'] = np.nan
            results['mae_many'] = np.nan
            results['mae_medium'] = np.nan
            results['mae_few'] = np.nan
            results['mae_overall_all_samples'] = mean_absolute_error(y_true, y_pred)
            return results

    else:
        raise ValueError(f"Unknown region_definition_method: {region_definition_method}")

    if all_indices is not None and len(all_indices) > 0:
        results['mae_overall'] = mean_absolute_error(y_true[all_indices], y_pred[all_indices])
    else:
        results['mae_overall'] = np.nan
        print("Warning: No data points were assigned to the defined regions.")

    if region_indices is not None:
        for region_name, indices in region_indices.items():
            if len(indices) > 0:
                mae = mean_absolute_error(y_true[indices], y_pred[indices])
                results[f'mae_{region_name}'] = mae
                print(f"Region '{region_name}': {len(indices)} samples, MAE = {mae:.4f}")
            else:
                results[f'mae_{region_name}'] = np.nan
                print(f"Region '{region_name}': 0 samples")
    else:
         results['mae_many'] = np.nan
         results['mae_medium'] = np.nan
         results['mae_few'] = np.nan

    results['mae_overall_all_samples'] = mean_absolute_error(y_true, y_pred)

    return results


def calculate_region_mae_with_thresholds(y_true, y_pred, fewshot_threshold, manyshot_threshold):
    """
    Calculate MAE for few/medium/many shot regions based on specified thresholds.

    Args:
        y_true (np.ndarray): Actual target values
        y_pred (np.ndarray): Predicted target values 
        fewshot_threshold (float): Threshold separating few shot and medium shot
        manyshot_threshold (float): Threshold separating medium shot and many shot

    Returns:
        dict: Dictionary containing MAE for 'overall', 'many', 'medium', 'few'
    """
    results = {}

    y_true = np.asarray(y_true).flatten()
    y_pred = np.asarray(y_pred).flatten()

    results['mae_overall'] = mean_absolute_error(y_true, y_pred)

    few_indices = y_true <= fewshot_threshold
    many_indices = y_true >= manyshot_threshold
    medium_indices = ~few_indices & ~many_indices

    if np.any(few_indices):
        results['mae_few'] = mean_absolute_error(y_true[few_indices], y_pred[few_indices])
        print(f"Few-shot region: {np.sum(few_indices)} samples, MAE = {results['mae_few']:.4f}")
    else:
        results['mae_few'] = np.nan
        print("Few-shot region: 0 samples")

    if np.any(medium_indices):
        results['mae_medium'] = mean_absolute_error(y_true[medium_indices], y_pred[medium_indices])
        print(f"Medium-shot region: {np.sum(medium_indices)} samples, MAE = {results['mae_medium']:.4f}")
    else:
        results['mae_medium'] = np.nan
        print("Medium-shot region: 0 samples")

    if np.any(many_indices):
        results['mae_many'] = mean_absolute_error(y_true[many_indices], y_pred[many_indices])
        print(f"Many-shot region: {np.sum(many_indices)} samples, MAE = {results['mae_many']:.4f}")
    else:
        results['mae_many'] = np.nan
        print("Many-shot region: 0 samples")

    return results

@deprecated(replacement="calculate_region_mae")
def calculate_region_mae_non_kde(y_true, y_pred, y_train_for_regions, n_bins=10, strategy='quantile'):
    """
    Calculate MAE for overall and many/medium/few shot regions.

    Args:
        y_true (np.ndarray): Actual target values.
        y_pred (np.ndarray): Predicted target values.
        y_train_for_regions (np.ndarray): Training target values used for region definition.
        n_bins (int): Number of bins for region definition.
        strategy (str): KBinsDiscretizer strategy.

    Returns:
        dict: Dictionary containing MAE for 'overall', 'many', 'medium', 'few'.
    """
    results = {}

    y_true = np.asarray(y_true).flatten()
    y_pred = np.asarray(y_pred).flatten()
    y_train_for_regions = np.asarray(y_train_for_regions).flatten()

    regions, discretizer = get_shot_regions(y_train_for_regions, n_bins=n_bins, strategy=strategy)

    region_indices, all_indices = assign_data_to_regions(y_true, discretizer, regions)

    if len(all_indices) > 0:
         results['mae_overall'] = mean_absolute_error(y_true[all_indices], y_pred[all_indices])
    else:
         results['mae_overall'] = np.nan

    for region_name, indices in region_indices.items():
        if len(indices) > 0:
            mae = mean_absolute_error(y_true[indices], y_pred[indices])
            results[f'mae_{region_name}'] = mae
            print(f"Region '{region_name}': {len(indices)} samples, MAE = {mae:.4f}")
        else:
            results[f'mae_{region_name}'] = np.nan
            print(f"Region '{region_name}': 0 samples")

    results['mae_overall_all_samples'] = mean_absolute_error(y_true, y_pred)

    return results
