import numpy as np
import pandas as pd
from sklearn.preprocessing import KBinsDiscretizer
from scipy.stats import gaussian_kde 
from sklearn.mixture import GaussianMixture
from scipy.ndimage import gaussian_filter1d
from scipy.ndimage import convolve1d
from scipy.signal.windows import triang

import torch


def get_lds_kernel_window(kernel, ks, sigma):
    """
    Generate kernel window for Label Distribution Smoothing (LDS).

    Args:
        kernel (str): Kernel type ('gaussian', 'triang', 'laplace').
        ks (int): Kernel size (should be odd).
        sigma (float): Sigma parameter for Gaussian or Laplace kernel.

    Returns:
        np.ndarray: Normalized kernel window.
    """
    assert kernel in ['gaussian', 'triang', 'laplace']
    half_ks = (ks - 1) // 2
    if kernel == 'gaussian':
        base_kernel = np.zeros(ks)
        base_kernel[half_ks] = 1.
        sigma = max(sigma, 1e-6)
        kernel_window = gaussian_filter1d(base_kernel, sigma=sigma)
        max_val = np.max(kernel_window)
        kernel_window = kernel_window / max(max_val, 1e-6)
    elif kernel == 'triang':
        kernel_window = triang(ks)
    else:
        laplace = lambda x: np.exp(-abs(x) / sigma) / (2. * sigma) if sigma > 0 else (1.0 if x == 0 else 0.0)
        kernel_window = np.array([laplace(x) for x in np.arange(-half_ks, half_ks + 1)])
        max_val = np.max(kernel_window)
        kernel_window = kernel_window / max(max_val, 1e-6)

    return kernel_window


def calculate_balanced_weights(labels_np, reweight='sqrt_inv',
                            binning_method='auto',
                            n_bins=None,
                            lds=False, lds_kernel='gaussian', lds_ks=5, lds_sigma=2):
    """
    Calculate weights specialized for continuous data.
    
    Args:
        labels_np (np.ndarray): 1D numpy array of training labels
        reweight (str): Reweighting strategy ('none', 'inverse', 'sqrt_inv')
        binning_method (str): 
            - 'auto': Automatically determine bin size based on data distribution
            - 'fixed_width': Use fixed width bins
            - 'quantile': Set bin boundaries to have equal sample counts
        n_bins (int): Number of bins. If None, automatically calculated
        lds (bool): Whether to apply LDS
        lds_kernel (str): LDS kernel type
        lds_ks (int): LDS kernel size
        lds_sigma (float): LDS sigma parameter
    """
    if reweight == 'none' and not lds:
        print("Reweight method is 'none', skipping weight calculation.")
        return None

    assert reweight in {'none', 'inverse', 'sqrt_inv'}
    if lds and reweight == 'none':
        reweight = 'sqrt_inv'
        print(f"LDS enabled, defaulting reweight method to '{reweight}'.")

    if n_bins is None:
        n_samples = len(labels_np)
        value_range = np.ptp(labels_np)
        
        iqr = np.percentile(labels_np, 75) - np.percentile(labels_np, 25)
        
        if iqr <= 1e-10:
            print(f"Warning: IQR is too small ({iqr}). Using default bin count.")
            n_bins = 10
        else:
            bin_width = 2 * iqr * (n_samples ** (-1/3))
            n_bins = max(10, min(100, int(np.ceil(value_range / bin_width))))
        
        print(f"Auto-calculated bin count: {n_bins}")

    if binning_method == 'quantile':
        bin_edges = np.percentile(labels_np, np.linspace(0, 100, n_bins + 1))
    else:
        bin_edges = np.linspace(np.min(labels_np), np.max(labels_np), n_bins + 1)

    bin_indices = np.digitize(labels_np, bin_edges) - 1
    bin_indices = np.clip(bin_indices, 0, n_bins - 1)

    bin_counts = np.bincount(bin_indices, minlength=n_bins)
    
    if reweight == 'sqrt_inv':
        bin_weights = np.sqrt(bin_counts)
    elif reweight == 'inverse':
        bin_weights = bin_counts
    else:
        bin_weights = bin_counts

    if lds:
        lds_kernel_window = get_lds_kernel_window(lds_kernel, lds_ks, lds_sigma)
        print(f'Applying LDS: [{lds_kernel.upper()}] ({lds_ks}/{lds_sigma})')
        bin_weights = convolve1d(bin_weights, weights=lds_kernel_window, mode='constant')

    sample_weights = 1.0 / (bin_weights[bin_indices] + 1e-6)
    sample_weights = np.clip(sample_weights, 1/1000, 1000)
    
    scaling = len(sample_weights) / np.sum(sample_weights)
    sample_weights *= scaling

    print(f"Balanced weights computed. Reweight: '{reweight}', LDS: {lds}. Shape: {sample_weights.shape}")
    return sample_weights.astype(np.float32)


def get_gmm(labels_np, n_components):
    """
    Fit Gaussian Mixture Model to target labels.

    Args:
        labels_np (np.ndarray): 1D numpy array of training labels.
        n_components (int): Number of GMM components.

    Returns:
        dict: GMM parameters dictionary ('means', 'weights', 'variances').
    """
    all_labels = labels_np.reshape(-1, 1)
    print(f"Fitting GMM with {n_components} components...")
    gmm = GaussianMixture(n_components=n_components, random_state=42, reg_covar=1e-6).fit(all_labels)
    gmm_dict = {'means': gmm.means_, 'weights': gmm.weights_, 'variances': gmm.covariances_}
    print("GMM fitting complete.")
    return gmm_dict


def get_shot_regions(y_train, n_bins=10, strategy='quantile'):
    """
    Determine shot regions (many, medium, few) based on target variable density.

    Args:
        y_train (pd.Series or np.ndarray): Target variable of training data.
        n_bins (int): Number of bins for discretization.
        strategy (str): KBinsDiscretizer strategy ('uniform', 'quantile', 'kmeans').

    Returns:
        dict: Dictionary mapping region names ('many', 'medium', 'few') to bin indices.
        KBinsDiscretizer: Fitted discretizer instance.
    """
    if isinstance(y_train, pd.Series):
        y_train_arr = y_train.values.reshape(-1, 1)
    else:
        y_train_arr = y_train.reshape(-1, 1)

    discretizer = KBinsDiscretizer(n_bins=n_bins, encode='ordinal', strategy=strategy, subsample=None)
    try:
         bin_assignments = discretizer.fit_transform(y_train_arr).flatten().astype(int)
    except ValueError as e:
        print(f"Warning: KBinsDiscretizer failed with strategy='{strategy}'. Trying 'uniform'. Error: {e}")
        discretizer = KBinsDiscretizer(n_bins=n_bins, encode='ordinal', strategy='uniform', subsample=None)
        bin_assignments = discretizer.fit_transform(y_train_arr).flatten().astype(int)

    bin_counts = np.bincount(bin_assignments, minlength=n_bins)

    non_zero_counts = bin_counts[bin_counts > 0]
    if len(non_zero_counts) < 3:
        print("Warning: Insufficient bin count variation to define 3 shot regions. Adjusting thresholds.")
        median_count = np.median(non_zero_counts) if len(non_zero_counts) > 0 else 0
        many_threshold = median_count
        few_threshold = median_count
    elif len(non_zero_counts) == 0:
         many_threshold, few_threshold = 0, 0
    else:
        few_threshold = np.percentile(non_zero_counts, 33.3)
        many_threshold = np.percentile(non_zero_counts, 66.6)
        if few_threshold == many_threshold:
           sorted_counts = np.sort(non_zero_counts)
           idx_few = len(sorted_counts) // 3
           idx_many = 2 * len(sorted_counts) // 3
           few_threshold = sorted_counts[idx_few]
           many_threshold = sorted_counts[idx_many]
           if few_threshold == many_threshold and len(sorted_counts) > 1:
               few_threshold = sorted_counts[0]
               many_threshold = sorted_counts[-1]

    regions = {'many': [], 'medium': [], 'few': []}
    for i, count in enumerate(bin_counts):
        if count == 0:
           continue
        elif count >= many_threshold:
            regions['many'].append(i)
        elif count <= few_threshold:
            regions['few'].append(i)
        else:
            regions['medium'].append(i)

    if not regions['many'] and regions['medium']:
        regions['many'] = regions['medium']
        regions['medium'] = []
    if not regions['few'] and regions['medium']:
         regions['few'] = regions['medium']
         regions['medium'] = []

    print(f"Bin counts: {bin_counts}")
    print(f"Thresholds (count): Few <= {few_threshold:.2f} < Medium < {many_threshold:.2f} <= Many")
    print(f"Bins by region: Many={regions['many']}, Medium={regions['medium']}, Few={regions['few']}")

    return regions, discretizer

def assign_data_to_regions(y, discretizer, regions):
    """Assign data points to many/medium/few regions based on y values."""
    if isinstance(y, pd.Series):
        y_arr = y.values.reshape(-1, 1)
    else:
        y_arr = y.reshape(-1, 1)

    bin_assignments = discretizer.transform(y_arr).flatten().astype(int)

    region_indices = {'many': [], 'medium': [], 'few': []}
    all_indices = []
    for i, bin_idx in enumerate(bin_assignments):
        assigned = False
        for region_name, region_bins in regions.items():
            if bin_idx in region_bins:
                region_indices[region_name].append(i)
                assigned = True
                break
        if assigned:
            all_indices.append(i)

    return region_indices, all_indices


def get_shot_regions_kde(y_train, bw_method=None):
    """
    Determine shot regions (many, medium, few) based on KDE of target variable.

    Args:
        y_train (pd.Series or np.ndarray): Target variable of training data.
        bw_method (str, scalar or callable, optional): KDE bandwidth calculation method.

    Returns:
        tuple: (kde, density_thresholds) where kde is the fitted KDE object and
               density_thresholds is a dict with 'few' and 'many' density thresholds.
    """
    if isinstance(y_train, pd.Series):
        y_train_arr = y_train.values.flatten()
    else:
        y_train_arr = np.asarray(y_train).flatten()

    if len(np.unique(y_train_arr)) < 2:
         print("Warning: KDE region definition - too few unique values in y_train. KDE may be unstable.")
         if len(y_train_arr) == 0:
              raise ValueError("y_train data is empty for KDE region definition.")

    try:
        kde = gaussian_kde(y_train_arr, bw_method=bw_method)
    except ValueError as e:
        print(f"Error during KDE training: {e}. Check data.")
        try:
            print("Retrying KDE with default bandwidth...")
            kde = gaussian_kde(y_train_arr)
        except Exception as final_e:
             print(f"KDE training final failure: {final_e}")
             raise final_e

    train_densities = kde(y_train_arr)

    if len(train_densities) < 3:
        print("Warning: Insufficient training data points to define density thresholds.")
        few_threshold = -np.inf
        many_threshold = np.inf
    else:
        few_threshold = np.percentile(train_densities, 33.3)
        many_threshold = np.percentile(train_densities, 66.6)

        if few_threshold == many_threshold:
            print("Warning: KDE density thresholds (33/66 percentile) are identical. Data distribution may be highly concentrated.")
            min_dens, max_dens = np.min(train_densities), np.max(train_densities)
            if min_dens < max_dens:
                 few_threshold = min_dens + (max_dens - min_dens) * 0.1
                 many_threshold = max_dens - (max_dens - min_dens) * 0.1
                 if few_threshold >= many_threshold:
                      few_threshold = many_threshold = np.median(train_densities)

    density_thresholds = {'few': few_threshold, 'many': many_threshold}

    print(f"KDE density thresholds: Few <= {few_threshold:.4g} < Medium < {many_threshold:.4g} <= Many")

    return kde, density_thresholds

def assign_data_to_regions_kde(y, kde, density_thresholds):
    """
    Assign data points to many/medium/few regions based on KDE density.

    Args:
        y (pd.Series or np.ndarray): Target variable values to assign regions.
        kde (gaussian_kde): Fitted KDE object.
        density_thresholds (dict): Dict containing 'few' and 'many' density thresholds.

    Returns:
        dict: Dictionary mapping region names to lists of indices.
        list: List of all assigned indices.
    """
    if isinstance(y, pd.Series):
        y_arr = y.values.flatten()
    else:
        y_arr = np.asarray(y).flatten()

    if len(y_arr) == 0:
        return {'many': [], 'medium': [], 'few': []}, []

    try:
        point_densities = kde(y_arr)
    except ValueError as e:
        print(f"Error evaluating KDE density on data points: {e}. Check input data.")
        return {'many': [], 'medium': [], 'few': []}, []

    few_thresh = density_thresholds['few']
    many_thresh = density_thresholds['many']

    region_indices = {'many': [], 'medium': [], 'few': []}
    all_indices = []

    for i, density in enumerate(point_densities):
        assigned = False
        if density >= many_thresh:
            region_indices['many'].append(i)
            assigned = True
        elif density <= few_thresh:
             if few_thresh == many_thresh and density == few_thresh:
                 if many_thresh > few_thresh:
                      region_indices['few'].append(i)
                      assigned = True
                 else:
                      region_indices['medium'].append(i)
                      assigned = True
             elif many_thresh > few_thresh:
                 region_indices['few'].append(i)
                 assigned = True
        else:
             if many_thresh > few_thresh or density == few_thresh:
                 region_indices['medium'].append(i)
                 assigned = True

        if assigned:
            all_indices.append(i)

    return region_indices, all_indices
