
import pickle
import os
import random

import numpy as np
from sklearn.model_selection import GroupShuffleSplit, StratifiedShuffleSplit
import torch
from torch import autograd

from braindecode.augmentation.base import IdentityTransform
from braindecode.augmentation.transforms import TimeReverse, FTSurrogate,\
    MissingChannels, ShuffleChannels, ChannelSymmetry, TimeMask,\
    GaussianNoise, SignFlip, FrequencyShift, BandstopFilter, RandomZRotation,\
    RandomYRotation, RandomXRotation
from braindecode.util import set_random_seeds


POSSIBLE_TRANSFORMS = {
    'no-aug': (IdentityTransform, {"default": {}}),
    'flip': (TimeReverse, {"default": {}}),
    'ft-surrogate': (FTSurrogate, {
        "default": {"magnitude": 1.0},
        "edf": {"magnitude": 0.9},
        "mass": {"magnitude": 1.0}
    }),
    'channel-dropout': (MissingChannels, {
        "default": {'magnitude': 0.2},
        "edf": {'magnitude': 0.5},
        "mass": {'magnitude': 0.8},
    }),
    'channel-shuffle': (ShuffleChannels, {
        "default": {'magnitude': 0.2},
        "edf": {'magnitude': 0.5},
        "mass": {'magnitude': 0.6},
    }),
    # WRONG CHANNELS, CAREFUL ! FOR DEMONSTRATION PURPOSES ONLY AS
    # PHYSIONET IS NOT COMPATIBLE WITH THIS TRANSFORM
    'channel-sym': (ChannelSymmetry, {
        "default": {'ordered_ch_names': ['Fp1', 'Fp2']}
    }),
    'time-mask': (TimeMask, {
        "old": {'mask_len_samples': 100, "mag_range": (0, 100)},  # legacy
        "default": {"magnitude": 0.5},
        "edf": {"magnitude": 0.6},
        "mass": {"magnitude": 0.9},
    }),
    'noise': (GaussianNoise, {
        "old": {'noisy_ratio': 0.99, 'std': 0.1},  # legacy
        "default": {"magnitude": 0.5},
        "edf": {"magnitude": 0.5},
        "mass": {"magnitude": 1.0},
    }),
    'sign': (SignFlip, {"default": {}}),
    'bandstop': (BandstopFilter, {
        "old": {"mag_range": (0, 10)},  # legacy
        "default": {"magnitude": 0.5},
        "edf": {"magnitude": 0.6},
        "mass": {"magnitude": 0.1},
    }),
    'freq-shift': (FrequencyShift, {
        "default": {},
        "edf": {"magnitude": 0.1},
        "mass": {"magnitude": 0.5},
    }),
    'rotz': (RandomZRotation, {
        "default": {},
        "edf": {"magnitude": 0.1},
        "mass": {"magnitude": 0.2},
    }),
    'roty': (RandomYRotation, {
        "default": {},
        "edf": {"magnitude": 0.9},
        "mass": {"magnitude": 1.0},
    }),
    'rotx': (RandomXRotation, {
        "default": {},
        "edf": {"magnitude": 0.4},
        "mass": {"magnitude": 0.1},
    }),
}


ENFORCEABLE_TRANSFORMS = {
    'flip': (TimeReverse(1.0), 2),
    'channel-sym': (ChannelSymmetry(1.0, ['Fp1', 'Fp2']), 2)
}


def flexible_int(x):
    return int(x) if x.isnumeric() else x


def flexible_float(x):
    try:
        return float(x)
    except ValueError:
        return x


def save_obj(obj, name):
    saving_path = name + '.pkl'
    if os.path.isfile(saving_path):
        if name.split("_")[-1].isdigit():  # NOT TESTED
            count = int(name[-1]) + 1
            saving_path = name[:-1] + str(count) + '.pkl'
        saving_path = name + '_1.pkl'
    with open(saving_path, 'wb') as f:
        pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)


def load_obj(name):
    with open(name + '.pkl', 'rb') as f:
        return pickle.load(f)


# TODO: the following functions could be moved to braindecode-wip
# (or directly to a public braindecode PR)


def get_groups(dataset):
    if (
        hasattr(dataset, "description") and
        hasattr(dataset, "datasets") and
        "subject" in dataset.description
    ):
        return np.hstack([
            [subj] * len(dataset.datasets[rec])
            for rec, subj in enumerate(
                dataset.description['subject'].values
            )
        ])
    else:
        return np.arange(len(dataset))


def _structured_split(
    splitter_class,
    indices,
    ratio,
    groups,
    targets=None,
    random_state=None
):
    if ratio == 1:
        return indices, np.array([])

    if isinstance(ratio, float):
        assert (
            ratio > 0 and ratio < 1
        ), "When ratio is a float, it must be positive and <=1."
    else:
        assert isinstance(ratio, int), (
            f"ratio can be either int or float. Got {type(ratio)}: {ratio}"
        )
    splitter = splitter_class(
        n_splits=1,
        train_size=ratio,
        random_state=random_state
    )
    train_idx, test_idx = list(splitter.split(
        indices,
        y=targets,
        groups=groups
    ))[0]
    return indices[train_idx], indices[test_idx]


def grouped_split(indices, ratio, groups, random_state=None):
    return _structured_split(
        splitter_class=GroupShuffleSplit,
        indices=indices,
        ratio=ratio,
        groups=groups,
        random_state=random_state
    )


def stratified_split(indices, ratio, targets, random_state=None):
    return _structured_split(
        splitter_class=StratifiedShuffleSplit,
        indices=indices,
        ratio=ratio,
        groups=None,
        targets=targets,
        random_state=random_state
    )


def log2_grid(max_value, n_values=None):
    max_n_values = int(np.log2(max_value))
    if n_values is None:
        n_values = max_n_values
    return np.logspace(-n_values + 1, 0, n_values, base=2)


def linear_grid(max_value, n_values):
    return np.linspace(0, max_value, n_values)


def find_device(device=None):
    if device is not None:
        assert isinstance(device, str), "device should be a str."
        return torch.device(device), False
    cuda = torch.cuda.is_available()  # check if GPU is available
    if cuda:
        if torch.cuda.device_count() > 1:
            device = torch.device('cuda:1')
        else:
            device = torch.device('cuda')
    else:
        device = torch.device('cpu')
    return device, cuda


def worker_init_fn(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)  # THIS IS BAD
    random.seed(worker_seed)


def get_global_rngs_states(cuda=False):
    global_states = {
        "torch_rng": torch.get_rng_state(),
        "numpy_rng": np.random.get_state(),
        "python_rng": random.getstate(),
    }
    if cuda:
        global_states["cuda_rng"] = torch.cuda.get_rng_state_all()
    return global_states


def set_global_rngs_states(states=None, seed=None, cuda=False):
    if seed is not None:
        set_random_seeds(seed, cuda)
    if states is None:
        states = {}
    if "torch_rng" in states:
        torch.set_rng_state(states["torch_rng"].to("cpu"))
    if "cuda_rng" in states:
        torch.cuda.set_rng_state_all(
            [state.to("cpu") for state in states["cuda_rng"]]
        )
    if "numpy_rng" in states:
        np.random.set_state(states["numpy_rng"])
    if "python_rng" in states:
        random.setstate(states["python_rng"])


def compute_Hv(v, gradients, inputs):
    return autograd.grad(
        sum(
            (vi * gwi.clone()).sum()
            for (vi, gwi) in zip(v, gradients)
        ),
        inputs=inputs,
        retain_graph=True,
    )


def estimate_Hessian_largest_eig(gradients, inputs, num_simulations: int):
    """Power method, copied and modified from
    https://en.wikipedia.org/wiki/Power_iteration

    Parameters
    ----------
    gradients : tuple
        Tuple of tensors representing the gradients of the loss wrt the
        weights
    inputs : tuple
        Tuple of tensors/parameters corresponding to the weights. Should have
        same size as gradients.
    num_simulations : int
        Number of steps.

    Returns
    -------
    tuple
        Largest eigenvalues of H (derivative of gradients wrt inputs).
    """
    # Ideally choose a random vector
    # To decrease the chance that our vector
    # Is orthogonal to the eigenvector
    b_k = tuple(
        torch.rand(g.shape, device=g.device)
        for g in gradients
    )
    inputs = tuple(inputs)

    for i in range(num_simulations + 1):
        # calculate the matrix-by-vector product Ab
        b_k1 = compute_Hv(b_k, gradients, inputs)

        # calculate the norm
        b_k1_norm = torch.sqrt(
            sum(torch.linalg.norm(b_k1_i)**2 for b_k1_i in b_k1)
        )

        # re normalize the vector
        b_k = tuple(
            b_k1_i / b_k1_norm for b_k1_i in b_k1
        )

    # b_k is the estimated main eigenvector
    # the largest eigenvalue is hence obtained by:
    return b_k1_norm
