# Authors: Robin Schirrmeister <robintibor@gmail.com>
#
# License: BSD (3-clause)


import os
import random

import numpy as np
import mne
import h5py
import torch
from sklearn.utils import check_random_state


def set_random_seeds(seed, cuda):
    """Set seeds for python random module numpy.random and torch.

    Parameters
    ----------
    seed: int
        Random seed.
    cuda: bool
        Whether to set cuda seed with torch.
    """
    random.seed(seed)
    torch.manual_seed(seed)
    if cuda:
        torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)


def np_to_var(
    X, requires_grad=False, dtype=None, pin_memory=False, **tensor_kwargs
):
    """
    Convenience function to transform numpy array to `torch.Tensor`.

    Converts `X` to ndarray using asarray if necessary.

    Parameters
    ----------
    X: ndarray or list or number
        Input arrays
    requires_grad: bool
        passed on to Variable constructor
    dtype: numpy dtype, optional
    var_kwargs:
        passed on to Variable constructor

    Returns
    -------
    var: `torch.Tensor`
    """
    if not hasattr(X, "__len__"):
        X = [X]
    X = np.asarray(X)
    if dtype is not None:
        X = X.astype(dtype)
    X_tensor = torch.tensor(X, requires_grad=requires_grad, **tensor_kwargs)
    if pin_memory:
        X_tensor = X_tensor.pin_memory()
    return X_tensor


def var_to_np(var):
    """Convenience function to transform `torch.Tensor` to numpy
    array.

    Should work both for CPU and GPU."""
    return var.cpu().data.numpy()


def corr(a, b):
    """
    Computes correlation only between terms of a and terms of b, not within
    a and b.

    Parameters
    ----------
    a, b: 2darray, features x samples

    Returns
    -------
    Correlation between features in x and features in y
    """
    # Difference to numpy:
    # Correlation only between terms of x and y
    # not between x and x or y and y
    this_cov = cov(a, b)
    return _cov_to_corr(this_cov, a, b)


def cov(a, b):
    """
    Computes covariance only between terms of a and terms of b, not within
    a and b.

    Parameters
    ----------
    a, b: 2darray, features x samples

    Returns
    -------
    Covariance between features in x and features in y
    """
    demeaned_a = a - np.mean(a, axis=1, keepdims=True)
    demeaned_b = b - np.mean(b, axis=1, keepdims=True)
    this_cov = np.dot(demeaned_a, demeaned_b.T) / (b.shape[1] - 1)
    return this_cov


def _cov_to_corr(this_cov, a, b):
    # computing "unbiased" corr
    # ddof=1 for unbiased..
    var_a = np.var(a, axis=1, ddof=1)
    var_b = np.var(b, axis=1, ddof=1)
    return _cov_and_var_to_corr(this_cov, var_a, var_b)


def _cov_and_var_to_corr(this_cov, var_a, var_b):
    divisor = np.outer(np.sqrt(var_a), np.sqrt(var_b))
    return this_cov / divisor


def wrap_reshape_apply_fn(stat_fn, a, b, axis_a, axis_b):
    """
    Reshape two nd-arrays into 2d-arrays, apply function and reshape
    result back.

    Parameters
    ----------
    stat_fn: function
        Function to apply to 2d-arrays
    a: nd-array: nd-array
    b: nd-array
    axis_a: int or list of int
        sample axis
    axis_b: int or list of int
        sample axis

    Returns
    -------
    result: nd-array
        The result reshaped to remaining_dims_a + remaining_dims_b
    """
    if not hasattr(axis_a, "__len__"):
        axis_a = [axis_a]
    if not hasattr(axis_b, "__len__"):
        axis_b = [axis_b]
    other_axis_a = [i for i in range(a.ndim) if i not in axis_a]
    other_axis_b = [i for i in range(b.ndim) if i not in axis_b]
    transposed_topo_a = a.transpose(tuple(other_axis_a) + tuple(axis_a))
    n_stat_axis_a = [a.shape[i] for i in axis_a]
    n_other_axis_a = [a.shape[i] for i in other_axis_a]
    flat_topo_a = transposed_topo_a.reshape(
        np.prod(n_other_axis_a), np.prod(n_stat_axis_a)
    )
    transposed_topo_b = b.transpose(tuple(other_axis_b) + tuple(axis_b))
    n_stat_axis_b = [b.shape[i] for i in axis_b]
    n_other_axis_b = [b.shape[i] for i in other_axis_b]
    flat_topo_b = transposed_topo_b.reshape(
        np.prod(n_other_axis_b), np.prod(n_stat_axis_b)
    )
    assert np.array_equal(n_stat_axis_a, n_stat_axis_b)
    stat_result = stat_fn(flat_topo_a, flat_topo_b)
    topo_result = stat_result.reshape(
        tuple(n_other_axis_a) + tuple(n_other_axis_b)
    )
    return topo_result


def get_balanced_batches(
    n_trials, rng, shuffle, n_batches=None, batch_size=None
):
    """Create indices for batches balanced in size
    (batches will have maximum size difference of 1).
    Supply either batch size or number of batches. Resulting batches
    will not have the given batch size but rather the next largest batch size
    that allows to split the set into balanced batches (maximum size difference 1).

    Parameters
    ----------
    n_trials : int
        Size of set.
    rng : RandomState
    shuffle : bool
        Whether to shuffle indices before splitting set.
    n_batches : int, optional
    batch_size : int, optional

    Returns
    -------
    batches: list of list of int
        Indices for each batch.
    """
    assert batch_size is not None or n_batches is not None
    if n_batches is None:
        n_batches = int(np.round(n_trials / float(batch_size)))

    if n_batches > 0:
        min_batch_size = n_trials // n_batches
        n_batches_with_extra_trial = n_trials % n_batches
    else:
        n_batches = 1
        min_batch_size = n_trials
        n_batches_with_extra_trial = 0
    assert n_batches_with_extra_trial < n_batches
    all_inds = np.array(range(n_trials))
    if shuffle:
        rng.shuffle(all_inds)
    i_start_trial = 0
    i_stop_trial = 0
    batches = []
    for i_batch in range(n_batches):
        i_stop_trial += min_batch_size
        if i_batch < n_batches_with_extra_trial:
            i_stop_trial += 1
        batch_inds = all_inds[range(i_start_trial, i_stop_trial)]
        batches.append(batch_inds)
        i_start_trial = i_stop_trial
    assert i_start_trial == n_trials
    return batches


def create_mne_dummy_raw(n_channels, n_times, sfreq, include_anns=True,
                         description=None, savedir=None, save_format='fif',
                         overwrite=True, random_state=None):
    """Create an mne.io.RawArray with fake data, and optionally save it.

    This will overwrite already existing files.

    Parameters
    ----------
    n_channels : int
        Number of channels.
    n_times : int
        Number of samples.
    sfreq : float
        Sampling frequency.
    include_anns : bool
        If True, also create annotations.
    description : list | None
        List of descriptions used for creating annotations. It should contain
        10 elements.
    savedir : str | None
        If provided as a string, the file will be saved under that directory.
    save_format : str | list
        If `savedir` is provided, this specifies the file format the data should
        be saved to. Can be 'raw' or 'hdf5', or a list containing both.
    random_state : int | RandomState
        Random state for the generation of random data.

    Returns
    -------
    raw : mne.io.Raw
        The created Raw object.
    save_fname : dict | None
        Dictionary containing the name the raw data was saved to.
    """
    random_state = check_random_state(random_state)
    data = random_state.rand(n_channels, n_times)
    ch_names = [f'ch{i}' for i in range(n_channels)]
    ch_types = ['eeg'] * n_channels
    info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types=ch_types)

    raw = mne.io.RawArray(data, info)

    if include_anns:
        n_anns = 10
        inds = np.linspace(
            int(sfreq * 2), int(n_times - sfreq * 2), num=n_anns).astype(int)
        onset = raw.times[inds]
        duration = [1] * n_anns
        if description is None:
            description = ['test'] * n_anns
        anns = mne.Annotations(onset, duration, description)
        raw = raw.set_annotations(anns)

    save_fname = dict()
    if savedir is not None:
        if not isinstance(save_format, list):
            save_format = [save_format]
        fname = os.path.join(savedir, 'fake_eeg_raw')

        if 'fif' in save_format:
            fif_fname = fname + '.fif'
            raw.save(fif_fname, overwrite=overwrite)
            save_fname['fif'] = fif_fname
        if 'hdf5' in save_format:
            h5_fname = fname + '.h5'
            with h5py.File(h5_fname, 'w') as f:
                f.create_dataset(
                    'fake_raw', dtype='f8', data=raw.get_data())
            save_fname['hdf5'] = h5_fname

    return raw, save_fname


class ThrowAwayIndexLoader(object):
    def __init__(self, net, loader, is_regression):
        self.net = net
        self.loader = loader
        self.last_i = None
        self.is_regression = is_regression

    def __iter__(self, ):
        normal_iter = self.loader.__iter__()
        for batch in normal_iter:
            if len(batch) == 3:
                x, y, i = batch
                # Store for scoring callbacks
                self.net._last_window_inds = i
            else:
                x, y = batch

            # TODO: should be on dataset side
            if hasattr(x, 'type'):
                x = x.type(torch.float32)
                if self.is_regression:
                    y = y.type(torch.float32)
                else:
                    y = y.type(torch.int64)
            yield x, y


def update_estimator_docstring(base_class, docstring):
    base_doc = base_class.__doc__.replace(' : ', ': ')
    idx = base_doc.find('callbacks:')
    idx_end = idx + base_doc[idx:].find('\n\n')
    # remove callback descripiton already included in braindecode docstring
    filtered_doc = base_doc[:idx] + base_doc[idx_end+6:]
    splitted = docstring.split('Parameters\n    ----------\n    ')
    out_docstring = splitted[0] + \
                    filtered_doc[filtered_doc.find('Parameters'):filtered_doc.find('Attributes')] + \
                    splitted[1] + \
                    filtered_doc[filtered_doc.find('Attributes'):]
    return out_docstring
