import os

import numpy as np
import pandas as pd
from statsmodels.tsa.arima_process import ArmaProcess
from sklearn.utils import check_random_state
import torch
from torch.utils.data import Dataset
from skorch.helper import predefined_split
from skorch.callbacks import EarlyStopping

from braindecode import EEGRegressor, EEGClassifier
from braindecode.augmentation.base import BaseDataLoader

from eeg_augment.utils import save_obj, find_device


class DummyDataset(Dataset):
    """Synthetic Dataset

    Parameters
    ----------
    X_shape : tuple
        Shape of X samples.
    n_samples : int
        Number of samples.
    X_generator : callable
        Callable taking no arguments and returning a new input sample X.
    invariant_op : callable
        Callable taking a sample X as input and returning corresponding target
        y. Should encode the desried invariance.
    """
    def __init__(
        self,
        X_shape,
        n_samples,
        X_generator,
        invariant_op,
    ):
        assert (
            isinstance(X_shape, tuple) or isinstance(X_shape, list)
        ) and len(X_shape) == 2,\
            "X_shape should be a tuple or list of n_samples 2."
        assert isinstance(n_samples, int) and n_samples > 0,\
            "n_samples should be a positve int."
        assert callable(invariant_op), "invariance_op should be callable."
        assert callable(X_generator), "X_generator should be callable."
        self.X_shape = X_shape
        self.n_samples = n_samples
        if invariant_op is None:
            def invariant_op(x):  # NOT TESTED
                return np.mean(x)
        self.invariant_op = invariant_op

        self.X_list = [X_generator() for _ in range(self.n_samples)]
        self.y_list = [invariant_op(X) for X in self.X_list]

    def __len__(self):
        return self.n_samples

    def __getitem__(self, idx):
        return self.X_list[idx], self.y_list[idx], 1.0


class IIDGaussianDataset(DummyDataset):
    """Synthetic Dataset where elements in X are IID sampled from univariate
    Gaussian distributions, centered at integers uniformly sampled from
    [[1, number of channels]]

    Parameters
    ----------
    X_shape : tuple
        Shape of X samples.
    n_samples : int
        Number of samples.
    invariant_op : callable
        Callable taking a sample X as input and returning corresponding target
        y. Should encode the desried invariance.
    random_state : int | np.random.Generator | None, optional
        Used to seed RNG.
    std : float, optional
        Standard deviation of Gaussian distribution.
    """
    def __init__(self, X_shape, random_state=None, std=0.2, **kwargs):
        assert 'X_generator' not in kwargs,\
            "X_generator is already set and should not be passed in arguments."
        assert (
            isinstance(X_shape, tuple) or isinstance(X_shape, list)
        ) and len(X_shape) == 2, "X_shape should be a tuple or list of size 2."
        self.rng = check_random_state(random_state)

        assert isinstance(std, float), "std should be a float."
        self.std = std

        super().__init__(X_shape, X_generator=self._generator, **kwargs)

    def _generator(self):
        n_channels, n_times = self.X_shape
        return self.rng.normal(scale=self.std, size=(n_channels, n_times)) +\
            self.rng.randint(1, n_channels + 1, n_channels).reshape(-1, 1)


class AR1Dataset(DummyDataset):
    """Synthetic Dataset where each channel in X is an AR1 signal

    Parameters
    ----------
    X_shape : tuple
        Shape of X samples.
    ar1_coefs : tuple | list
        First order AR coefficient for each channel.
    offsets : tuple | list
        Offsets for each channel.
    noise_std : tuple | list
        Std of the white AR noise component, for each channel of X.
    n_samples : int
        Number of samples.
    invariant_op : callable
        Callable taking a sample X as input and returning corresponding target
        y. Should encode the desried invariance.
    """
    def __init__(self, X_shape, ar1_coefs, offsets, noise_std, **kwargs):
        assert 'X_generator' not in kwargs,\
            "X_generator is already set and should not be passed in arguments."

        assert (
            isinstance(offsets, tuple) or isinstance(offsets, list)
        ) and len(offsets) == X_shape[0] and all(
            [isinstance(value, float) for value in offsets]
        ), ("offsets should be a tuple or list of float, ",
            "whose size should be equal to the number of channels in X_shape.",
            f"Got {offsets}")
        self.offsets = offsets

        assert (
            isinstance(ar1_coefs, tuple) or isinstance(ar1_coefs, list)
        ) and len(ar1_coefs) == X_shape[0] and all(
            [isinstance(value, float) for value in ar1_coefs]
        ), ("ar1_coefs should be a tuple or list of float, ",
            "whose size should be equal to the number of channels in X_shape.")
        self.ar1_coefs = [-c for c in ar1_coefs]

        assert (
            isinstance(noise_std, tuple) or isinstance(noise_std, list)
        ) and len(noise_std) == X_shape[0] and all(
            [isinstance(value, float) for value in noise_std]
        ), ("ar1_coefs should be a tuple or list of float, ",
            "whose size should be equal to the number of channels in X_shape.")
        self.noise_std = noise_std
        super().__init__(X_shape, X_generator=self._generator, **kwargs)

    def _generator(self):
        n_channels, n_times = self.X_shape
        X = np.empty(self.X_shape)

        for c, (ar1, offset, std) in enumerate(
            zip(self.ar1_coefs, self.offsets, self.noise_std)
        ):
            X[c, :] = ArmaProcess(
                np.array([1, ar1]), np.array([std])
            ).generate_sample(nsample=n_times) + offset
        return torch.as_tensor(X).float()


class LinearRegression(torch.nn.Module):
    """Simple linear model

    Parameters
    ----------
    input_size : tuple
        Expected size of X samples.
    output_size : tuple
        Expected size of y samples.
    """
    def __init__(self, input_size, output_size):
        super(LinearRegression, self).__init__()
        self.flat = torch.nn.Flatten()
        self.linear = torch.nn.Linear(input_size, output_size)

    def forward(self, x):
        x = self.flat(x)
        out = self.linear(x)
        return out


def make_invariant_function_from_transform(
    transform,
    order,
    weights,
    bias,
    std,
    random_state=None,
):
    rng = check_random_state(random_state)

    def func(X):
        ret = list()
        for transform_n_times in range(order):
            transformed = X
            if isinstance(transformed, np.ndarray):
                transformed = torch.as_tensor(transformed).float()
            for j in range(transform_n_times):
                transformed, _ = transform(
                    transformed,
                    torch.zeros(transformed.shape[0])
                )
            mapped = np.dot(
                weights,
                transformed.flatten()
            ) + bias + rng.normal(0, scale=std)
            ret.append(mapped)
        if len(mapped.shape) == 0:
            return np.sum(ret, axis=0)  # NOT TESTED
        return np.argmax(np.sum(ret, axis=0))
    return func


def make_real_dataset_invariant(
    concat_dataset,
    invariant_op,
):
    for ds in concat_dataset.datasets:
        ds.y = np.array([
            invariant_op(X) for X, _, _ in ds
        ])
    return concat_dataset


def sanity_checker(
    dataset_class,
    invariant_op,
    transform,
    X_shape,
    y_shape=1,
    train_size=1024,
    valid_size=256,
    n_folds=10,
    n_epochs=300,
    batch_size=128,
    to_extract=None,
    save_path=None,
    classif=False,
    patience=None,
    verbose=1,
    device=None,
    **kwargs
):
    """Generate synthetic data, train linear model on it with and without
    augmentation and evaluate it on validation set with cross-validation

    Parameters
    ----------
    dataset_class : DummyDataset
        Used to instantiate train and validation datasets.
    invariant_op : callable
        Used as decision function in datasets.
    transform : braindecode.augmentation.Transform
        Augmentation to test.
    X_shape : tuple
        Shape of inputs.
    y_shape : int, optional
        Shape of output, by default 1
    train_size : int, optional
        Number of training samples, by default 1024
    valid_size : int, optional
        Number of validation samples, by default 256
    n_folds : int, optional
        Number of bootstrapped datasets, by default 10
    n_epochs : int, optional
        By default 300
    batch_size : int, optional
        By default 128
    to_extract : list, optional
        Metrics to extract, by default None
    save_path : str, optional
        Path where results should be saved, by default None
    classif : bool, optional
        Whether to create a classification task or not. By default False
        (regression).
    patience : int | None, optional
        If omitted, no EarlyStopping will be done. Otherwise, the value passed
        will be used to set it.
    verbose : boolean, optional
        Whether or not to print training history live. By default True.
    device : str, optional
        Device to train on. By default None.
    """
    if to_extract is None:
        to_extract = ['epoch', 'train_loss', 'valid_loss']

    callbacks = []
    if patience is not None:
        callbacks.append(('early_stopping', EarlyStopping(patience=patience)))

    device, cuda = find_device(device)
    if cuda:
        print("---- CUDA device detected! GPU training starting ---")

    results = []
    for k in range(1, n_folds + 1):
        should_try_again = True
        print(f"-------- Fold {k} of {n_folds}")
        while should_try_again:
            should_try_again = False
            if not classif:  # NOT TESTED
                train_set = dataset_class(
                    X_shape=X_shape,
                    n_samples=train_size,
                    invariant_op=invariant_op,
                    **kwargs
                )
                valid_set = dataset_class(
                    X_shape=X_shape,
                    n_samples=valid_size,
                    invariant_op=invariant_op,
                    **kwargs
                )

                reference_model = EEGRegressor(
                    LinearRegression(np.prod(X_shape), y_shape),
                    train_split=predefined_split(valid_set),
                    batch_size=batch_size,
                    iterator_train=BaseDataLoader,
                    callbacks=callbacks,
                    verbose=int(verbose),
                    device=device
                )

                augmented_model = EEGRegressor(
                    LinearRegression(np.prod(X_shape), y_shape),
                    train_split=predefined_split(valid_set),
                    batch_size=batch_size,
                    iterator_train=BaseDataLoader,
                    iterator_train__transforms=transform,
                    callbacks=callbacks,
                    verbose=int(verbose),
                    device=device
                )
            else:
                train_set = dataset_class(
                    X_shape=X_shape,
                    n_samples=train_size,
                    invariant_op=invariant_op,
                    **kwargs
                )
                valid_set = dataset_class(
                    X_shape=X_shape,
                    n_samples=valid_size,
                    invariant_op=invariant_op,
                    **kwargs
                )

                reference_model = EEGClassifier(
                    LinearRegression(np.prod(X_shape), y_shape),
                    criterion=torch.nn.CrossEntropyLoss,
                    train_split=predefined_split(valid_set),
                    batch_size=batch_size,
                    iterator_train=BaseDataLoader,
                    callbacks=callbacks,
                    verbose=int(verbose),
                    device=device
                )

                augmented_model = EEGClassifier(
                    LinearRegression(np.prod(X_shape), y_shape),
                    criterion=torch.nn.CrossEntropyLoss,
                    train_split=predefined_split(valid_set),
                    batch_size=batch_size,
                    iterator_train=BaseDataLoader,
                    iterator_train__transforms=transform,
                    callbacks=callbacks,
                    verbose=int(verbose),
                    device=device
                )

            try:
                reference_model.fit(
                    train_set,
                    y=None,
                    epochs=n_epochs
                )
                augmented_model.fit(
                    train_set,
                    y=None,
                    epochs=n_epochs
                )

                fold_res = {
                    'reference': pd.DataFrame(
                        reference_model.history.to_list()
                    ).loc[:, to_extract],
                    'augment': pd.DataFrame(
                        augmented_model.history.to_list()
                    ).loc[:, to_extract],
                }
                results.append(fold_res)
                if save_path is not None:
                    os.makedirs(os.path.dirname(save_path), exist_ok=True)
                    save_obj(results, save_path)

                should_try_again = False
            except IndexError as err:
                should_try_again = True
                continue
