#!/usr/bin/env python

import numpy as np

from scipy.interpolate import griddata, Rbf, RBFInterpolator
from scipy.signal import find_peaks

# needed to filter stuff
from model import RecordEventType



def plot_neuron_activation_contour(ax, coords, activation_times,
                                   method='rbf', resolution=100, vmin=None, vmax=None,
                                   nlevels=10, cmap='GnBu', xlims=None, ylims=None, xpad=0.1, ypad=0.1):
    """
    Universal contour plotting for neuron activations

    Args:
        coords: Nx2 array of (x, y) coordinates
        activation_times: N array of activation times
        ax: matplotlib axis
        method: 'rbf', 'griddata_linear', 'griddata_cubic', or 'nearest'
        resolution: resolution of interpolation grid
        show_neurons: whether to overlay neuron positions
    """
    x = coords[:, 0]
    y = coords[:, 1]

    # Set color limits
    if vmin is None:
        vmin = np.min(activation_times)
    if vmax is None:
        vmax = np.max(activation_times)

    # Create interpolation grid
    if xlims is None:
        x_min, x_max = np.min(x), np.max(x)
    else:
        x_min, x_max = xlims

    if ylims is None:
        y_min, y_max = np.min(y), np.max(y)
    else:
        y_min, y_max = ylims

    # Add some padding around the data
    x_pad = (x_max - x_min) * xpad
    y_pad = (y_max - y_min) * ypad

    xi = np.linspace(x_min - x_pad, x_max + x_pad, resolution)
    yi = np.linspace(y_min - y_pad, y_max + y_pad, resolution)
    Xi, Yi = np.meshgrid(xi, yi)

    # Interpolate based on method
    if method == 'rbf':
        rbf = RBFInterpolator(np.column_stack([x, y]), activation_times, kernel='multiquadric', epsilon=1.0, smoothing=1.0) # 0.1
        grid_points = np.column_stack([Xi.ravel(), Yi.ravel()])
        Zi = rbf(grid_points).reshape(Xi.shape)

    elif method == 'rbf_linear':
        rbf = RBFInterpolator(np.column_stack([x, y]), activation_times, kernel='linear', smoothing=1.0) # 0.1
        grid_points = np.column_stack([Xi.ravel(), Yi.ravel()])
        Zi = rbf(grid_points).reshape(Xi.shape)


    elif method == 'rbf_gaussian':
        rbf = RBFInterpolator(np.column_stack([x, y]), activation_times, kernel='gaussian', epsilon=1.0, smoothing=1.0) # 0.1
        grid_points = np.column_stack([Xi.ravel(), Yi.ravel()])
        Zi = rbf(grid_points).reshape(Xi.shape)

    elif method == 'rbf_legacy':
        rbf = Rbf(x, y, activation_times, function='multiquadric', smooth=1.0) # 0.1
        Zi = rbf(Xi, Yi)

    elif method == 'rbf_legacy_linear':
        rbf = Rbf(x, y, activation_times, function='linear', smooth=1.0) # 0.1
        Zi = rbf(Xi, Yi)

    elif method == 'rbf_legacy_gaussian':
        rbf = Rbf(x, y, activation_times, function='gaussian', smooth=1.0) # 0.1
        Zi = rbf(Xi, Yi)

    elif method == 'griddata_linear':
        Zi = griddata((x, y), activation_times, (Xi, Yi), method='linear', fill_value=np.nan)

    elif method == 'griddata_cubic':
        Zi = griddata((x, y), activation_times, (Xi, Yi), method='cubic', fill_value=np.nan)

    elif method == 'nearest':
        Zi = griddata((x, y), activation_times, (Xi, Yi), method='nearest')

    elif method == 'inverse_distance':
        Zi = inverse_distance_weighting(x, y, activation_times, Xi, Yi)

    else:
        raise ValueError(f"Unknown method: {method}")

    # Create contour plot
    contourf = ax.contourf(Xi, Yi, Zi, levels=nlevels, cmap=cmap, vmin=vmin, vmax=vmax)
    contour_lines = ax.contour(Xi, Yi, Zi, levels=nlevels, colors='black', linewidths=1.0, alpha=0.6)

    # Add contour labels
    ax.clabel(contour_lines, inline=True, fontsize=10, fmt='%.1f ms')


    # Set axis properties
    ax.set_xlim(x_min - x_pad, x_max + x_pad)
    ax.set_ylim(y_min - y_pad, y_max + y_pad)
    ax.set_aspect('equal')

    return contourf, contour_lines


def inverse_distance_weighting(x, y, values, Xi, Yi, power=2, radius=None):
    """
    Custom inverse distance weighting interpolation
    """
    Zi = np.zeros_like(Xi)

    for i in range(Xi.shape[0]):
        for j in range(Xi.shape[1]):
            xi, yi = Xi[i, j], Yi[i, j]

            # Calculate distances to all data points
            distances = np.sqrt((x - xi)**2 + (y - yi)**2)

            # Avoid division by zero
            distances = np.maximum(distances, 1e-10)

            # Apply radius cutoff if specified
            if radius is not None:
                mask = distances <= radius
                if not np.any(mask):
                    Zi[i, j] = np.nan
                    continue
                distances = distances[mask]
                values_masked = values[mask]
            else:
                values_masked = values

            # Inverse distance weights
            weights = 1 / (distances ** power)
            weights /= np.sum(weights)

            # Weighted average
            Zi[i, j] = np.sum(weights * values_masked)

    return Zi


def determine_spike_wave_count(activation_times, min_levels=7, max_levels=25, extra_levels=1):
    """
    Automatically determine appropriate number of contour levels based on data structure

    Args:
        activation_times: array of activation times
        min_levels: minimum number of levels to return
        max_levels: maximum number of levels to return
        extra_levels: add this number of levels to the detected gaps/peaks. The reason is
            that for smooth visualization, we typically want multiple contour lines
            per wave to show the gradient with each wave. the +base_levels adds this
            extra resolution.
            For instance, a value of +1 is more or less the natural wave number
            that is detected, while +3 can produce somewhat better results.

    Returns:
        int: suggested number of contour levels
    """
    if len(activation_times) <= 3:
        return min_levels

    # Sort times for analysis
    times_sorted = np.sort(activation_times)
    time_range = times_sorted[-1] - times_sorted[0]

    # all the same time
    if time_range == 0:
        return min_levels

    # Method 1: Gap-based detection
    # Look for natural gaps in the time series that suggest wave boundaries
    gaps = np.diff(times_sorted)
    gap_threshold = np.percentile(gaps, 75)  # Use 75th percentile as threshold

    # Count number of significant gaps (wave boundaries)
    gap_based_levels = np.sum(gaps > gap_threshold) + extra_levels

    # Method 2: Histogram-based detection
    # Use histogram to find natural clustering
    n_bins = min(20, len(activation_times) // 2)
    hist, _ = np.histogram(activation_times, bins=n_bins)

    # Find peaks in histogram (representing wave fronts)
    peaks, _ = find_peaks(hist, height=1, distance=1)
    histogram_levels = len(peaks) + extra_levels

    final_levels = int(0.5 * gap_based_levels + 0.5 * histogram_levels)
    final_levels = max(min_levels, min(final_levels, max_levels))

    print(f"Auto-determined {final_levels} contour levels from {len(activation_times)} neurons")
    print(f"  Time range: {time_range:.1f}, Significant gaps: {gap_based_levels}, Histogram peaks: {len(peaks) if 'peaks' in locals() else 'N/A'}")

    return final_levels


def make_contour_from_recorder(ax, neurons, recorder, coords, fill_in_times=False, event_filter=RecordEventType.SPIKE, method='rbf', **kwargs):

    n_neurons = len(neurons)
    xlims = np.min(coords[:, 0]), np.max(coords[:, 0])
    ylims = np.min(coords[:, 1]), np.max(coords[:, 1])
    activation_times_dict = recorder.get_first_activation_times(RecordEventType.SPIKE)

    if fill_in_times:
        start_idx = kwargs.pop('start_idx', 0)
        all_times = np.empty(n_neurons)
        all_times[:] = np.nan
        for neuron_id, timestamp in activation_times_dict.items():
            all_times[neuron_id] = timestamp

        start_coord = coords[start_idx, :]
        for i in range(n_neurons):
            if not np.isnan(all_times[i]):
                continue

            n_coord = coords[i, :]
            d = np.sqrt(np.sum((n_coord - start_coord)**2))
            candidates = []
            for j in neurons[i].nbrs:
                nbr_i = neurons[j].id
                if np.isnan(all_times[nbr_i]):
                    continue

                # dist = np.sqrt(np.sum(([neurons[nbr_i].coord.x, neurons[nbr_i].coord.y] - start_coord)**2))
                # if dist <= d:
                #     continue
                candidates.append(all_times[nbr_i])

            if len(candidates) > 0:
                all_times[i] = np.min(candidates)

        # fill in wholes
        change = True
        while change:
            change = False
            for i in range(n_neurons):
                if not np.isnan(all_times[i]):
                    continue
                candidates = []
                for j in neurons[i].nbrs:
                    nbr_i = neurons[j].id
                    if np.isnan(all_times[nbr_i]):
                        continue
                    candidates.append(all_times[nbr_i])
                if len(candidates) == len(neurons[i].nbrs) and len(candidates) > 0:
                    all_times[i] = np.min(candidates)
                    change = True

        for i in range(n_neurons):
            if np.isnan(all_times[i]):
                all_times[i] = np.max(all_times[~np.isnan(all_times)])

    else:
        neuron_ids       = []
        ncoords          = []
        all_times        = []
        for neuron_id, timestamp in activation_times_dict.items():
            neuron_ids.append(neuron_id)
            ncoords.append(coords[neuron_id])
            all_times.append(timestamp)
        coords = np.array(ncoords)
        all_times = np.array(all_times)


    nlevels = kwargs.get('nlevels', 'auto')
    min_levels = kwargs.get('min_levels', 7)
    if nlevels == 'auto':
        nlevels = determine_spike_wave_count(all_times, min_levels=min_levels)
    kwargs.pop('nlevels', None)
    kwargs.pop('min_levels', None)

    plot_neuron_activation_contour(ax, coords, all_times, method=method, nlevels=nlevels, xlims=xlims, ylims=ylims, **kwargs)

