import math
import os

import spikeforest as sf
import spikeinterface.full as si
import spikeinterface.qualitymetrics as qm
import numpy as np
import json

import torch
import torch.nn.functional as F
from matplotlib import pyplot as plt
from scipy.interpolate import CubicSpline, interp1d
from spikeinterface.preprocessing import bandpass_filter, whiten, common_reference, remove_artifacts, \
    highpass_spatial_filter, highpass_filter, phase_shift, detect_bad_channels
from spikeinterface.sortingcomponents.motion import interpolate_motion
from spikeinterface.sortingcomponents.peak_detection import detect_peaks
from spikeinterface.extractors import NumpySorting
from spikeinterface.comparison import compare_sorter_to_ground_truth
from spikeinterface import get_noise_levels, compute_sparsity, ChannelSparsity
from spikeinterface.core import split_recording
import spikeinterface.widgets as sw
from spikeinterface.sortingcomponents.peak_localization import localize_peaks
from spikeinterface.sortingcomponents.peak_selection import select_peaks
from spikeinterface.widgets import plot_agreement_matrix, plot_confusion_matrix
from spikeinterface.preprocessing import get_motion_parameters_preset, get_motion_presets, correct_motion, compute_motion
from spikeinterface.sortingcomponents.motion import estimate_motion


def get_dataset_details(uri=None):
    all_recordings = sf.load_spikeforest_recordings(uri)
    for R in all_recordings:
        print('=========================================================')
        print(f'{R.study_set_name}/{R.study_name}/{R.recording_name}')
        print(f'Num. channels: {R.num_channels}')
        print(f'Duration (sec): {R.duration_sec}')
        print(f'Sampling frequency (Hz): {R.sampling_frequency}')
        print(f'Num. true units: {R.num_true_units}')
        print(f'Sorting true object: {json.dumps(R.sorting_true_object)}')
        print('')

def get_dataset(study_name,
                recording_name,
                uri=None,
                save_path='./resources/spikeforest_waveform',
                ms_before=1.0,
                ms_after=1.0,
                num_per_units=500):
    R = sf.load_spikeforest_recording(study_name=study_name,
                                      recording_name=recording_name,
                                      uri=uri)
    print(f'{R.study_set_name}/{R.study_name}/{R.recording_name}')
    print(f'Num. channels: {R.num_channels}')
    print(f'Duration (sec): {R.duration_sec}')
    print(f'Sampling frequency (Hz): {R.sampling_frequency}')
    print(f'Num. true units: {R.num_true_units}')
    print('')

    recording = R.get_recording_extractor()
    sorting_true = R.get_sorting_true_extractor()

    print(
        f'Recording extractor info: {recording.get_num_channels()} channels, {recording.get_sampling_frequency()} Hz, {recording.get_total_duration()} sec')
    print(
        f'Sorting extractor info: unit ids = {sorting_true.get_unit_ids()}, {sorting_true.get_sampling_frequency()} Hz')
    print('')

    waveforms_folder = save_path + '/' + recording_name
    analyzer = si.create_sorting_analyzer(
        sorting_true,
        recording,
        format='binary_folder',
        folder=waveforms_folder,
        overwrite=True,
        return_scaled=True,
    )
    analyzer.compute('random_spikes')
    analyzer.compute('waveforms', ms_before=ms_before, ms_after=ms_after, num_per_units=num_per_units)
    analyzer.compute("noise_levels")
    analyzer.compute("templates")

    return analyzer, analyzer.get_extension('waveforms'), recording_name

def calculate_snrs(analyzer):
    SNRs = qm.compute_snrs(analyzer)
    return SNRs

def generate_dataset(waveforms,
                     SNRs,
                     recording_name,
                     save_path='./resources/spikeforest_waveforms/data',
                     threshold=3.0):
    units = []
    for k, v in SNRs.items():
        if v >= threshold:
            units.append(k)
            print(k, v)

    data = []
    labels = []
    for unit_id in units:
        wfs = waveforms.get_waveforms_one_unit(unit_id)
        print(unit_id, ":", wfs.shape)
        data.append(wfs)
        labels.append(np.full(wfs.shape[0], unit_id, dtype=int))
    data = np.concatenate(data, axis=0)
    labels = np.concatenate(labels, axis=0)

    save_path = save_path + '/' + recording_name
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    np.save(save_path + '_data.npy', data)
    np.save(save_path + '_labels.npy', labels)
    return data, labels


def interpolate_spike_pytorch(data, target_length=121, mode='linear'):
    """
    data: Tensor of shape [batch_size, 60, channels]
    Returns: Tensor of shape [batch_size, 121, channels]
    """
    # PyTorch 的 interpolate 要求 shape 为 [B, C, T]
    data = data.permute(0, 2, 1)  # [B, C, 60]

    out = F.interpolate(data, size=target_length, mode=mode, align_corners=False)

    return out.permute(0, 2, 1)  # [B, 121, C]

def pad_with_minval_channel(x: torch.Tensor, target_channels: int) -> torch.Tensor:
    """
    Pad spike waveform tensor using the channel with the lowest valley (minimum value).

    Args:
        x: Tensor of shape [bs, time, channels]
        target_channels: Final desired number of channels (must be > x.shape[2])

    Returns:
        Tensor of shape [bs, time, target_channels]
    """
    bs, time, ch = x.shape
    padding_needed = target_channels - ch
    assert padding_needed > 0, "target_channels must be greater than input channels"

    # Step 1: Find the index of the channel with the lowest valley (min value) for each sample
    min_vals, _ = torch.min(x, dim=1)  # [bs, channels], min over time
    min_indices = torch.argmin(min_vals, dim=1)  # [bs], channel index with min value

    # Step 2: Extract the waveform of that channel and replicate it to pad
    pad_channels = torch.zeros((bs, time, padding_needed), device=x.device, dtype=x.dtype)
    for i in range(bs):
        selected_channel = x[i, :, min_indices[i]]  # [time]
        pad_channels[i] = selected_channel.unsqueeze(1).repeat(1, padding_needed)  # [time, padding_needed]

    # Step 3: Concatenate original + padded channels
    left = padding_needed // 2
    x_padded = torch.cat([pad_channels[:, :, :left], x, pad_channels[:, :, left:]], dim=2)  # [bs, time, target_channels]
    return x_padded

def cubic_spline_interpolate_spikes(x: torch.Tensor, target_len: int = 121) -> torch.Tensor:
    """
    Perform cubic spline interpolation on spike waveforms.

    Args:
        x: Input tensor of shape [bs, times, channels] (PyTorch tensor)
        target_len: Target number of time steps after interpolation

    Returns:
        Interpolated tensor of shape [bs, target_len, channels] (PyTorch tensor)
    """
    assert x.dim() == 3, "Input must be 3D tensor [bs, times, channels]"

    x_np = x.detach().cpu().numpy()  # Convert to numpy for scipy
    bs, times, channels = x_np.shape
    new_time = np.linspace(0, 1, target_len)
    old_time = np.linspace(0, 1, times)

    # Allocate output array
    out_np = np.zeros((bs, target_len, channels), dtype=np.float32)

    for b in range(bs):
        for c in range(channels):
            cs = CubicSpline(old_time, x_np[b, :, c])
            out_np[b, :, c] = cs(new_time)

    return torch.from_numpy(out_np).to(x.device)

def repeat_channels(x: torch.Tensor, repeat_factor: int, max_channels=11) -> torch.Tensor:
    """
    Repeat each channel 'repeat_factor' times along the channel dimension.

    Args:
        x: Tensor of shape [bs, time, channels]
        repeat_factor: How many times to repeat each channel

    Returns:
        Tensor of shape [bs, time, channels * repeat_factor]
    """
    bs, time, ch = x.shape
    # min_channel = get_channels_with_min_value(x)
    # repeat_times = max(0, repeat_factor * ch - max_channels)

    # Step 1: reshape to [bs, time, ch, 1]
    x = x.unsqueeze(-1)  # [bs, time, ch, 1]

    # Step 2: repeat last dimension
    x = x.repeat(1, 1, 1, repeat_factor)  # [bs, time, ch, repeat_factor]

    # Step 3: reshape back to [bs, time, ch * repeat_factor]
    x = x.reshape(bs, time, ch * repeat_factor)
    return x

def detect_spike(study_name='hybrid_static_tetrode',
                 recording_name='rec_4c_600s_12',
                 uri=None,
                 detect_threshold=5,
                 exclude_sweep_ms=1,
                 chunk_size=60000,
                 peaks=None,
                 verbose=False,
                 drift=False,
                 chunk_duration=200,
                 recording=None,
                 sorting_true=None,):
    if not recording:
        R = sf.load_spikeforest_recording(study_name=study_name,
                                          recording_name=recording_name,
                                          uri=uri)
        print(f'{R.study_set_name}/{R.study_name}/{R.recording_name}')
        if verbose:
            print(f'Num. channels: {R.num_channels}')
            print(f'Duration (sec): {R.duration_sec}')
            print(f'Sampling frequency (Hz): {R.sampling_frequency}')
            print(f'Num. true units: {R.num_true_units}')
        print('')

        recording = R.get_recording_extractor()
        sorting_true = R.get_sorting_true_extractor()

    if verbose:
        print(
            f'Recording extractor info: {recording.get_num_channels()} channels, {recording.get_sampling_frequency()} Hz, {recording.get_total_duration()} sec')
        print(
            f'Sorting extractor info: unit ids = {sorting_true.get_unit_ids()}, {sorting_true.get_sampling_frequency()} Hz')
        print('')
        total_peaks = 0
        for unit_id in sorting_true.get_unit_ids():
            st = sorting_true.get_unit_spike_train(unit_id=unit_id)
            total_peaks += len(st)
            print(f'Unit {unit_id}: {len(st)} events')
        print('')
        print(f'Total peaks: {total_peaks}')

    # Step 1: Preprocessing

    recording_filtered = bandpass_filter(recording, freq_min=300, freq_max=5000, dtype='float32')
    recording_filtered = si.notch_filter(recording_filtered, freq=60, dtype='float32')
    # recording_filtered.set_channel_gains(1)
    # recording_filtered.set_channel_offsets(0)
    # bad_channel_ids, channel_labels = detect_bad_channels(recording=recording_filtered, noisy_channel_threshold=4.5, dead_channel_threshold=4.5)
    # if len(bad_channel_ids) > 0:
    #     recording_filtered = recording_filtered.remove_channels(remove_channel_ids=bad_channel_ids)
    # recording_whitened = si.whiten(recording_filtered)

    # job_kwargs = {'detect_kwargs': {'detect_threshold': 8.0, 'exclude_sweep_ms': 0.2,  'method':'locally_exclusive_torch'},}
    #               # 'localize_peaks_kwargs': {'method': 'grid_convolution'}}
    # motion, motion_info = si.compute_motion(recording_filtered, preset='rigid_fast', output_motion_info=True, **job_kwargs)
    #
    # fig = plt.figure(figsize=(14, 8))
    # si.plot_motion_info(
    #     motion_info, recording_filtered,
    #     figure=fig,
    #     color_amplitude=True,
    #     amplitude_cmap="inferno",
    #     scatter_decimate=10,
    # )
    #
    # # interpolated_recording = interpolate_motion(recording_filtered, motion=motion, border_mode='force_extrapolate')
    # interpolated_recording = interpolate_motion(recording_filtered, motion=motion)
    # recording_filtered = interpolated_recording
    if peaks is None:
        if not drift:
            noise_levels = get_noise_levels(recording_filtered, return_scaled=True)
            print(noise_levels)
            # Step 2: Spike detection
            peaks = detect_peaks(recording_filtered, method='locally_exclusive_torch',
                                 detect_threshold=detect_threshold,
                                 exclude_sweep_ms=exclude_sweep_ms,
                                 chunk_size=chunk_size,
                                 n_jobs=4,
                                 progress_bar=True,
                                 noise_levels=noise_levels,
                                 peak_sign='neg',
                                 )
            # peaks = select_peaks(peaks=peaks, method='uniform',
            #                       n_peaks=10000, seed=42)
            # peak_locations = localize_peaks(recording=recording_filtered, peaks=peaks, method="monopolar_triangulation",
            #                                 radius_um=75.0,
            #                                 max_distance_um=150.0,)
            # print(peak_locations)
            peaks = np.array(
                peaks,
                dtype=[('sample_index', 'int64'), ('channel_index', 'int64'),
                       ('amplitude', 'float32'), ('segment_index', 'int64')]
            )

        else:
            total_duration = recording_filtered.get_total_duration()
            sampling_rate = recording_filtered.get_sampling_frequency()
            n_chunks = int(np.ceil(total_duration / chunk_duration))

            chunks = []
            for i in range(n_chunks):
                start = i * chunk_duration * sampling_rate
                end = min((i + 1) * chunk_duration * sampling_rate, total_duration * sampling_rate)
                chunk = recording_filtered.frame_slice(start_frame=start, end_frame=end)
                if end - start < 100:
                    continue
                chunks.append(chunk)
            all_peaks = []
            samples_per_chunk = int(chunk_duration * sampling_rate)

            for i, chunk in enumerate(chunks):
                print(f"Detecting chunk {i + 1}/{len(chunks)}...")
                # chunk = si.zscore(recording=chunk)
                noise_levels = get_noise_levels(chunk, return_scaled=False)
                peaks = detect_peaks(chunk, method='locally_exclusive_torch',
                                     detect_threshold=detect_threshold,
                                     exclude_sweep_ms=exclude_sweep_ms,
                                     # chunk_size=chunk_size,
                                     chunk_size=chunk_size,
                                     n_jobs=4,
                                     progress_bar=True,
                                     noise_levels=noise_levels,
                                     peak_sign='neg',)

                sample_offset = i * samples_per_chunk
                peaks = np.array(
                    peaks,
                    dtype=[('sample_index', 'int64'), ('channel_index', 'int64'),
                           ('amplitude', 'float32'), ('segment_index', 'int64')]
                )
                peaks['sample_index'] += sample_offset
                all_peaks.append(peaks)

            peaks = np.concatenate(all_peaks)
            # peaks = np.array(
            #     peaks,
            #     dtype=[('sample_index', 'int64'), ('channel_index', 'int64'),
            #            ('amplitude', 'float32'), ('segment_index', 'int64')]
            # )

    else:
        peaks_type = [('sample_index', 'int64'), ('channel_index', 'int64'),
                      ('amplitude', 'float32'), ('segment_index', 'int64')]
        target_peaks = np.empty(peaks.shape, dtype=peaks_type)
        target_peaks['sample_index'] = peaks['sample_index']
        target_peaks['channel_index'] = peaks['channel_index']
        target_peaks['amplitude'] = peaks['amplitude']
        target_peaks['segment_index'] = peaks['segment_index']
        peaks = target_peaks

    peaks = np.sort(peaks, order=['channel_index'])
    unit_spike_trains = {}
    for ch in np.unique(peaks['channel_index']):
        unit_peaks = peaks[peaks['channel_index'] == ch]
        spike_times = unit_peaks['sample_index']
        unit_spike_trains[ch] = spike_times

    # 创建 sorting 对象（以通道 ID 作为 unit ID）
    sorting = NumpySorting.from_unit_dict(unit_spike_trains, sampling_frequency=recording.get_sampling_frequency())

    return recording_filtered, sorting_true, sorting, peaks

def get_waveforms(recording, sorting, format='memory', save_path='', SNRs_threshold=3, neighbors=None):
    analyzer = si.create_sorting_analyzer(
        sorting,
        recording,
        format=format,
        overwrite=True,
        return_scaled=True,
    )
    # analyzer.compute('random_spikes', method="uniform", max_spikes_per_unit=100000, )
    # analyzer.compute('waveforms', ms_before=1.0, ms_after=1.0, chunk_duration="100s", n_jobs=4)
    # analyzer.compute("noise_levels")
    # analyzer.compute("templates", operators=["average", "median", "std"])
    # sparsity = compute_sparsity(analyzer, method="radius", radius_um=55)
    # analyzer = si.create_sorting_analyzer(
    #     sorting,
    #     recording,
    #     format=format,
    #     overwrite=True,
    #     return_scaled=True,
    #     sparsity=sparsity,
    # )
    analyzer.compute('random_spikes', method="uniform", max_spikes_per_unit=100000, )
    analyzer.compute('waveforms', ms_before=1.0, ms_after=1.0, chunk_duration="100s", n_jobs=4)
    analyzer.compute("noise_levels")
    analyzer.compute("templates", operators=["average", "median", "std"])
    sparsity = compute_sparsity(analyzer, method="best_channels", num_channels=11)
    analyzer = si.create_sorting_analyzer(
        sorting,
        recording,
        format=format,
        overwrite=True,
        return_scaled=True,
        sparsity=sparsity,
    )
    analyzer.compute('random_spikes', method="uniform", max_spikes_per_unit=100000, )
    analyzer.compute('waveforms', ms_before=1.0, ms_after=1.0, chunk_duration="100s", n_jobs=4)
    analyzer.compute("noise_levels")
    analyzer.compute("templates", operators=["average", "median", "std"])
    SNRs = qm.compute_snrs(analyzer)
    data = []
    labels = []
    units = []
    # for k, v in sorted(SNRs.items(), key=lambda x: x[1], reverse=True):
    for k, v in SNRs.items():
        if v >= SNRs_threshold:
            units.append(k)
            print(k, v)
    ext_wf = analyzer.get_extension("waveforms")
    for unit_id in units:
        wfs = ext_wf.get_waveforms_one_unit(unit_id)
        print(unit_id, ":", wfs.shape)
        data.append(wfs[:, :, :])
        labels.append(np.full(wfs.shape[0], unit_id, dtype=int))
    data = np.concatenate(data, axis=0)
    labels = np.concatenate(labels, axis=0)

    return data, labels

def get_result(pred_labels, peaks, recording):
    pred_labels = np.array(pred_labels)  # shape 应该与 peaks 等长
    # peaks_sorted = np.sort(peaks, order=['channel_index'])
    peaks_sorted = peaks
    # 检查对齐
    assert len(pred_labels) == len(peaks), "预测标签与peaks长度不一致"

    # 遍历预测的 unit ID（即聚类结果）
    unit_spike_trains = {}
    for unit_id in np.unique(pred_labels):
        indices = np.where(pred_labels == unit_id)[0]
        spike_times = np.array([peaks_sorted[i]['sample_index'] for i in indices])
        unit_spike_trains[int(unit_id)] = np.sort(spike_times)

    # 创建 sorting（使用聚类标签作为 unit ID）
    sorting_pred = NumpySorting.from_unit_dict(unit_spike_trains, sampling_frequency=recording.get_sampling_frequency())

    return sorting_pred

def analyze_result(sorting_true, sorting_pred, delta_time=0.4, match_score=0.1, exhaustive_gt=True, draw_matrix=False):

    comparison = compare_sorter_to_ground_truth(sorting_true,
                                                sorting_pred,
                                                delta_time=delta_time,
                                                exhaustive_gt=exhaustive_gt,
                                                match_score=match_score)
    # 输出评估结果
    # for index, row in comparison.get_performance().iterrows():
    #     print(row)
    # df = comparison.get_ordered_agreement_scores()
    # print(df.head(10))
    if draw_matrix:
        plot_confusion_matrix(comparison, count_text=True, figsize=(32, 32))  # 查看预测和GT的对应关系
        plot_agreement_matrix(comparison, count_text=True, figsize=(32, 32))
    return comparison

def preprocess_spikeforest_data(data,
                                normalize=True,
                                interpolate=90,
                                repeat=True,
                                repeat_times=3,
                                max_channels=11,
                                pad_len=None):
    data = torch.tensor(data, dtype=torch.float).detach()
    if normalize:
        data = _zscore_normalize(data)
    if interpolate:
        data = interpolate_spike_pytorch(data, mode='linear', target_length=interpolate)
    if repeat:
        if pad_len is not None:
            new_data = []
            for i in range(data.shape[0]):
                temp_data = data[i, :, :pad_len[i]].unsqueeze(0)
                # print(temp_data.shape)
                if temp_data.shape[-1] < max_channels:
                    repeat_times = (max_channels // temp_data.shape[2]) + 1
                    temp_data = repeat_channels(temp_data, repeat_times)
                # else:
                #     center = temp_data.shape[2] // 2
                #     left = center - 6 // 2
                #     right = left + 6
                #     temp_data = temp_data[:, :, left:right]
                #     temp_data = repeat_channels(temp_data, 2)
                center = temp_data.shape[2] // 2
                left = center - max_channels // 2
                right = left + max_channels
                temp_data = temp_data[:, :, left:right]
                new_data.append(temp_data)
            new_data = torch.cat(new_data, dim=0)
            return new_data

        if data.shape[2] < max_channels:
            repeat_times  = (max_channels // data.shape[2]) + 1
            residual = data.shape[2] * repeat_times - max_channels
            if residual > repeat_times + 1:
                pad_length = max_channels - data.shape[2] * (repeat_times - 1)
                data = repeat_channels(data, repeat_times - 1)
                data = data.permute(0, 2, 1)
                data = F.pad(data, pad=(0, 0, pad_length // 2, pad_length - (pad_length // 2)), mode='replicate')
                data = data.permute(0, 2, 1)
            else:
                data = repeat_channels(data, repeat_times)
            # data = repeat_channels(data, repeat_times)
        center = data.shape[2] // 2
        left = center - max_channels // 2
        right = left + max_channels
        data = data[:, :, left:right,]
    return data

def _zscore_normalize(data, dim=(0, 1)):
    mean = torch.mean(data, dim=dim, keepdim=True)
    std = torch.std(data, dim=dim, keepdim=True)
    data = (data - mean) / (std + 1e-6)

    return data * 1.0

def minmax_scale_to_range(x, min_val=-25, max_val=25, eps=1e-8):
    x_min = x.min(dim=-1, keepdim=True)[0]
    x_max = x.max(dim=-1, keepdim=True)[0]
    return (x - x_min) / (x_max - x_min + eps) * (max_val - min_val) + min_val


def cubic_interpolate_times(x: torch.Tensor, new_times: int) -> torch.Tensor:
    """
    对输入张量在 time 维度进行三次插值。

    参数:
        x: [bs, times, channels] 的输入 tensor
        new_times: 插值后的时间步数

    返回:
        [bs, new_times, channels] 的 tensor
    """
    bs, old_times, channels = x.shape
    x_np = x.detach().cpu().numpy()  # 转为 NumPy

    old_time_points = np.linspace(0, 1, old_times)
    new_time_points = np.linspace(0, 1, new_times)

    interpolated = np.zeros((bs, new_times, channels), dtype=np.float32)

    for b in range(bs):
        for c in range(channels):
            f = interp1d(old_time_points, x_np[b, :, c], kind='cubic')
            interpolated[b, :, c] = f(new_time_points)

    return torch.tensor(interpolated, dtype=x.dtype, device=x.device)

def interpolate_channels(x: torch.Tensor, new_channels: int, mode='linear') -> torch.Tensor:
    """
    对 channel 维度做插值，输入为 [bs, times, channels]
    """
    import numpy as np
    from scipy.interpolate import interp1d

    bs, times, old_channels = x.shape
    x_np = x.detach().cpu().numpy()

    old_ch_points = np.linspace(0, 1, old_channels)
    new_ch_points = np.linspace(0, 1, new_channels)

    out = np.zeros((bs, times, new_channels), dtype=np.float32)

    for b in range(bs):
        for t in range(times):
            f = interp1d(old_ch_points, x_np[b, t, :], kind=mode, fill_value="extrapolate")
            out[b, t, :] = f(new_ch_points)

    return torch.tensor(out, dtype=x.dtype, device=x.device)


def get_channels_with_min_value(x):
    min_vals_per_channel = x.min(dim=1).values
    min_channel_indices = min_vals_per_channel.argmin(dim=1)

    return min_channel_indices

def shift_channels(data, window_size=4):
    min_channel_indices = get_channels_with_min_value(data)
    truncation_size = window_size // 2
    new_data = []
    for i in range(data.shape[0]):
        # min_channel_index = math.floor(data.shape[2] / 2)
        min_channel_index = min_channel_indices[i]
        start = (min_channel_index - truncation_size)
        end = start + truncation_size
        if start < 0:
            start = 0
            end = truncation_size
        if end > data.shape[2]:
            end = data.shape[2]
            start = data.shape[2] - truncation_size

        new_data.append(data[i, :, start:end].unsqueeze(0))

    new_data = torch.cat(new_data, dim=0)
    return new_data


def align_label(sorting_pred, sorting_true):
    comparison = compare_sorter_to_ground_truth(
        sorting_true,
        sorting_pred,
        exhaustive_gt=False
    )

    matching_dict = comparison.get_matching()
    new_times = []
    new_labels = []
    true_labels = []

    for true_unit in sorting_true.get_unit_ids():
        if true_unit not in matching_dict:
            continue  # 没有匹配成功的跳过或用 placeholder（如 -1）

        matched_gt_unit = matching_dict[true_unit]
        spike_train = sorting_true.get_unit_spike_train(unit_id=true_unit)

        new_times.extend(spike_train)
        new_labels.extend([matched_gt_unit] * len(spike_train))

    return new_labels