import os

import numpy as np
import scipy.io as io
import torch


def load_ecog_data(session_path, channels, signal_length, start_time, end_time):
    """
    Load ECoG data for a given session.

    Args:
        session_path: Path to session folder
        channels: List of channels to load and arange
        signal_length: Length of time windows to split into
        start_time: Time in seconds to start loading data from
        end_time: Time in seconds to stop loading data from

    Returns:
        Array of shape (num_time_windows, num_channels, signal_length)
    """

    assert start_time < end_time
    time_file = os.path.join(session_path, "ECoGTime.mat")
    time = io.loadmat(time_file)
    time = time["ECoGTime"].squeeze()
    traces = []
    for chan in channels:
        ecog_file = os.path.join(session_path, f"ECoG_ch{chan}.mat")
        data = io.loadmat(ecog_file)
        traces += [data[f"ECoGData_ch{chan}"]]
    ecog = np.concatenate(traces, axis=0)
    time_window = np.where((time > start_time) & (time < end_time))[0]
    ecog_time_window = ecog[:, time_window]

    splitter = (
        np.arange(1, (ecog_time_window.shape[1] // signal_length) + 1, dtype=int)
        * signal_length
    )
    splitted = np.split(ecog, splitter, axis=1)
    return np.array(splitted[:-1])


def standardize_array(arr, ax, set_mean=None, set_std=None, return_mean_std=False):
    """
    Standardize array along given axis. set_mean and set_std can be used to manually set mean and standard deviation.

    Args:
        arr: Array to be standardized.
        ax: Axis along which to standardize.
        set_mean: If not None, use this value as mean.
        set_std: If not None, use this value as standard deviation.
        return_mean_std: If True, return mean and standard deviation that were used for standardization.

    Returns:
        Standardized array.
        If return_mean_std is True: Mean
        If return_mean_std is True: Standard deviation
    """

    if set_mean is None:
        arr_mean = np.mean(arr, axis=ax, keepdims=True)
    else:
        arr_mean = set_mean
    if set_std is None:
        arr_std = np.std(arr, axis=ax, keepdims=True)
    else:
        arr_std = set_std

    assert np.min(arr_std) > 0.0
    if return_mean_std:
        return (arr - arr_mean) / arr_std, arr_mean, arr_std
    else:
        return (arr - arr_mean) / arr_std


class TychoUnconditionalDataset(torch.utils.data.Dataset):
    """
    Dataset of for the macaque ECoG data.

    Unconditional version: Only awake data is used.
    """

    def __init__(
        self,
        signal_length,
        channels=[1, 2, 3, 4],
        session_one_start=0.0,
        session_one_end=10000.0,
        filepath=None,
    ):
        super().__init__()
        self.signal_length = signal_length

        if channels == "all":
            channels = [i for i in range(1, 129)]
        self.num_channels = len(channels)

        filepath_sess1 = os.path.join(filepath, "Session1")
        self.array = standardize_array(
            load_ecog_data(
                filepath_sess1,
                channels,
                signal_length,
                session_one_start,
                session_one_end,
            ),
            ax=(0, 2),
        )

    def __getitem__(self, index):
        return_dict = {}
        return_dict["signal"] = torch.from_numpy(np.float32(self.array[index]))
        cond = self.get_cond()
        if cond is not None:
            return_dict["cond"] = cond
        return return_dict

    def get_cond(self):
        return None

    def __len__(self):
        return len(self.array)


def get_tycho_dataset(
    filepath,
    split_frac=0.9,
    random_seed=42,
    signal_length=500,
    batch_size=32,
    train_fraction=1.0,
):
    # split into train and test
    dataset = TychoUnconditionalDataset(
        signal_length=signal_length,
        channels="all",
        filepath=filepath,
    )
    train_size = int(split_frac * len(dataset))
    val_size = len(dataset) - train_size
    test_size = (2 * val_size) // 3
    val_size = val_size - test_size
    train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
        dataset,
        [train_size, val_size, test_size],
        generator=torch.Generator().manual_seed(random_seed),
    )
    if train_fraction < 1.0:
        train_dataset = torch.utils.data.Subset(
            train_dataset,
            np.arange(int(len(train_dataset) * train_fraction), dtype=int),
        )

    # print len of train and test dataset
    print(f"Train dataset length: {len(train_dataset)}")
    print(f"Val dataset length: {len(val_dataset)}")
    print(f"Test dataset length: {len(test_dataset)}")

    train_dataloader = torch.utils.data.DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True
    )
    val_dataloader = torch.utils.data.DataLoader(
        val_dataset, batch_size=batch_size, shuffle=False
    )
    test_dataloader = torch.utils.data.DataLoader(
        test_dataset, batch_size=batch_size, shuffle=False
    )

    return train_dataloader, val_dataloader, test_dataloader
