import os
from os.path import join

from eeg_augment.training_utils import make_args_parser, handle_dataset_args,\
    check_grid, CrossvalModel, prepare_training, read_config
from eeg_augment.auto_augmentation import _make_transform_objects,\
    MASS_BEST_TRANSFORMS, PHYSIONET_BEST_TRANSFORMS


def launch_training(
    training_dir,
    epochs,
    windows_dataset,
    sfreq,
    n_classes=5,
    device=None,
    lr=1e-3,
    batch_size=128,
    num_workers=4,
    parallel=False,
    transforms=None,
    ordered_ch_names=None,
    randaugment_seq_len=None,
    randaugment_transform_collec=None,
    early_stop=True,
    model_to_use=None,
    data_ratio=None,
    max_ratios=None,
    grouped_subset=True,
    dataset_defaults_to_use="default",
    n_jobs=1,
    verbose=False,
    random_state=None,
    **kwargs
):
    """Start cross-validated training

    Parameters
    ----------
    training_dir : str
        Path to training directory where all logs and results will be saved.
    epochs : int
        Number of epochs
    windows_dataset : torch.data.utils.ConcatDataset
        Dataset to train on.
    sfreq : int
        Sampling frequency.
    n_classes = int, optional
        Range of classes to predict. By default 5.
    device : str, optional
        Device to train on. By default None.
    lr : float, optional
        Learning rate to use. By default 1e-3.
    batch_size : int, optional
        Batch size to use. By default 128.
    num_workers : int, optional
        Number of workers to use for data loading. By default 4.
    random_state : int, optional
        Seed for random number generator. By default None.
    parallel : bool, optional
        Whether to parallelize model across GPUs using torch.nn.DataParallel.
        By default False.
    transforms : list | tuple | None, optional
        Tuple or list of tuples containing a transform name, a probability and
        a magnitude to apply. By default None.
    ordered_ch_names : list | None, optional
        List of strings representing the channels of the montage considered.
        Only used for instantiating transforms needing this information. Has to
        be in standard 10-20 style. The order has to be consistent with
        the order of channels in the input matrices that will be fed to the
        transform. This channel will be used to compute approximate sensors
        positions from a standard 10-20 montage. Defaults to None.
    randaugment_seq_len : int, optional
        Length of transforms sequence to sample in case RandAugment is used.
    randaugment_transform_collec : list | None, optional
        Collection of possible transforms to sample from when using
        RandAugment. Defaults to None (all available transforms).
    early_stop : bool, optional
        Whether to carry early stopping with patience=30. By default True
    model_to_use : str | None, optional
        Defines which net should be used. By default (None) will use
        SleepStager. If set to 'lin' will use one layer linear net.
    data_ratio : list | float | str | None, optional
        Float or list of floats between 0 and 1 or a str (only "log2" and
        "lin" supported for now). Each element will be used to build a
        subset of the cross-validated training sets (valid and test sets
        are conserved). Omitting it or setting it to None, is equivalent to
        setting it to [1.] (using the whole training set). If "log2" is
        passed, then a log2 scale of training sizes will be used. By default
        None.
    max_ratios : int | None, optional
        Maximum number of subsets to be built. Useful when ratios are
        computed automatically. Ignored when data_ratio is omitted or
        a list.
    grouped_subset : bool, optional
        Whether to compute training subsets taking groups (subjects) into
        account or not. When False, stratified spliting will be used to
        build the subsets. By default True.
    dataset_defaults_to_use : str, optional
        What default parameters to use to instantiate the transforms. Can be
        "default", "edf" (for Physionet dataset default magnitudes) or "mass"
        (for MASS dataset default magnitudes). Defaults to "default".
    n_jobs : int, optional
        Number of workers to use for parallelizing across splits. By
        default 1.
    verbose : bool, optional
        By default False.
    **kwargs
        Other arguments to pass to CrossvalModel object.
    """
    device, rng, model, model_params, shared_callbacks = prepare_training(
        windows_dataset=windows_dataset,
        lr=lr,
        batch_size=batch_size,
        num_workers=num_workers,
        early_stop=early_stop,
        sfreq=sfreq,
        n_classes=n_classes,
        parallel=parallel,
        device=device,
        model_to_use=model_to_use,
        random_state=random_state,
    )

    transforms = _make_transform_objects(
        transforms,
        ordered_ch_names=ordered_ch_names,
        sfreq=sfreq,
        randaugment_seq_len=randaugment_seq_len,
        randaugment_transform_collec=randaugment_transform_collec,
        defaults_key=dataset_defaults_to_use,
        random_state=rng,
    )
    model_params['iterator_train__transforms'] = transforms

    cross_val_training = CrossvalModel(
        training_dir,
        model,
        model_params=model_params,
        shared_callbacks=shared_callbacks,
        balanced_loss=True,  # Not settable for now
        monitor='valid_bal_acc_best',  # Not settable for now
        should_checkpoint=True,  # Not settable for now
        log_tensorboard=True,  # Not settable for now
        random_state=random_state,
        **kwargs
    )
    cross_val_training.learning_curve(
        windows_dataset=windows_dataset,
        epochs=epochs,
        data_ratios=data_ratio,
        max_ratios=max_ratios,
        grouped_subset=grouped_subset,
        n_jobs=n_jobs,
        verbose=verbose
    )


def _build_transform_tuples(
    transform_key,
    probabilities=0.5,
    magnitudes=None,
    n_probas=None,
    n_mags=None,
):
    """Build triplets (transform_name, probability, magnitude)

    Parameters
    ----------
    transform_key : str
        String encoding transform. Will be first element of all tuples created.
    probabilities : float | array-like | str, optional
        If string, will look for a grid generator with the given key and create
        the probabilities to use automatically. If an array or float, will use
        values to set the probability of transforms. By default 0.5
    magnitudes : float | array-like | str | None, optional
        If string, will look for a grid generator with the given key and create
        the probabilities to use automatically. If an array or float, will use
        values to set the probability of transforms. By default None
    n_probas : int, optional
        Ignored if `probabilities` is not a string. Will be used to know the
        size of the grid of values to create otherwise. By default None.
    n_mags : int, optional
        Ignored if `magnitudes` is not a string. Will be used to know the
        size of the grid of values to create otherwise. By default None.
    """
    probabilities = check_grid(probabilities, max_value=1., n_values=n_probas)
    print(f"Grid of probabilities created: {probabilities}")
    magnitudes = check_grid(magnitudes, max_value=1., n_values=n_mags)
    print(f"Grid of magnitudes created: {magnitudes}")
    triplets = list()
    for probability in probabilities:
        for magnitude in magnitudes:
            triplets.append((transform_key, probability, magnitude))
    return triplets


def train_with_different_settings(
    experiment_dir,
    transform_key,
    probabilities=0.5,
    magnitudes=None,
    n_probas=None,
    n_mags=None,
    **training_params,
):
    """Will launch multiple trainings for a single transform, each with a
    different setting in terms of probability/magnitude

    Parameters
    ----------
    experiment_dir : str
        Path where training data and results will be stored.
    transform_key : str
        String encoding what transform to use.
    probabilities : float | array-like | str, optional
        If string, will look for a grid generator with the given key and create
        the probabilities to use automatically. If an array or float, will use
        values to set the probability of transforms. By default 0.5
    magnitudes : float | array-like | str | None, optional
        If string, will look for a grid generator with the given key and create
        the probabilities to use automatically. If an array or float, will use
        values to set the probability of transforms. By default None
    n_probas : int, optional
        Ignored if `probabilities` is not a string. Will be used to know the
        size of the grid of values to create otherwise. By default None.
    n_mags : int, optional
        Ignored if `magnitudes` is not a string. Will be used to know the
        size of the grid of values to create otherwise. By default None.
    **training_params
        Parameters to be passed to launch_training function.
    """
    assert isinstance(experiment_dir, str),\
        "experiment_dir should be a str."
    os.makedirs(experiment_dir, exist_ok=True)

    transform_tuples = _build_transform_tuples(
        transform_key,
        probabilities=probabilities,
        magnitudes=magnitudes,
        n_probas=n_probas,
        n_mags=n_mags,
    )

    print(f"Transform triplets created:")
    for triplet in transform_tuples:
        print(triplet)

    for transform_tuple in transform_tuples:
        setting_name = "-".join(map(str, transform_tuple[1:]))
        print(f"\n---------- {setting_name} ----------\n")
        setting_training_folder = join(experiment_dir, setting_name)
        launch_training(
            training_dir=setting_training_folder,
            transforms=transform_tuple,
            **training_params
        )


def train_with_different_transforms(
    experiment_dir,
    transforms=None,
    probabilities=0.5,
    magnitudes=None,
    n_probas=None,
    n_mags=None,
    **training_params
):
    """Carry multiple cross-validated trainings of SleepStageChambon2018 on
    SleepPhysionet dataset with different augmentations.

    Parameters
    ----------
    experiment_dir : str
        Parent directory where all training directories will be created.
    transforms : list | None, optional
        List of strings encoding the transforms to apply separately.
        By default None (no augmentation).
    probabilities : float | array-like | str, optional
        If string, will look for a grid generator with the given key and create
        the probabilities to use automatically. If an array or float, will use
        values to set the probability of transforms. By default 0.5
    magnitudes : float | array-like | str | None, optional
        If string, will look for a grid generator with the given key and create
        the probabilities to use automatically. If an array or float, will use
        values to set the probability of transforms. By default None
    n_probas : int, optional
        Ignored if `probabilities` is not a string. Will be used to know the
        size of the grid of values to create otherwise. By default None.
    n_mags : int, optional
        Ignored if `magnitudes` is not a string. Will be used to know the
        size of the grid of values to create otherwise. By default None.
    **training_params
        Parameters to be passed to launch_training function.
    """
    # XXX: Change API to avoid ignored args, by moving probabilities
    # and magnitudes automatic grid creation to __main__
    assert isinstance(experiment_dir, str),\
        "experiment_dir should be a str."
    os.makedirs(experiment_dir, exist_ok=True)

    if transforms is None:
        transforms = ["no-aug"]
    elif isinstance(transforms, str):
        transforms = [transforms]
    elif not isinstance(transforms, list):  # UNTESTED
        raise ValueError(
           "transforms should be a string, list of strings or None."
        )

    for transform in transforms:
        print(f"\n########## {transform} ##########\n")

        transform_training_folder = join(experiment_dir, transform)
        if transform == 'randaugment':
            transform_training_folder += "_depth"
            transform_training_folder += str(
                training_params['randaugment_seq_len']
            )

        train_with_different_settings(
            transform_training_folder,
            transform_key=transform,
            probabilities=probabilities,
            magnitudes=magnitudes,
            n_probas=n_probas,
            n_mags=n_mags,
            **training_params,
        )


if __name__ == '__main__':
    parser = make_args_parser()
    parser.add_argument(
        "--rand_n",
        type=int,
        default=2,
        help="Length of transforms sequence to sample in RandAugment. Ignored"
             "for other transforms."
    )

    parser.add_argument(
        "--rand_best_only",
        action="store_true",
        help="Whether to only use known best transforms in RandAugment.",
    )
    args = parser.parse_args()

    windows_dataset, ch_names, sfreq = handle_dataset_args(args)

    transforms_collection = None
    if args.rand_best_only and args.dataset == "mass":
        transforms_collection = MASS_BEST_TRANSFORMS
    if args.rand_best_only and args.dataset == "edf":
        transforms_collection = PHYSIONET_BEST_TRANSFORMS

    training_params = {
        'windows_dataset': windows_dataset,
        'epochs': args.epochs,
        'sfreq': sfreq,
        'device': args.device,
        'lr': args.lr,
        'batch_size': args.batch_size,
        'num_workers': args.num_workers,
        'random_state': args.random_state,
        'early_stop': args.early_stop,
        'n_folds': args.nfolds,
        'should_load_state': args.should_load_state,
        'train_size_over_valid': args.train_size_over_valid,
        'model_to_use': args.model,
        'data_ratio': args.data_ratio,
        'max_ratios': args.max_ratios,
        'grouped_subset': args.grouped_subset,
        'n_jobs': args.n_jobs,
        'ordered_ch_names': ch_names,
        'randaugment_seq_len': args.rand_n,
        'randaugment_transform_collec': transforms_collection,
        "dataset_defaults_to_use": args.dataset
    }

    if args.config:
        config = read_config(args.config)
        training_params.update(config["split"])
        training_params.update(config["training"])

    train_with_different_transforms(
        experiment_dir=args.training_dir,
        transforms=args.augment,
        probabilities=args.proba,
        magnitudes=args.mag,
        n_probas=args.n_probas,
        n_mags=args.n_mags,
        **training_params
    )
