# ---- Imports ---- #
from nlb_tools.nwb_interface import NWBDataset
from nlb_tools.make_tensors import make_train_input_tensors, make_eval_input_tensors
import numpy as np
import h5py
from pathlib import Path

# ---- Run params ---- #
# DANDI_ROOT = Path("~/dandi/").expanduser()
DANDI_ROOT = Path("/home/anon/lvm/code/dandi/").expanduser()
dataset_name = "mc_maze"
bin_size_ms = 5

# ---- Dataset locations ---- #
datapath_dict = {
    "mc_maze": DANDI_ROOT / "000128/sub-Jenkins/",
    "mc_rtt": DANDI_ROOT / "000129/sub-Indy/",
    "area2_bump": DANDI_ROOT / "000127/sub-Han/",
    "dmfc_rsg": DANDI_ROOT / "000130/sub-Haydn/",
    "mc_maze_large": DANDI_ROOT / "000138/sub-Jenkins/",
    "mc_maze_medium": DANDI_ROOT / "000139/sub-Jenkins/",
    "mc_maze_small": DANDI_ROOT / "000140/sub-Jenkins/",
}
prefix_dict = {
    "mc_maze": "*full",
    "mc_maze_large": "*large",
    "mc_maze_medium": "*medium",
    "mc_maze_small": "*small",
}
datapath = datapath_dict[dataset_name]
prefix = prefix_dict.get(dataset_name, "")
save_path = f"./{dataset_name}_{bin_size_ms}ms.h5"

# ---- Load dataset ---- #
dataset = NWBDataset(datapath, prefix)
dataset.resample(bin_size_ms)

# ---- Extract data ---- #
train_dict = make_train_input_tensors(
    dataset,
    dataset_name,
    "train",
    save_file=False,
    include_forward_pred=False,
    include_behavior=True,
)
val_dict = make_train_input_tensors(
    dataset,
    dataset_name,
    "val",
    save_file=False,
    include_forward_pred=False,
    include_behavior=True,
)

train_spikes = np.dstack(
    [train_dict["train_spikes_heldin"], train_dict["train_spikes_heldout"]]
)
val_spikes = np.dstack(
    [val_dict["train_spikes_heldin"], val_dict["train_spikes_heldout"]]
)

spikes = np.concatenate([train_spikes, val_spikes], axis=0)
train_data = spikes[0 : len(train_spikes) + len(val_spikes) // 2]
n_val_test = len(spikes) - len(train_data)
valid_data = spikes[
    len(train_spikes)
    + len(val_spikes) // 2 : len(train_spikes)
    + len(val_spikes) // 2
    + n_val_test // 4
]
test_data = spikes[len(train_spikes) + len(val_spikes) // 2 + n_val_test // 4 :]

train_behavior = train_dict["train_behavior"]
val_behavior = val_dict["train_behavior"]
behavior = np.concatenate([train_behavior, val_behavior], axis=0)
train_beh = behavior[0 : len(train_behavior) + len(val_behavior) // 2]
n_val_test = len(behavior) - len(train_beh)
valid_beh = behavior[
    len(train_behavior)
    + len(val_behavior) // 2 : len(train_behavior)
    + len(val_behavior) // 2
    + n_val_test // 4
]
test_beh = behavior[len(train_behavior) + len(val_behavior) // 2 + n_val_test // 4 :]

# ---- Save to lfads-torch compatible format ---- #
with h5py.File(save_path, "w") as h5file:
    h5file.create_dataset("train_encod_data", data=train_data)
    h5file.create_dataset("train_recon_data", data=train_data)
    h5file.create_dataset("valid_encod_data", data=valid_data)
    h5file.create_dataset("valid_recon_data", data=valid_data)
    h5file.create_dataset("test_encod_data", data=test_data)
    h5file.create_dataset("test_recon_data", data=test_data)
    h5file.create_dataset("train_behavior", data=train_beh)
    h5file.create_dataset("valid_behavior", data=valid_beh)
    h5file.create_dataset("test_behavior", data=test_beh)

# ---- Print summary ---- #
print(f"Train data shape: {train_data.shape}")
print(f"Valid data shape: {valid_data.shape}")
print(f"Test data shape: {test_data.shape}")
