import copy
import datetime
import enum
import os
import sys
import warnings
from abc import ABC, abstractmethod
from collections import deque
from contextlib import contextmanager

import datahugger
import fsspec
import h5py
import matplotlib
import numpy as np
import pandas as pd
import pims
from dandi.dandiapi import DandiAPIClient
from fsspec.implementations.cached import CachingFileSystem
from PIL import Image
from pynwb import NWBHDF5IO
from scipy.io import loadmat

import adaptive_latents.transformer
from adaptive_latents import CONFIG
from adaptive_latents.timed_data_source import ArrayWithTime
from adaptive_latents.utils import save_to_cache

DATA_BASE_PATH = CONFIG.dataset_path


class ModelOrganism(enum.Enum):
    FLY = 'Drosophila melanogaster'
    MONKEY = 'Macaca mulatta'
    RAT = 'Rattus rattus'
    MOUSE = 'Mus musculus'
    FINCH = 'Taeniopygia castanotis'
    FISH = 'Danio rerio'


class Dataset(ABC):
    neural_data: ArrayWithTime
    behavioral_data: ArrayWithTime

    @property
    @abstractmethod
    def doi(self):
        pass

    @property
    @abstractmethod
    def automatically_downloadable(self):
        pass

    @property
    @abstractmethod
    def model_organism(self) -> ModelOrganism:
        pass

    @abstractmethod
    def acquire(self, *args, **kwargs):
        pass


class DandiDataset(Dataset):
    # TODO: should name be DANDIDataset?
    automatically_downloadable = True

    @property
    @abstractmethod
    def dandiset_id(self):
        pass

    @property
    @abstractmethod
    def version_id(self):
        pass

    @contextmanager
    def acquire(self, asset_path):
        # https://pynwb.readthedocs.io/en/latest/tutorials/advanced_io/streaming.html
        with DandiAPIClient() as client:
            asset = client.get_dandiset(self.dandiset_id, version_id=self.version_id).get_asset_by_path(asset_path)
            s3_url = asset.get_content_url(follow_redirects=1, strip_query=True)

        fs = fsspec.filesystem("http")
        fs = CachingFileSystem(
            fs=fs,
            cache_storage=[DATA_BASE_PATH / "nwb_cache"],
        )

        with fs.open(s3_url, "rb") as f:
            with h5py.File(f) as file:
                fhan = NWBHDF5IO(file=file, load_namespaces=True)
                yield fhan


class Odoherty21Dataset(DandiDataset):
    doi = 'https://dandiarchive.org/dandiset/000129/draft/'
    model_organism = ModelOrganism.MONKEY
    dandiset_id = "000129"
    version_id = None

    dataset_base_path = DATA_BASE_PATH / "odoherty21"
    automatically_downloadable = True

    def __init__(self, bin_width=0.03, downsample_behavior=True, neural_lag=0, drop_third_coord=True, pos_rescale_factor=1, vel_rescale_factor=1, supress_warnings=CONFIG.supress_dandi_warnings):
        self.bin_width = bin_width
        self.downsample_behavior = downsample_behavior
        self.drop_third_coord = drop_third_coord
        self.neural_lag = neural_lag
        self.pos_rescale_factor = pos_rescale_factor
        self.vel_rescale_factor = vel_rescale_factor
        self.supress_warnings = supress_warnings
        assert self.neural_lag >= 0

        self.units, self.finger_pos, self.finger_vel, self.finger_t, A, bin_ends = self.construct()
        self.neural_data = ArrayWithTime(A, bin_ends)
        self.behavioral_data = ArrayWithTime(self.finger_pos, self.finger_t)

        self.beh_pos = ArrayWithTime(self.finger_pos, self.finger_t)
        self.beh_vel = ArrayWithTime(self.finger_vel, self.finger_t)
        self.beh_pos_vel = ArrayWithTime(np.hstack([self.finger_pos, self.finger_vel]), self.finger_t)

    def construct(self):
        # TODO: get the warnings to work again
        # with warnings.catch_warnings(record=True) as warning_list:
        with self.acquire("sub-Indy/sub-Indy_desc-train_behavior+ecephys.nwb") as fhan:
            ds = fhan.read()
            units = ds.units.to_dataframe()
            finger_pos = ds.processing['behavior'].data_interfaces['finger_pos'].data[:]
            finger_pos_t = np.arange(finger_pos.shape[0]) * ds.processing['behavior'].data_interfaces['finger_pos'].conversion
            finger_vel = ds.processing['behavior'].data_interfaces['finger_vel'].data[:]
            finger_vel_t = np.arange(finger_vel.shape[0]) * ds.processing['behavior'].data_interfaces['finger_vel'].conversion

            # for w in warning_list:
            #     if self.supress_warnings and "Ignoring cached namespace" in str(w):
            #         continue
            #     warnings.warn_explicit(message=w.message, category=w.category, filename=w.filename, lineno=w.lineno, source=w.source)

        start_time = units.iloc[0, 2].min()
        end_time = units.iloc[0, 2].max()
        bins = np.arange(start_time, end_time, self.bin_width)
        bin_ends = bins[1:]

        A = np.zeros(shape=(bins.shape[0] - 1, len(units)))

        for i, (_, row) in enumerate(units.iterrows()):
            A[:, i], _ = np.histogram(row['spike_times'], bins=bins)


        factor = 4
        if self.downsample_behavior:
            finger_pos = finger_pos[::factor]
            finger_pos_t = finger_pos_t[::factor]
            finger_vel = finger_vel[::factor]
            finger_vel_t = finger_vel_t[::factor]

        bin_ends = bin_ends + self.neural_lag
        assert (finger_pos_t == finger_vel_t).all()
        finger_t = finger_pos_t

        if self.drop_third_coord:
            finger_pos = finger_pos[:,:2]
            finger_vel = finger_vel[:,:2]

        finger_pos = finger_pos * self.pos_rescale_factor
        finger_vel = finger_vel * self.vel_rescale_factor


        return units, finger_pos, finger_vel, finger_t, A, bin_ends

    def plot_variances(self, ax):
        ax.hist(np.squeeze(self.neural_data.a).std(axis=0), bins=50, label='neural')
        for x in np.nanstd(np.squeeze(self.beh_pos.a), axis=0):
            ax.axvline(x, color='C1', label='pos')

        for x in np.nanstd(np.squeeze(self.beh_vel.a), axis=0):
            ax.axvline(x, color='C2', label='vel')

        ax.legend()
        ax.set_xlabel('variance')
        ax.set_ylabel('count')


class Schaffer23Datset(Dataset):
    # TODO: make this subclass DandiDataset
    doi = 'https://doi.org/10.6084/m9.figshare.23749074'
    model_organism = ModelOrganism.FLY
    dataset_base_path = DATA_BASE_PATH / 'schaffer23'
    automatically_downloadable = True
    sub_datasets = (
        '2019_06_28_fly2.nwb', '2019_07_01_fly2.nwb', '2019_08_07_fly2.nwb', '2019_08_14_fly1.nwb',
        '2019_08_14_fly2.nwb', '2019_08_14_fly3_2.nwb', '2019_08_20_fly2.nwb', '2019_08_20_fly3.nwb',
        '2019_10_02_fly2.nwb', '2019_10_10_fly3.nwb', '2019_10_14_fly2.nwb', '2019_10_14_fly3.nwb',
        '2019_10_14_fly4.nwb', '2019_10_18_fly2.nwb', '2019_10_18_fly3.nwb', '2019_10_21_fly1.nwb'
    )

    def __init__(self, sub_dataset_identifier=sub_datasets[0]):
        if isinstance(sub_dataset_identifier, int):
            sub_dataset_identifier = self.sub_datasets[sub_dataset_identifier]

        self.sub_dataset = sub_dataset_identifier
        A, beh, t, t = self.construct(sub_dataset_identifier)
        self.neural_data = ArrayWithTime(A,t)
        self.behavioral_data = ArrayWithTime(beh,t)

    def construct(self, sub_dataset_identifier):
        with self.acquire(sub_dataset_identifier) as fhan:
            file = fhan.read()
            A = file.processing["ophys"].data_interfaces["DfOverF"].roi_response_series['RoiResponseSeries'].data[:]
            beh = file.processing['behavioral state'].data_interfaces['behavioral state'].data[:]
            t = file.processing['behavioral state'].data_interfaces['behavioral state'].timestamps[:]
            # t = file.processing["behavior"].data_interfaces["ball_motion"].timestamps[:]
        return A, beh, t, t

    def acquire(self, sub_dataset_identifier):
        if len(list(self.dataset_base_path.glob("*.nwb"))) == 0:
            datahugger.get(self.doi, self.dataset_base_path)

        if sub_dataset_identifier is not None:
            return NWBHDF5IO(self.dataset_base_path / sub_dataset_identifier, mode="r", load_namespaces=True)


class Churchland10Dataset(DandiDataset):
    doi = 'https://doi.org/10.48324/dandi.000128/0.220113.0400'
    model_organism = ModelOrganism.MONKEY
    dandiset_id = '000128'
    version_id = '0.220113.0400'
    automatically_downloadable = True

    def __init__(self, bin_width=0.03):
        self.bin_width = bin_width
        neural_data, hand_position, nerual_t, hand_t = self.construct()
        self.neural_data = ArrayWithTime(neural_data, nerual_t)
        self.behavioral_data = ArrayWithTime(hand_position, hand_t)

    def construct(self,):
        with self.acquire('sub-Jenkins/sub-Jenkins_ses-full_desc-train_behavior+ecephys.nwb') as fhan:
            nwb_in = fhan.read()
            units = nwb_in.units.to_dataframe()
            hand_pos = np.array(nwb_in.processing['behavior'].data_interfaces['hand_pos'].data)
            hand_t = np.array(nwb_in.processing['behavior'].data_interfaces['hand_pos'].timestamps)

        bin_edges = np.arange(units.iloc[0, 2][0, 0], units.iloc[0, 2][-1, -1] + self.bin_width, self.bin_width)

        A = np.zeros((len(bin_edges) - 1, units.shape[0]))

        for i in range(units.shape[0]):
            A[:, i], _ = np.histogram(units.iloc[i, 1], bin_edges)

        recorded_intervals = units.iloc[0, 2]

        interval_to_start_from = 0

        def intersection(start1, stop1, start2, stop2):
            return max(min(stop1, stop2) - max(start1, start2), 0)

        for i in range(len(bin_edges) - 1):
            bin_start = bin_edges[i]
            bin_stop = bin_edges[i + 1]
            covered = 0
            for j in range(interval_to_start_from, recorded_intervals.shape[0]):
                interval_start, interval_stop = recorded_intervals[j]
                if bin_start > interval_stop:
                    interval_to_start_from += 1
                    continue
                if interval_start > bin_stop:
                    break
                covered += intersection(bin_start, bin_stop, interval_start, interval_stop)

            if covered / self.bin_width < .9:
                A[i, :] = np.nan

        bin_ends = bin_edges[1:]

        return A, hand_pos, bin_ends, hand_t


class TostadoMarcos24Dataset(DandiDataset):
    doi = 'https://dandiarchive.org/dandiset/001046/draft'
    dandiset_id = '001046'
    version_id = 'draft'
    automatically_downloadable = True
    model_organism = ModelOrganism.FINCH
    sub_datasets = ['27', '26', '28']

    def __init__(self, sub_dataset_identifier=sub_datasets[0], bin_size=0.03):
        if isinstance(sub_dataset_identifier, int):
            sub_dataset_identifier = self.sub_datasets[sub_dataset_identifier]
        self.sub_dataset = sub_dataset_identifier
        self.bin_size = bin_size
        self.tx, self.vocalizations, self.neural_data, self.behavioral_data = self.construct(self.sub_dataset)

    def construct(self, sub_dataset_identifier):
        with self.acquire(f"sub-Finch-z-r12r13-21-held-in-calib/sub-Finch-z-r12r13-21-held-in-calib_ses-202106{sub_dataset_identifier}.nwb") as fhan:
            nwb = fhan.read()
            # TODO: make it possible to pass around TimeSeries with the HDF5 dereferenced
            tx = nwb.acquisition['tx'].data[:]
            tx_t = nwb.acquisition['tx'].timestamps[:]
            vocalizations = ArrayWithTime.from_nwb_timeseries(nwb.acquisition['vocalizations'])
            trials = nwb.intervals['trials'].to_dataframe()


        # make FR matrix for neural data
        dt_s = np.diff(tx_t)
        dt = np.median(dt_s)
        assert dt_s.std() / dt < .0001

        bins = np.linspace(tx_t[0], tx_t[-1], int((tx_t[-1] - tx_t[0]) // self.bin_size) + 1)
        A = np.empty(shape=(bins.size - 1, tx.shape[1]))
        for i in range(A.shape[1]):
            counts, _ = np.histogram(tx_t[tx[:, i].nonzero()[0]], bins=bins)
            A[:, i] = counts
        t = np.convolve([.5, .5], bins, 'valid')
        neural_data = ArrayWithTime(A, t)

        # make spectrogram matrix
        times = []
        spectral_data = []
        for idx, row in trials.iterrows():
            times.extend(row['spectrogram_times'] + row['start_time'])
            spectral_data.extend(row['spectrogram_values'].T)

        spectral_data = np.array(spectral_data)
        times = np.array(times)
        behavioral_data = ArrayWithTime(spectral_data, times)

        return tx, vocalizations, neural_data, behavioral_data

    def play_audio(self):
        """
        Examples
        -------
        >>> d = TostadoMarcos24Dataset()
        >>> d.play_audio()
        <IPython.lib.display.Audio object>
        """

        import IPython.display as ipd
        x = self.vocalizations.flatten()
        t = self.vocalizations.t.flatten()
        return ipd.Audio(x, rate=round(1 / np.median(np.diff(t))))

    def plot_recalculated_spectrogram(self, ax):
        """
        Examples
        --------
        >>> import matplotlib.pyplot as plt
        >>> d = TostadoMarcos24Dataset()
        >>> fig, ax = plt.subplots()
        >>> d.plot_recalculated_spectrogram(ax)
        """
        import scipy.signal as ss

        x = self.vocalizations.flatten()
        t = self.vocalizations.t.flatten()

        dt = np.median(np.diff(t))
        Fs = 1 / dt

        window_length_in_s = .01
        window_length_in_samples = int(window_length_in_s // dt)
        window = ss.windows.tukey(window_length_in_samples)
        SFT = ss.ShortTimeFFT(win=window, hop=window_length_in_samples, fs=Fs)

        Sx = SFT.stft(x)

        N = len(t)
        # fig1, ax1 = plt.subplots(figsize=(6., 4.))  # enlarge plot a bit
        t_lo, t_hi = SFT.extent(N)[:2]  # time range of plot
        ax.set(xlabel=f"Time $t$ in seconds ({SFT.p_num(N)} slices, " +
                       rf"$\Delta t = {SFT.delta_t:g}\,$s)",
                ylabel=f"Freq. $f$ in Hz ({SFT.f_pts} bins, " +
                       rf"$\Delta f = {SFT.delta_f:g}\,$Hz)",
                xlim=(t_lo, t_hi))

        im1 = ax.imshow((abs(Sx)), origin='lower', aspect='auto',
                         extent=SFT.extent(N), cmap='viridis')

        for t0_, t1_ in [(t_lo, SFT.lower_border_end[0] * SFT.T),
                         (SFT.upper_border_begin(N)[0] * SFT.T, t_hi)]:
            ax.axvspan(t0_, t1_, color='w', linewidth=0, alpha=.2)

        for t_ in [0, N * SFT.T]:  # mark signal borders with vertical line:
            ax.axvline(t_, color='y', linestyle='--', alpha=0.5)

        fig = ax.get_figure()
        fig.colorbar(im1, label="Magnitude $|S_x(t, f)|$")
        fig.tight_layout()




class Nason20Dataset(Dataset):
    doi = 'https://doi.org/10.7302/wwya-5q86'
    directory_name = 'nason20'
    model_organism = ModelOrganism.MONKEY
    dataset_base_path = DATA_BASE_PATH / directory_name
    automatically_downloadable = False

    def __init__(self, bin_width=0.15):
        self.bin_width = bin_width
        a, beh, t, t = self.construct()
        self.neural_data = ArrayWithTime(a, t)
        self.behavioral_data = ArrayWithTime(beh, t)

    def acquire(self):
        file = self.dataset_base_path / 'OnlineTrainingData.mat'
        if not file.is_file():
            print(f"""\
Please manually download the OnlineTrainingData.mat file from {self.doi}.
Then put it in '{self.dataset_base_path}'.
""")
            raise FileNotFoundError()
        return loadmat(file, squeeze_me=True, simplify_cells=True)

    def construct(self):
        bin_width_in_ms = int(self.bin_width * 1000)

        mat = self.acquire()
        data = mat['OnlineTrainingData']
        n_channels = data[0]['SpikingBandPower'].shape[1]

        for i in range(len(data) - 1):
            assert data[i + 1]['ExperimentTime'][0] - data[i]['ExperimentTime'][-1] == 3

        A = []
        t = []
        beh = []
        for i, trial in enumerate(data):
            A_spacer = np.nan * np.zeros((3, n_channels))
            t_spacer = np.arange(1, 4) + trial['ExperimentTime'][-1]
            beh_spacer = t_spacer * np.nan
            if i == len(data) - 1:
                A_spacer = np.zeros((0, n_channels))
                t_spacer = []
                beh_spacer = []
            sub_A_spaced = np.vstack([trial['SpikingBandPower'], A_spacer])
            sub_t_spaced = np.hstack([trial['ExperimentTime'], t_spacer])
            sub_beh_spaced = np.hstack([trial['FingerAngle'], beh_spacer])
            A.append(sub_A_spaced)
            t.append(sub_t_spaced)
            beh.append(sub_beh_spaced)
        A = np.vstack(A)
        t = np.hstack(t) / 1000  # converts to seconds
        beh = np.hstack(beh)

        s = t > 1.260  # there's an early dead zone
        A, beh, t = A[s], beh[s], t[s]

        aug = np.column_stack([t, beh, A])
        binned_aug = aug[aug.shape[0] % bin_width_in_ms:, :].reshape((-1, bin_width_in_ms, aug.shape[1]))
        t = binned_aug[:, :, 0].max(axis=1)
        beh = np.nanmean(binned_aug[:, :, 1], axis=1)
        A = np.nanmean(binned_aug[:, :, 2:], axis=1)

        return A, beh, t, t


class Peyrache15Dataset(Dataset):
    doi = 'http://dx.doi.org/10.6080/K0G15XS1'
    model_organism = ModelOrganism.MOUSE
    dataset_base_path = DATA_BASE_PATH / 'peyrache15'
    automatically_downloadable = False
    sub_datasets = ("Mouse12-120806", "Mouse12-120807", "Mouse24-131216")

    def __init__(self, sub_dataset_identifier=sub_datasets[0], bin_width=0.03):
        if isinstance(sub_dataset_identifier, int):
            sub_dataset_identifier = self.sub_datasets[sub_dataset_identifier]
        self.sub_dataset = sub_dataset_identifier
        self.bin_width = bin_width
        A, raw_behavior, a_t, beh_t = self.construct(sub_dataset_identifier)
        self.neural_data = ArrayWithTime(A, a_t)
        self.behavioral_data = ArrayWithTime(raw_behavior, beh_t)

    def acquire(self, sub_dataset_identifier):
        if not (self.dataset_base_path / sub_dataset_identifier).is_dir():
            print(f"""\
Please download {sub_dataset_identifier} from {self.doi} and put it in {self.dataset_base_path}.
""")
            raise FileNotFoundError()

    def construct(self, sub_dataset_identifier):
        self.acquire(sub_dataset_identifier)

        @save_to_cache("peyrache15_data")
        def static_construct(sub_dataset_identifier, bin_width):
            def read_int_file(fname):
                with open(fname) as fhan:
                    ret = []
                    for line in fhan:
                        line = int(line.strip())
                        ret.append(line)
                    return ret

            shanks = []
            for n in range(30):
                shanks.append(os.path.isfile(self.dataset_base_path / sub_dataset_identifier / f"{sub_dataset_identifier}.clu.{n}"))

            assert not any(shanks[20:])
            shanks = np.nonzero(shanks)[0]

            sampling_rate = 20_000
            clusters_to_ignore = {0, 1}

            shank_datas = []
            cluster_mapping = {}  # this will be a bijective dictionary between the (shank, cluster) and unit_number (also nan entries)

            min_time = float("inf")
            max_time = 0
            used_columns = 0
            for shank in shanks:
                clusters = read_int_file(self.dataset_base_path / sub_dataset_identifier / f"{sub_dataset_identifier}.clu.{shank}")
                n_clusters = clusters[0]
                clusters = clusters[1:]

                # TODO: check if I should exclude the hash unit
                for cluster in np.unique(clusters):
                    if cluster not in clusters_to_ignore:
                        cluster_mapping[(shank, cluster)] = used_columns
                        used_columns += 1
                        cluster_mapping[cluster_mapping[(shank, cluster)]] = (shank, cluster)
                    else:
                        cluster_mapping[(shank, cluster)] = np.nan

                clusters = [cluster_mapping[(shank, c)] for c in clusters]
                times = read_int_file(self.dataset_base_path / sub_dataset_identifier / f"{sub_dataset_identifier}.res.{shank}")

                pairs = np.array([times, clusters]).T
                pairs = pairs[~np.isnan(pairs[:, 1]), :]

                if len(pairs):
                    pairs[:, 0] /= sampling_rate

                    min_time = min(min_time, pairs[:, 0].min())
                    max_time = max(max_time, pairs[:, 0].max())

                    shank_datas.append(pairs)

            bins = np.arange(min_time, max_time + bin_width, bin_width)
            bin_ends = bins[1:]
            A = np.zeros((len(bins) - 1, used_columns))

            for shank_data in shank_datas:
                max_lower_bound = 0
                last_time = 0
                for time, cluster in shank_data:
                    assert time >= last_time
                    while time > bins[max_lower_bound + 1]:
                        max_lower_bound += 1
                    A[max_lower_bound, int(cluster)] += 1
                    last_time = time

            with open(self.dataset_base_path / sub_dataset_identifier / f"{sub_dataset_identifier}.whl", "r") as fhan:
                coords = [[] for _ in range(4)]
                for line in fhan:
                    line = [float(x) for x in line[:-1].split("\t")]
                    for i in range(4):
                        coords[i].append(line[i])

            raw_behavior = np.array(coords).T

            sampling_rate = 39.06
            t = np.arange(raw_behavior.shape[0]) / sampling_rate

            raw_behavior[raw_behavior == -1] = np.nan

            return A, raw_behavior, bin_ends, t

        return static_construct(sub_dataset_identifier, self.bin_width)



# class Musall19Dataset(Dataset):
#     doi = 'https://doi.org/10.1038/s41593-019-0502-4'
#     model_organism = ModelOrganism.MOUSE
#     dataset_base_path = DATA_BASE_PATH / 'musall19'
#     inner_data_path = dataset_base_path / "their_data/2pData/Animals/mSM49/SpatialDisc/30-Jul-2018"
#     automatically_downloadable = False
#
#     def __init__(self, cam=1, video_target_dim=100, resize_factor=1):
#         self.cam = cam  # either 1 or 2
#         self.video_target_dim = video_target_dim
#         self.resize_factor = resize_factor
#
#         A, d, ca_times, t = self.construct()
#         self.neural_data = ArrayWithTime(A, ca_times)
#         self.behavioral_data = ArrayWithTime(d, t)
#
#     def construct(self):
#         self.acquire()
#
#         @save_to_cache("musall19_data")
#         def static_construct(cam, video_target_dim, resize_factor):
#             ca_sampling_rate = 31
#             video_sampling_rate = 30
#
#             #### load A
#             variables = loadmat(self.inner_data_path / 'data.mat', squeeze_me=True, simplify_cells=True)
#             A = variables["data"]['dFOF']
#             _, n_samples_per_trial, _ = A.shape
#             A = np.vstack(A.T)
#
#             #### load trial start and end times, in video frames
#             def read_floats(file):
#                 with open(file) as fhan:
#                     text = fhan.read()
#                     return [float(x) for x in text.split(",")]
#
#             on_times = read_floats(self.dataset_base_path / "trialOn.txt")
#             off_times = read_floats(self.dataset_base_path / "trialOff.txt")
#             trial_edges = np.array([on_times, off_times]).T
#             trial_edges = trial_edges[np.all(np.isfinite(trial_edges), axis=1)].astype(int)
#
#             #### load video
#             root_dir = self.inner_data_path / "BehaviorVideo"
#
#             start_V = 0  # 29801
#             end_V = trial_edges.max()  # 89928
#             used_V = end_V - start_V
#
#             Wid, Hei = 320, 240
#             Wid0, Hei0 = Wid // 4, Hei // 4
#
#             # resized by half
#             Data = np.zeros((used_V, Wid // resize_factor, Hei // resize_factor))
#
#             for k in tqdm(range(16)):
#                 name = f'{root_dir}/SVD_Cam{cam}-Seg{k + 1}.mat'
#                 # Load MATLAB .mat file
#                 mat_contents = loadmat(name)
#                 V = mat_contents['V']  # (89928, 500)
#                 U = mat_contents['U']  # (500, 4800)
#
#                 VU = V[start_V:end_V, :].dot(U)  # (T, 4800)
#                 seg = VU.reshape((used_V, Wid0, Hei0))
#                 Wid1, Hei1 = Wid0 // resize_factor, Hei0 // resize_factor
#                 seg = resize(seg, (used_V, Wid1, Hei1), mode='constant')
#
#                 i, j = k // 4, (k % 4)
#                 Data[:, i * Wid1:(i+1) * Wid1, j * Hei1:(j+1) * Hei1] = seg
#
#             #### dimension reduce video
#             t = np.arange(Data.shape[0]) / video_sampling_rate
#             d = np.array(Data.reshape(Data.shape[0], -1))
#             del Data
#             d = proSVD.apply_and_cache(input_arr=d, output_d=video_target_dim, init_size=video_target_dim)
#             t, d = clip(t, d)
#
#             #### define times
#             ca_times = np.hstack([np.linspace(*trial_edges[i], n_samples_per_trial) for i in range(len(trial_edges))])
#             ca_times = ca_times / video_sampling_rate
#
#             return A, d, ca_times, t
#
#         return static_construct(self.cam, self.video_target_dim, self.resize_factor)
#
#     def acquire(self):
#         if not self.inner_data_path.is_dir():
#             # TODO: I think this is actually publicly downloadable
#             raise FileNotFoundError()


class Naumann24uDataset(Dataset):
    doi = None
    automatically_downloadable = False
    model_organism = ModelOrganism.FISH
    dataset_base_path = DATA_BASE_PATH / "naumann24u"
    sub_datasets = (
        "output_020424_ds1",
        "output_012824_ds3",
        "output_012824_ds6_fish3",
    )

    class BehaviorClassifier(adaptive_latents.transformer.StreamingTransformer):
        def __init__(self, threshold=.3, input_streams=None, output_streams=None, log_level=None):
            input_streams = input_streams or {0:'X'}
            super().__init__(input_streams=input_streams, output_streams=output_streams, log_level=log_level)
            self.history = deque(maxlen=15)
            self.threshold = threshold

        def _partial_fit_transform(self, data, stream, return_output_stream):
            if self.input_streams[stream] == 'X':

                output = []
                for angle in data:
                    self.history.append(angle)
                    h = np.squeeze(self.history)
                    if np.isnan(h).any():
                        data = ArrayWithTime(np.nan, data.t)
                    elif (h > self.threshold).any():
                        if (h < -self.threshold).any():
                            output.append(3)
                        else:
                            output.append(1)
                    elif (h < -self.threshold).any():
                        output.append(2)
                    else:
                        output.append(0)
                data = ArrayWithTime(output,data.t)

            stream = self.output_streams[stream]
            return data, stream if return_output_stream else data

        def get_params(self, deep=True):
            return dict(threshold=self.threshold) | super().get_params(deep=deep)

        # def expected_data_streams(self, rng, DIM, cycles=1):
        #     for _ in range(cycles):
        #         for s in self.input_streams:
        #             yield rng.normal(size=(10, DIM)), s

        def expected_data_streams(self, rng, DIM, cycles=1):
                for i in range(cycles):
                    for s in self.input_streams:
                        yield ArrayWithTime(rng.normal(size=(10, DIM)),i), s

    def __init__(self, sub_dataset_identifier=sub_datasets[0], beh_type='angle'):
        if isinstance(sub_dataset_identifier, int):
            sub_dataset_identifier = self.sub_datasets[sub_dataset_identifier]
        self.sub_dataset = sub_dataset_identifier
        (
            self.C,
            self.opto_stimulations,
            self.neuron_df,
            self.visual_stimuli,
            self.tail_position,
            self.frame_times,
            self.tail_times,
            self.tail_angle,
            self.pose_class,
            self.background_image,
            self.neuron_locations  # TODO: join this with neuron_df
        ) = self.construct(self.sub_dataset)

        self.neural_data = ArrayWithTime(self.C.T, self.frame_times)

        self.end_of_visual_period_sample = self.opto_stimulations['sample'].min() - 1
        self.end_of_visual_period_time = self.frame_times[self.end_of_visual_period_sample]
        self.n_neurons_in_opto = np.isfinite(self.neural_data[self.end_of_visual_period_sample, :]).sum()

        self.bin_width = np.median(np.diff(self.frame_times))
        warnings.warn("bin width is actually improper here")
        if beh_type == 'bout':
            self.behavioral_data = ArrayWithTime(self.pose_class, self.tail_times).reshape(-1,1)
        elif beh_type == 'angle':
            self.behavioral_data = ArrayWithTime(self.tail_angle, self.tail_times).reshape(-1,1)
        elif beh_type == 'whole tail':
            self.behavioral_data = ArrayWithTime(self.tail_position, self.tail_times)
        elif beh_type == 'offset':
            self.behavioral_data = ArrayWithTime(self.tail_position[:, -1, :] - self.tail_position[:, 0, :], self.tail_times)

    def construct(self, sub_dataset_identifier):
        visual_stimuli, optical_stimulations, C, string_tail_position, frame_times, tail_times, background_image, neuron_locations = self.acquire(sub_dataset_identifier)

        C[np.cumsum(C, axis=1) == 0] = np.nan

        # convert the dates from strings to offsets in seconds
        assert abs(tail_times[0] - frame_times[0]) < datetime.timedelta(minutes=3), 'Check start times/timezones match'
        experiment_start = min(tail_times[0], frame_times[0])
        ms = datetime.timedelta(seconds=1)
        tail_times = np.array([(t - experiment_start)/ms for t in tail_times])
        frame_times = np.array([(t - experiment_start)/ms for t in frame_times])
        if frame_times.size > (n_recorded_frames := C.shape[1]):
            warnings.warn('chopping last frames because C is too small')
            frame_times = frame_times[:n_recorded_frames]

        # convert the tail positions from strings to arrays
        tail_position = []
        for sample in string_tail_position:
            rows = sample[1:-1].split('[')[1:]
            rows = [row.split(']')[0].split(',') for row in rows]
            rows = [[int(x) for x in row] for row in rows]
            tail_position.append(rows)
        tail_position = np.array(tail_position)


        # make DF's
        visual_stimuli_df = pd.DataFrame({'sample': visual_stimuli[:,0].astype(int), 'time': frame_times[visual_stimuli[:,0].astype(int)], 'l_angle': visual_stimuli[:,2], 'r_angle': visual_stimuli[:,3]})

        optical_stimulation_df = pd.DataFrame({'sample': optical_stimulations[:, 0].astype(int), 'time': frame_times[optical_stimulations[:,0].astype(int)], 'target_neuron': optical_stimulations[:,2].astype(int)})

        target_neuron = optical_stimulation_df.target_neuron
        stim_groups = [0]
        group_sub_stim = [0]
        stim_name = ['A0']
        for i in range(1, len(target_neuron)):
            if target_neuron[i-1] != target_neuron[i]:
                stim_groups.append(stim_groups[-1]+1)
                group_sub_stim.append(0)
            else:
                stim_groups.append(stim_groups[-1])
                group_sub_stim.append(group_sub_stim[-1] + 1)
            stim_name.append(chr(stim_groups[-1] + 65) + str(group_sub_stim[-1]))

        optical_stimulation_df['stim_group'] = stim_groups
        optical_stimulation_df['group_sub_stim'] = group_sub_stim
        optical_stimulation_df['stim_name'] = stim_name

        neurons = {}
        for neuron_id in optical_stimulation_df['target_neuron']:
            locations = optical_stimulations[optical_stimulations[:,2] == neuron_id, 3:]
            assert np.all(np.std(locations, axis=0) == 0)
            neurons[neuron_id] = locations[0,:]
        neuron_df = pd.DataFrame.from_dict(neurons, orient='index', columns=['x', 'y'])

        displacement = tail_position[:, -1, :] - tail_position[:, 0, :]
        tail_angle = np.atan2(*(-displacement[:, ::-1]).T)

        pose_class = self.BehaviorClassifier().offline_run_on(ArrayWithTime(tail_angle[:,None,None], tail_times))

        return C, optical_stimulation_df, neuron_df, visual_stimuli_df, tail_position, frame_times, tail_times, tail_angle, pose_class, background_image, neuron_locations

    def acquire(self, sub_dataset_identifier):
        base = self.dataset_base_path / sub_dataset_identifier
        if not base.is_dir():
            print(base)
            raise FileNotFoundError()
        optical_stimulations = np.load(base/'photostims.npy')
        visual_stimuli = np.loadtxt(base/'stimmed.txt')
        tail_position = np.load(base/'tails.npy')

        frame_times = []
        with open(base/'timing'/ 'framesendtimes.txt') as fhan:
            for line in fhan:
                frame_times.append(datetime.datetime.fromisoformat(line[:-2])- datetime.timedelta(hours=5))

        tail_times = []
        with open(base/'timing'/ 'tailsendtimes.txt') as fhan:
            for line in fhan:
                tail_times.append(datetime.datetime.strptime(line[:-1], '%I:%M:%S.%f %p %m/%d/%Y'))



        c_filename = 'raw_C.txt'
        if sub_dataset_identifier == 'output_020424_ds1':
            c_filename = 'analysis_proc_C.txt'
        C = np.loadtxt(base/c_filename)

        neuron_locations = np.loadtxt(base / 'contours.txt')
        background_image = np.loadtxt(base / 'image.txt')

        return visual_stimuli, optical_stimulations, C, tail_position, frame_times, tail_times, background_image, neuron_locations

    def plot_colors(self, ax):
        theta = np.linspace(0, 360)
        ax.scatter(np.cos(theta * np.pi / 180), np.sin(theta * np.pi / 180), c=self.a2c(theta))
        ax.axis('equal')

    @staticmethod
    def a2c(a):
        a = (a + 30) % 360
        return matplotlib.cm.ScalarMappable(matplotlib.colors.Normalize(vmin=0, vmax=360), cmap=matplotlib.cm.hsv).to_rgba(a)


    def get_rectangular_block(self, n_neurons=150):
        # type: (Naumann24uDataset, int) -> ArrayWithTime
        cutoff1 = np.nonzero(np.nancumsum(self.neural_data[:,n_neurons]) > 0)[0][0]
        cutoff2 = np.nonzero(np.nancumsum(self.neural_data[cutoff1,::-1]))[0][0]
        neural_data = self.neural_data.slice(cutoff1, -1)[:,:-cutoff2]
        additional_cutoff_info = np.where(np.isnan(neural_data).any(axis=1))[0]
        if additional_cutoff_info.size > 0:
            cutoff3 = additional_cutoff_info[-1] + 1
            neural_data = neural_data.slice(cutoff3, -1)
        assert not np.isnan(neural_data).any()
        return copy.deepcopy(neural_data)


class Leventhal24uDataset(Dataset):
    doi = None
    model_organism = ModelOrganism.RAT
    dataset_base_path = DATA_BASE_PATH / 'leventhal24u'
    automatically_downloadable = False
#
    sub_datasets = (
        "R0493/R0493_20230720_ChAdvanced_230720_105441",
        # "R0466/R0466_20230403_ChoiceEasy_230403_095044",
        # "R0544_20240625_ChoiceEasy_240625_100738"
    )
#
    def __init__(self, sub_dataset_identifier=sub_datasets[0], bin_size=0.03):
        if isinstance(sub_dataset_identifier, int):
            sub_dataset_identifier = self.sub_datasets[sub_dataset_identifier]
        self.sub_dataset = sub_dataset_identifier
        self.bin_size = bin_size
        A, t, spike_times, clusters, trial_data, unflattened_trial_data = self.construct(sub_dataset_identifier)
        self.neural_data = ArrayWithTime(A, t)
        self.spike_times = spike_times
        self.spike_clusters = clusters
        self.trial_data = trial_data
        self.unflattened_trial_data = unflattened_trial_data
        # self.behavioral_data = ArrayWithTime(beh,t)
#
    def construct(self, sub_dataset_identifier):
        spike_times, clusters, trials = self.acquire(sub_dataset_identifier)

        # for neuron_to_drop in [846]:
        #     warnings.warn(f"dropping neuron {neuron_to_drop} because I suspect it was an 'other' unit; I need to check this")
        #     s = clusters != neuron_to_drop
        #     spike_times = spike_times[s]
        #     clusters = clusters[s]

        unique_clusters = np.unique(clusters)
        n_units = unique_clusters.size

        n_bins = np.ceil((spike_times.max() - spike_times.min()) / self.bin_size).astype(int) + 1
        bin_edges = np.arange(n_bins)*self.bin_size + spike_times.min()

        A = np.zeros((n_bins - 1, n_units))
        for i, c in enumerate(unique_clusters):
            A[:, i] = np.histogram(spike_times[clusters == c], bins=bin_edges)[0]

        bin_ends = bin_edges[1:]

        trial_data = pd.DataFrame(trials['trials'])

        df = trial_data

        df = df.rename(columns={'tone': 'toneType'})
        for column in ['timing', 'timestamps']:
            sub_df = df[column].apply(pd.Series)
            if column == 'timing':
                sub_df = sub_df.rename(columns=lambda x: "relative_" + x)

            df = pd.concat([df.drop(column, axis=1), sub_df], axis=1)
        flattened_trial_data = df

        # expected_trial_data_keys = {'Time', 'Attempt', 'Center', 'Target', 'Tone', 'RT', 'MT', 'pretone', 'outcome', 'SideNP', 'CenterNoseIn', 'SideInToFood'}
        # discovered_trial_data_keys = {k for k, v in log_data.items() if hasattr(v, '__len__') and len(v) == len(log_data['Time'])}
        # assert expected_trial_data_keys.difference(discovered_trial_data_keys) == set()
        # assert discovered_trial_data_keys.difference(expected_trial_data_keys) == set()
        #
        # trial_data = pd.DataFrame({key: log_data[key] for key in expected_trial_data_keys})

        return A, bin_ends, spike_times, clusters, flattened_trial_data, trial_data
#
    def acquire(self, sub_dataset_identifier):
        subset_base_path = self.dataset_base_path / sub_dataset_identifier

        lib_directory = (self.dataset_base_path / 'load-rhd-notebook-python').resolve()
        if not lib_directory.is_dir():
            print(f"""\
Please download (clone) `https://github.com/Intan-Technologies/load-rhd-notebook-python` into {self.dataset_base_path.resolve()}.""")

        info_file = subset_base_path / 'info.rhd'
        if not info_file.is_file():
            print(f"""\
Please place a copy of '{sub_dataset_identifier}' into '{self.dataset_base_path}'.""")

        sys.path.append(str(lib_directory)) # todo: this is really bad
        import importrhdutilities as rhd

        result, data_present = rhd.load_file(subset_base_path / 'info.rhd')
        assert not data_present
        sampling_frequency = result['frequency_parameters']['amplifier_sample_rate']

        spike_times = np.load(subset_base_path / 'spike_times.npy').flatten() / sampling_frequency
        clusters = np.load(subset_base_path / 'spike_clusters.npy').flatten()

        trials = loadmat(subset_base_path.parent / 'trials.mat', squeeze_me=True, simplify_cells=True)

        return spike_times, clusters, trials

        # t = np.fromfile(subset_base_path/'time.dat', dtype='int32') / sampling_frequency


class Zong22Dataset(Dataset):
    doi = "https://dx.doi.org/10.11582/2022.00008"
    automatically_downloadable = False
    model_organism = ModelOrganism.MOUSE
    dataset_base_path = DATA_BASE_PATH / 'zong22'

    def make_cookie_entry(area, animal_id, date, f_part, f_total, cookie_status, filtered):
        assert type(area) == str  # this can become static after 3.10
        cookie_status = 'with' if cookie_status else 'no'
        filtered = 'filtered' if filtered else ''
        return {
            'basepath':       f'{area}_recordings/{animal_id}/{date}/',
            'raw_frames':     f'{animal_id}_imaging_{date}_{cookie_status}cookies_00001.tif',
            'behavior_csv':   f'{animal_id}_imaging_{date}_{cookie_status}cookies_00001_trackingVideoDLC_resnet50_OPENMINI2P_bottomcameraAug26shuffle1_1030000{filtered}.csv',
            'behavior_video': f'{animal_id}_imaging_{date}_{cookie_status}cookies_00001_trackingVideo.avi',
            'part_of_F': (f_part,f_total)
        }

    def make_object_entry(area, animal_id, date, f_part, f_total, object_n, filtered):
        assert type(area) == str  # this can become static after 3.10
        object_str = f'object{object_n}' if object_n is not None else 'noobject'
        filtered = 'filtered' if filtered else ''
        return {
            'basepath':       f'{area}_recordings/{animal_id}/{date}/',
            'raw_frames':     f'{animal_id}_imaging_{date}_{object_str}_00001.tif',
            'behavior_csv':   f'{animal_id}_imaging_{date}_{object_str}_00001_trackingVideoDLC_resnet50_OPENMINI2P_bottomcameraAug26shuffle1_1030000{filtered}.csv',
            # 'behavior_video': f'{animal_id}_imaging_{date}_{object_str}_00001_trackingVideo.avi',
            'part_of_F': (f_part,f_total)
        }

    sub_datset_info = pd.DataFrame([
        make_cookie_entry('VC', '93562', '20200817', 1, 2, False, True),
        make_cookie_entry('VC', '93562', '20200817', 2, 2, True, True),

        make_cookie_entry('MEC', '94557', '20200822', 1, 2, False, False),
        make_cookie_entry('MEC', '94557', '20200822', 1, 2, True, False),

        make_object_entry('MEC', '94557', '20201008', 1, 3, None, True),
        make_object_entry('MEC', '94557', '20201008', 2, 3, 1, True),
        make_object_entry('MEC', '94557', '20201008', 3, 3, 2, True),
    ])

    sub_datasets = list(sub_datset_info.index)

    def __init__(self, sub_dataset_identifier=sub_datasets[0], neural_lag=0, neural_scale=1, pos_scale=1, hd_scale=1, h2b_scale=1):
        if isinstance(sub_dataset_identifier, int):
            sub_dataset_identifier = self.sub_datasets[sub_dataset_identifier]

        self.sub_dataset = sub_dataset_identifier
        self.neural_Fs = 15
        self.neural_lag = neural_lag
        self.neural_scale = neural_scale
        self.bin_width = 1/self.neural_Fs  # todo: make this universal?
        self.F, self.raw_images, self.behavior_video, self.behavior_df, self.n_cells, self.stat, self.ops = self.acquire()


        self.neural_data = ArrayWithTime(self.F.T * self.neural_scale, (np.arange(self.F.shape[1]) * 1 / self.neural_Fs) + self.neural_lag)
        self.behavioral_data = ArrayWithTime(self.behavior_df.loc[:, ['x', 'y', 'hd', 'h2b']] * np.array([pos_scale, pos_scale, hd_scale, h2b_scale]), self.behavior_df.loc[:, 't'])

        self.video_t = np.squeeze(self.behavioral_data.t)

    def acquire(self):
        sub_dataset_base_path = self.dataset_base_path / self.sub_datset_info.basepath[self.sub_dataset]
        if not sub_dataset_base_path.is_dir():
            print(f"Go download the dataset from {self.doi}. (Or remount the external drive on Tycho)")
            raise FileNotFoundError()

        iscell = np.load(sub_dataset_base_path / 'suite2p' / 'plane0' / 'iscell.npy')
        F_all = np.load(sub_dataset_base_path / 'suite2p' / 'plane0' / 'F.npy')
        n_cells = int(sum(iscell[:, 0]))

        stat = np.load(sub_dataset_base_path / 'suite2p' / 'plane0' / 'stat.npy', allow_pickle=True)
        ops = np.load(sub_dataset_base_path / 'suite2p' / 'plane0' / 'ops.npy', allow_pickle=True).item()

        def make_beh(fpath):
            pre_beh = pd.read_csv(fpath)
            columns = ["t"] + list(map(lambda a: f"{a[0]}_{a[1]}", zip(pre_beh.iloc[0, 1:], pre_beh.iloc[1, 1:])))
            columns = {pre_beh.columns[i]: columns[i] for i in range(len(columns))}
            beh = pre_beh.rename(columns=columns).iloc[2:].astype(float).reset_index(drop=True)
            beh.t = beh.t / self.neural_Fs
            return beh

        part, total = self.sub_datset_info.part_of_F[self.sub_dataset]
        block_length = F_all.shape[1] // total

        F_all = F_all - F_all.min(axis=1, keepdims=True)
        # F_all = F_all / np.median(F_all, axis=1, keepdims=True)

        F_all_0 = np.median(F_all, axis=1, keepdims=True)
        F_all = (F_all - F_all_0) / F_all_0

        F_all[np.isnan(F_all)] = 0

        F = F_all[:, (part - 1) * block_length: part * block_length]
        img = Image.open(sub_dataset_base_path / self.sub_datset_info.raw_frames[self.sub_dataset])
        video = None
        if isinstance(video_filename:=self.sub_datset_info.behavior_video[self.sub_dataset], str):
            video = pims.Video(sub_dataset_base_path / video_filename)
        beh = make_beh(sub_dataset_base_path / self.sub_datset_info.behavior_csv[self.sub_dataset])

        nose = self.get_behavior_trace(beh, 'nose')
        body = self.get_behavior_trace(beh, 'bodycenter')
        head = self.get_behavior_trace(beh, 'mouse')

        beh['hd'] = np.arctan2(*(nose - head).T)
        beh['h2b'] = np.linalg.norm(head - body, axis=1)
        beh['x'] = head[:,0]
        beh['y'] = head[:,1]


        return F, img, video, beh, n_cells, stat, ops

    def show_stim_pattern(self, ax, desired_stim):
        planes = []
        for i in range(500):
            self.raw_images.seek(i)
            planes.append(np.array(self.raw_images))

        im = np.mean(planes, axis=0)

        ax.matshow(-im, cmap='Grays')
        xs, ys = list(zip(*[cell['med'] for cell in self.stat]))
        map = ax.scatter(ys, xs, s=7, c=desired_stim)

        ax.get_figure().colorbar(map)

    @staticmethod
    def get_behavior_trace(beh, point_str, threshold=.999):
        point_trace = np.array([beh.loc[:, point_str + '_x'].to_numpy(),
                                beh.loc[:, point_str + '_y'].to_numpy()]).T
        s = beh.loc[:, point_str + '_likelihood'].to_numpy() < threshold
        point_trace[s] *= np.nan
        return point_trace


class DummyCircleDataset(Dataset):
    # meant to mock Odoherty21
    doi = None
    automatically_downloadable = False
    model_organism = None

    def __init__(self, Fs=33, total_time=600):
        self.rng = np.random.default_rng()
        self.Fs = Fs
        self.bin_width = 1 / self.Fs
        self.t = np.linspace(0, total_time, total_time * self.Fs + 1)

        neural_data, behavioral_data = self.construct()

        self.neural_data = ArrayWithTime(neural_data, self.t)
        self.behavioral_data = ArrayWithTime(behavioral_data, self.t)

    def acquire(self):
        pass

    def construct(self):
        self.acquire()
        # 15 samples is about 2pi end_time
        # 1 sample is bin_width start_time
        speed_factor = self.Fs * np.pi * 2 / 15
        neural_data = np.sin(np.vstack(3 * [self.t]).T * speed_factor)
        neural_data = neural_data + self.rng.normal(scale=.1, size=neural_data.shape)

        behavioral_data = np.sin(np.vstack(3 * [self.t]).T * speed_factor)
        behavioral_data = behavioral_data + self.rng.normal(scale=.1, size=behavioral_data.shape)
        return neural_data, behavioral_data


"""
class Low21Dataset(MultiDataset):
    doi = "https://doi.org/10.17632/hntn6m2pgk.1"
    automatically_downloadable = True
    dataset_base_path = DATA_BASE_PATH / 'low21'

    def acquire(self, sub_dataset_identifier=None):
        if len(list(self.dataset_base_path.glob("*.npy"))) == 0:
            datahugger.get(self.doi, self.dataset_base_path)

    # def construct(self, sub_dataset_identifier=None):
    #     pass
    #
    # def get_sub_datasets(self):
    #     pass
"""


