import time
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
from pathlib import Path
from time import perf_counter

import torch
from BAE.training import get_EEGClassifier, parallel_learning_curve
from BAE.utils import (ClassWiseAugmentation, downsample, find_device,
                       get_dataset)
from braindecode.augmentation import (FrequencyShift, FTSurrogate,
                                      GaussianNoise, IdentityTransform,
                                      SignFlip, SmoothTimeMask, TimeReverse)
from braindecode.augmentation.transforms import (ChannelsDropout,
                                                 ChannelsShuffle,
                                                 ChannelsSymmetry)
from braindecode.util import set_random_seeds

parser = ArgumentParser(
    description='Compute learning curves for various transformations\
         implemented in braindecode.',
    formatter_class=ArgumentDefaultsHelpFormatter)

parser.add_argument('-e', '--epochs',
                    type=int,
                    default=2,
                    help="number of epochs used for each fit")
parser.add_argument('-k', '--folds',
                    type=int,
                    default=2,
                    help="number of folds used for\
                    each cross-validation training")
parser.add_argument('-n', '--subjects',
                    type=int,
                    default=5,
                    help="number of subjects used to create\
                        the whole dataset")
parser.add_argument('--prop',
                    type=float,
                    # np.logspace(3, 12, 10, base=1 / 2)
                    # slightly under 2^-9 --> 2^6 windows
                    default=[0.0055],
                    nargs='+',
                    help="dataset fractions used for the learning curve")
parser.add_argument('-t', '--transformations',
                    type=str,
                    default=["FTSurrogate", "GaussianNoise",
                             "SignFlip", "SmoothTimeMask", "TimeReverse",
                             "FrequencyShift", "IdentityTransform"],
                    nargs='+',
                    help="list of transformations to plot"
                    "can be chosen among: [FTSurrogate, GaussianNoise,"
                    "SignFlip, SmoothTimeMask, TimeReverse, FrequencyShift, "
                    "IdentityTransform]"
                    )
parser.add_argument('-j', '--n_jobs',
                    type=int,
                    default=1,
                    help="number of jobs that will run simultaneously")
parser.add_argument('-p', '--proba',
                    type=float,
                    default=0.2,
                    help="probability to apply augmentations")
parser.add_argument('-r', '--random_state',
                    type=int,
                    default=19,
                    help='set random state for repro_learning_curve')
parser.add_argument('-o', '--output',
                    type=str,
                    default='./examples/learning_curves',
                    help='path to the output folder')
parser.add_argument('-d', '--dataset',
                    type=str,
                    default='SleepPhysionet',
                    help='Dataset to use. Can be either SleepPhysionet or'
                    'BCI')
parser.add_argument('--device',
                    type=str,
                    default=None,
                    help='Device to use, default None will use CPU or cuda:1')
parser.add_argument('--downsampling',
                    action='store_true',
                    default=False,
                    help='Whether to downsample the training set so that all'
                    'classes are balanced.')


def main():
    args = parser.parse_args()
    proba = args.proba
    print("Arguments: {}\n".format(args))
    print(len(args.prop))

    if args.random_state:
        set_random_seeds(args.random_state, find_device()[0])

    # get preprocessed dataset
    windows = get_dataset(
        name=args.dataset,
        n_subjects=args.subjects,
        n_jobs=args.n_jobs)
    subjects_mask_d = None  # Make the code versatile for downsampling or not
    print('\nn_windows before DS: {}'.format(len(windows)))

    ch_names = windows.datasets[0].windows.ch_names
    sfreq = windows.datasets[0].windows.info['sfreq']

    if args.downsampling:
        windows, subjects_mask_d = downsample(
            windows, random_state=args.random_state)
    print('\nn_windows adter DS: {}'.format(len(windows)))

    clf_params = {}
    for transfo in args.transformations:

        if transfo == 'GaussianNoise':
            clf_params.update({'iterator_train__transforms': [GaussianNoise(
                probability=proba,
                std=0.1,
                random_state=args.random_state)]})
        elif transfo == 'FTSurrogate':
            clf_params.update({'iterator_train__transforms': [FTSurrogate(
                probability=proba,
                random_state=args.random_state)]})
        elif transfo == 'SignFlip':
            clf_params.update({'iterator_train__transforms': [SignFlip(
                probability=proba,
                random_state=args.random_state)]})
        elif transfo == 'SmoothTimeMask':
            clf_params.update({'iterator_train__transforms': [SmoothTimeMask(
                probability=proba,
                mask_len_samples=100,
                random_state=args.random_state)]})
        elif transfo == "TimeReverse":
            clf_params.update({'iterator_train__transforms': [TimeReverse(
                probability=proba,
                random_state=args.random_state)]})
        elif transfo == "FrequencyShift":
            clf_params.update({'iterator_train__transforms': [FrequencyShift(
                probability=proba,
                sfreq=sfreq,
                delta_freq_range=(0, 3),
                random_state=args.random_state)]})
        elif transfo == "IdentityTransform":
            clf_params.update(
                {'iterator_train__transforms': [IdentityTransform()]})
        elif transfo == "ChannelsDropout":
            clf_params.update({'iterator_train__transforms': [ChannelsDropout(
                probability=proba,
                p_drop=0.6,
                random_state=args.random_state)]})
        elif transfo == "ChannelsShuffle":
            clf_params.update({'iterator_train__transforms': [ChannelsShuffle(
                probability=proba,
                p_shuffle=0.1,
                random_state=args.random_state)]})
        elif transfo == "ChannelsSymmetry":
            clf_params.update({'iterator_train__transforms': [ChannelsSymmetry(
                probability=proba,
                ordered_ch_names=ch_names,
                random_state=args.random_state)]})
        elif transfo == "custom":
            aug_dict = {
                0: GaussianNoise(
                    probability=proba,
                    std=0.1,
                    random_state=args.random_state),
                1: GaussianNoise(
                    probability=proba,
                    std=0.1,
                    random_state=args.random_state),
                2: SignFlip(
                    probability=proba,
                    random_state=args.random_state),
                3: GaussianNoise(
                    probability=proba,
                    std=0.1,
                    random_state=args.random_state),
                4: SignFlip(
                    probability=proba,
                    random_state=args.random_state)}
            clf_params.update({'iterator_train__transforms': [
                ClassWiseAugmentation(aug_dict)]})

        if args.device:
            clf_params.update(
                {'device': torch.device(args.device)})

        clf = get_EEGClassifier(
            dataset_name=args.dataset,
            clf_params=clf_params,
            random_state=args.random_state)

        print(
            "\n-------\nDevice: {}\n-------\n".format(
                clf.get_params()['device']))

        print("\n---------------\nAugmentation: {}\n---------------\n".format(
            clf.get_params()['iterator_train__transforms']))
        t_start = perf_counter()

        lr_curve = parallel_learning_curve(
            windows,
            clf=clf,
            K=args.folds,
            proportions=args.prop,
            epochs=args.epochs,
            n_jobs=args.n_jobs,
            random_state=args.random_state,
            subjects_mask=subjects_mask_d)
        t_stop = perf_counter()

        print("\n---------------\n{} time: {}s\n---------------\n".format(
            transfo, round(t_stop - t_start, 3)))
        output = Path(args.output)

        current_time = time.strftime("%H-%M")
        lr_curve.to_pickle(
            output /
            '{}-{}-{}-p{}.pkl'.format(
                transfo,
                args.folds,
                current_time,
                args.prop[0]))


if __name__ == "__main__":
    main()
