import argparse
import copy
import os
from os.path import join, basename
import sqlite3
from sqlite3 import Error

import numpy as np
import pandas as pd
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
import optuna
from optuna.pruners import MedianPruner, NopPruner
from optuna.integration import SkorchPruningCallback
from optuna.samplers import RandomSampler
from joblib import Parallel, delayed

from mnist_auto_aug.classwise_transforms import KEY_TO_TRANSFORM,\
    make_transform_objects, ClasswiseTransform
from mnist_auto_aug.training_utils import fit_and_predict



def _sample_classwise_transform(trial, train_set, classes, transforms_family):
    """ Samples classwise transform and use it to set the training set."""
    sampled_classwise_transform = list()
    for c in classes:
        operation = trial.suggest_categorical(
            f'operation_{c}', transforms_family
        )
        probability = trial.suggest_discrete_uniform(
            f'probability_{c}', 0., 1., 0.25
        )
        sampled_classwise_transform.append((operation, probability))

    transforms = make_transform_objects(
        sampled_classwise_transform,
        compose=False
    )
    classwise_transform = ClasswiseTransform(
        {c: t for c, t in zip(classes, transforms)}
    )
    train_set.transform = classwise_transform
    return train_set


def _sample_regular_transform(trial, train_set, transforms_family):
    """ Samples a regular (not classwise) transform and use it to set the
    training set."""
    sampled_transform = list()
    operation = trial.suggest_categorical('operation', transforms_family)
    probability = trial.suggest_discrete_uniform(f'probability', 0., 1., 0.25)
    sampled_transform.append((operation, probability))

    transforms = make_transform_objects(
        sampled_transform,
        compose=True
    )
    train_set.img_transform = transforms
    return train_set


def _sample_identity_transform(train_set):
    """ Placeholder where identity is sampled for all calses"""
    transforms = make_transform_objects(None)
    train_set.img_transform = transforms
    return train_set


def make_train_and_see_objective_skorch(
    training_path,
    model,
    train_set,
    valid_set,
    test_loader,
    model_params,
    callbacks,
    transforms_family=None,
    epochs=30,
    classes=None,
    classwise=True,
    baseline=False,
):
    """Objective factory where for each new trial, the augmentations are used
    to retrain the model, which is then evaluated on the validation set.
    Correspond to the metric used in AutoAugment [1]_

    Parameters
    ----------
    training_path : str
        Path to directory where everything is saved (checkpoints, results, etc)
    model : torch.nn.Module
        torch Module modeling a neural net classifier.
    train_set : torch.util.data.Dataset
        Dataset used for training the model.
    valid_set : torch.util.data.Dataset
        Dataset used for earlystopping and for assesssing the subpolicy used
        during trained.
    test_loader : torch.util.data.DataLoader
        Used mainly for logging final performance of each trial.
    model_params : dict
        Parameters to pass to NeuralNetClassifier other than module, callbacks
        and split.
    callbacks : list
        Skorch callbacks to use during training.
    transforms_family : list, optional
        List fo strings encoding transformations from
        classwise_transforms.KEY_TO_TRANSFORM to sample from. If omitted, all
        transformations will be considered. Defaults to None.
    epochs : int, optional
        Maximum number of epochs to train in each split. By default 30.
    classes : list, optional
        List of classes in the dataset. Used to know which classes to sample
        transforms for in case of classwise augmentation. Defaults to None (all
        MNIST classes).
    classwise : bool, optional
        Whether to sample classwise transforms or not. Defaults to True.
    baseline : bool, optional
        Whether to not transform (always sample the Identity). Defaults to
        False.

    Returns
    -------
    callable
        Objective function taking a trial as input and outputing a validation
        loss.

    References
    ----------
    .. [1] Cubuk, E. D., Zoph, B., Mane, D., Vasudevan, V., & Le, Q. V. (2019).
    Autoaugment: Learning augmentation strategies from data. In Proceedings of
    the IEEE/CVF Conference on Computer Vision and Pattern Recognition
    (pp. 113-123).
    """
    if transforms_family is None:
        transforms_family = list(KEY_TO_TRANSFORM.keys())
        transforms_family.remove('no-aug')

    if classes is None:
        classes = list(range(10))

    def _objective(trial):
        if baseline:
            train_set_w_tf = _sample_identity_transform(train_set)
        elif classwise:
            train_set_w_tf = _sample_classwise_transform(
                trial,
                train_set,
                classes,
                transforms_family,
            )
        else:
            train_set_w_tf = _sample_regular_transform(
                trial,
                train_set,
                transforms_family,
            )

        updated_callbacks = list()
        updated_callbacks += callbacks
        updated_callbacks.append(
            ("pruning", SkorchPruningCallback(trial, "valid_acc"))
        )

        (test_loss, test_acc), (valid_loss, valid_acc) = fit_and_predict(
            copy.deepcopy(model),
            train_set_w_tf,
            valid_set,
            test_loader,
            epochs=epochs,
            model_params=model_params,
            callbacks=callbacks,
            return_valid_perf=True,
        )

        trial_perf = pd.DataFrame(
            [[trial.number, test_loss, test_acc, valid_loss, valid_acc]],
            index=[1],  columns=range(5),
        )
        trial_perf.to_csv(
            join(training_path, "perf_per_tial.csv"),
            index=False,
            header=False,
            mode='a'  # Append to the end of the file
        )
        return valid_acc
    return _objective


def _start_db(experiment_path):
    """ create a database connection to a SQLite database """
    db_file_name = _derive_study_name(experiment_path) + '.db'
    db_file = join(experiment_path, db_file_name)

    conn = None
    try:
        conn = sqlite3.connect(db_file)
        return conn, db_file
    except Error as e:
        print(e)
        _close_db(conn)
        return None, None


def _close_db(conn):
    if conn:
        conn.close()


def _derive_study_name(experiment_path):
    return basename(experiment_path)


def _search(
    experiment_path,
    objective_factory,
    objective_params,
    model,
    sampler_class,
    seed,
    job_idx,
    n_trials,
    storage=None,
):
    """ Job for parallel augmentation search."""
    # Create a separate folder per job to avoid one erasing the others results
    job_folder = join(experiment_path, f"job_{job_idx}")
    os.makedirs(job_folder, exist_ok=True)

    objective = objective_factory(
        training_path=job_folder,
        model=copy.deepcopy(model),
        **objective_params
    )

    # Each job gets a different seed to make sure they don't sample the same
    # augmentations
    print(f">>> Seeding optuna sampler with {seed + job_idx}")
    sampler = sampler_class(seed=seed + job_idx)

    study = optuna.load_study(
        study_name=_derive_study_name(experiment_path),
        storage=storage,
        sampler=sampler,
    )
    study.optimize(objective, n_trials=n_trials)
    study.trials_dataframe().to_pickle(join(job_folder, "trials.pkl"))


def _split_n_trials(tot_trials, n_jobs):
    splitted_trials = np.array([int(tot_trials // n_jobs)] * n_jobs)
    for i in range(int(tot_trials % n_jobs)):
        splitted_trials[i] += 1
    return splitted_trials


def search_subpolicies(
    experiment_path,
    objective_factory,
    model,
    objective_params,
    sampler_class,
    seed,
    n_trials,
    pruning=True,
    n_jobs=1,
    job_idx=0
):
    pruner = MedianPruner() if pruning else NopPruner()
    os.makedirs(experiment_path, exist_ok=True)
    try:
        conn, db_file = _start_db(experiment_path)
        storage = f"sqlite:///{db_file}"

        study = optuna.create_study(
            direction='maximize',
            pruner=pruner,
            study_name=_derive_study_name(experiment_path),
            storage=storage,
            load_if_exists=True,
        )

        n_trials_per_job = _split_n_trials(n_trials, n_jobs)

        parallel = Parallel(n_jobs=n_jobs, prefer="processes")
        parallel(
            delayed(_search)(
                    experiment_path=experiment_path,
                    objective_factory=objective_factory,
                    objective_params=objective_params,
                    model=model,
                    sampler_class=sampler_class,
                    # Important to ensure workers sample differently
                    seed=seed,
                    job_idx=job_idx + i,
                    n_trials=job_n_trials,
                    storage=storage,
            ) for i, job_n_trials in enumerate(n_trials_per_job)
        )
    finally:
        _close_db(conn)

