import os
from copy import deepcopy

import click
import numpy as np
from ruamel import yaml


class EasyAccessForNestedDict:
    def __init__(self, obj):
        self.obj = obj

    def get(self, key):
        def _get(c, k):
            return _get(c[k[0]], k[1:]) if len(k) > 1 else c[k[0]]

        try:
            return _get(self.obj, key.split('.')) if isinstance(key, str) \
                else _get(self.obj, key)
        except KeyError as exc:
            raise KeyError(f'Failed to get key {key}') from exc

    def set(self, key, value):
        def _set(c, k, v):
            if len(k) > 1:
                return _set(c[k[0]], k[1:], v)
            else:
                o = c.get(k[0])
                c[k[0]] = v
                return o

        return _set(self.obj, key.split('.'), value) if isinstance(key, str) \
            else _set(self.obj, key, value)


def make_one_at_a_time_experiments(cfg, writer):
    keys_and_choices = {
        'exp_params.uncertainty_type': ('aleatoric', 'epistemic', 'total'),
        'architecture.params.reset_to_same_weights': (True, False),
        'pseudolabeler.reassign_pseudo_labels': (True, False),
        'exp_params.reset_weights_after_pl': (True, False),
        'loss.positive_unlabeled_loss.prior': (0.2, 0.25, 0.3, 0.4, 0.5, 0.6, 0.65, 0.7, 0.75, 0.8),
        'loss.pseudo_labeled_loss_weight': (0.0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.5, 0.75, 0.9),
        'pseudolabeler.new_labels_pos_neg_ratio': (None, 1, 0.667),
        'pseudolabeler.use_soft_labels': (True, False),
        'loss.positive_unlabeled_loss.class': ('nnpu', 'upu'),
    }

    for key, choices in keys_and_choices.items():
        old_value = cfg.get(key)
        for value in choices:
            if value != old_value:
                writer(
                    cfg,
                    f'{key.split(".")[-1]}_{value}',
                    {key: value},
                )


def make_ensemble_or_mcdrop_grid_experiments(cfg, writer):
    """
    Make a grid search to find best ensemble (if config has mc dropout)
    or mc dropout (if config has ensemble) parameters
    """
    arch = cfg.get('architecture.class')
    if arch == 'ensemble':
        for n in [5, 10, 25, 50]:
            for p in [0.02, 0.05, 0.1, 0.2, 0.25, 0.5]:
                writer(cfg, f'mcdrop_n{n}_p{p}', {
                    'architecture.class': 'kiryo_cnn',
                    'architecture.params.backbone': None,
                    'architecture.params.mc_dropout_pl_samples': n,
                    'architecture.params.pdrop': p,
                }, n=1)
    else:
        bs = cfg.get('exp_params.batch_size')
        for n in [2, 5, 10]:
            for p in [0.05, 0.1, 0.15, 0.2, 0.25]:
                writer(cfg, f'ensemble_n{n}_p{p}', {
                    'architecture.class': 'ensemble',
                    'architecture.params.backbone': arch,
                    'architecture.params.mc_dropout_pl_samples': None,
                    'architecture.params.pdrop': p,
                    'architecture.params.ensemble_size': n,
                    'exp_params.batch_size': bs // n,
                }, n=1)


def make_ensemble_or_mcdrop_experiments(cfg, writer):
    """
    After the grid search above we know the best parameters,
    so make five repetitions of this
    """

    # ensemble size
    for s in [2, 3, 5, 10]:
        if cfg.get('architecture.params.ensemble_size') != s:
            writer(cfg, f'ensemble_n{s}', {
                'architecture.class': 'ensemble',
                'architecture.params.backbone': 'kiryo_cnn',
                'architecture.params.mc_dropout_pl_samples': None,
                'architecture.params.pdrop': 0.15,
                'architecture.params.ensemble_size': s,
                'exp_params.batch_size': 2048 // s,
            })

    # mc dropout
    if cfg.get('architecture.class') == 'ensemble':
        writer(cfg, 'ensemble_mcd', {
            'architecture.class': 'kiryo_cnn',
            'architecture.params.backbone': None,
            'architecture.params.mc_dropout_pl_samples': 25,
            'architecture.params.pdrop': 0.25,
            'architecture.params.ensemble_size': None,
            'exp_params.batch_size': 2048,
        })


def make_no_labeled_validation_experiments(cfg, writer):
    """
    Here we try using a PU validation set
    """
    writer(cfg, 'pu_validation', {
        'exp_params.validate_on_true_labels': False,
        'exp_params.fit_postprocessors_on_labeled_validation': False,
        'exp_params.temperature_scale': False,
    })


def make_positive_bias_experiments(cfg, writer):
    """
    Here we test nnPU and nnPUSB on the biased dataset, as well as
    nnPUSB on the unbiased dataset
    """
    cfg.obj['exp_params.post_processing'] = {'pusb_scaler': {'prior': 0.4}}
    writer(cfg, 'positive_bias_bno_lnnpusb', {
        'dataset.class': 'cifar10',
        'loss.positive_unlabeled_loss.class': 'nn_pusb',
    })
    writer(cfg, 'positive_bias_byes_lnnpusb', {
        'dataset.class': 'cifar10_pusb',
        'loss.positive_unlabeled_loss.class': 'nn_pusb',
    })
    cfg.obj.pop('exp_params.post_processing')
    writer(cfg, 'positive_bias_byes_lnnpu', {
        'dataset.class': 'cifar10_pusb',
        'loss.positive_unlabeled_loss.class': 'nnpu',
    })


def make_temperature_experiments(cfg, writer):
    """
    Here we try three ways of doing temperature scaling:
     - not doing it at all
     - doing it on the assigned pseudo-labels
     - doing it on the labeled validation set
    """
    ot = cfg.get('exp_params.temperature_scale')
    ov = cfg.get('exp_params.fit_postprocessors_on_labeled_validation')
    for t in [True, False]:
        for v in [True, False]:
            if (not t and v) or (t == ot and v == ov):
                continue

            if not t:
                ft = 'no'
            elif v:
                ft = 'onval'
            else:
                ft = 'onpl'

            writer(cfg, f'temperature_on_{ft}', {
                'exp_params.validate_on_true_labels': True,
                'exp_params.fit_postprocessors_on_labeled_validation': v,
                'exp_params.temperature_scale': t,
            })


def make_pl_experiments(cfg, writer):
    """
    Here we try different values for the uncertainty threshold
    and maximum number of pseudo-labels to assign
    """
    for u in [0.01, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.25]:
        writer(cfg, f'pseudolabeler_unc_{u}', {
            'pseudolabeler.max_new_labels': None,
            'pseudolabeler.new_labels_uncertainty_percentile': None,
            'pseudolabeler.new_labels_max_uncertainty': u,
        })

    for u in [None, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55]:
        writer(cfg, f'pseudolabeler_unlab_unc_{u}', {
            'pseudolabeler.unlabel_min_uncertainty': u
        })

    for n in [50, 100, 500, 1000, 2500]:
        writer(cfg, f'pseudolabeler_num_{n}', {
            'pseudolabeler.max_new_labels': n,
            'pseudolabeler.new_labels_uncertainty_percentile': None,
            'pseudolabeler.new_labels_max_uncertainty': None,
        })


def make_dataset_experiments(cfg, writer):
    """
    Here we run on the other datasets with the parameters
    from hyperband
    """

    writer(cfg, 'dataset_mnist', {
        'dataset.class': 'mnist',
        'architecture.params.backbone': 'mlp',
        'architecture.params.input_shape': 784,
        'architecture.params.layer_shape': 300,
        'architecture.params.n_hidden_layers': 4,

        # these values are fixed based on the cifar10 ablations
        'architecture.class': 'ensemble',
        'architecture.params.ensemble_size': 2,
        'loss.pseudo_labeled_loss_weight': 0.1,
        'loss.positive_unlabeled_loss.prior': 0.5,
        'pseudolabeler.max_new_labels': 1000,
        'pseudolabeler.new_labels_max_uncertainty': 0.05,
        'pseudolabeler.new_labels_uncertainty_percentile': None,
        'pseudolabeler.unlabel_min_uncertainty': 0.45,

        # these values are from hyperband
        'exp_params.batch_size': 16384,
        'exp_params.learning_rate': 0.002,
        'exp_params.weight_decay': 0.0002,
        'exp_params.pseudolabel_every': 150,
        'architecture.params.pdrop': 0.05,
    })

    writer(cfg, 'dataset_fashion_mnist', {
        'dataset.class': 'fashion_mnist',
        'architecture.params.backbone': 'mlp',
        'architecture.params.input_shape': 784,
        'architecture.params.layer_shape': 300,
        'architecture.params.n_hidden_layers': 4,

        # these values are fixed based on the cifar10 ablations
        'architecture.class': 'ensemble',
        'architecture.params.ensemble_size': 2,
        'loss.pseudo_labeled_loss_weight': 0.1,
        'loss.positive_unlabeled_loss.prior': 0.5,
        'pseudolabeler.max_new_labels': 1000,
        'pseudolabeler.new_labels_max_uncertainty': 0.05,
        'pseudolabeler.new_labels_uncertainty_percentile': None,
        'pseudolabeler.unlabel_min_uncertainty': 0.35,

        # these values are from hyperband
        'exp_params.batch_size': 8192,
        'exp_params.learning_rate': 0.002,
        'exp_params.weight_decay': 0.0002,
        'exp_params.pseudolabel_every': 50,
        'architecture.params.pdrop': 0.05,
    })

    writer(cfg, 'dataset_brain_cancer', {
        'dataset.class': 'brain_cancer',
        'architecture.params.backbone': 'cnn',

        # these values are fixed based on the cifar10 ablations
        'architecture.class': 'ensemble',
        'architecture.params.ensemble_size': 2,
        'loss.pseudo_labeled_loss_weight': 0.1,
        'loss.positive_unlabeled_loss.prior': 0.5,
        'pseudolabeler.max_new_labels': 1000,
        'pseudolabeler.new_labels_max_uncertainty': 0.05,
        'pseudolabeler.new_labels_uncertainty_percentile': None,
        'pseudolabeler.unlabel_min_uncertainty': 0.35,

        # these values are from hyperband
        'architecture.params.channels_in': 3,
        'architecture.params.n_conv_blocks': 3,
        'architecture.params.first_conv_block_channels': 16,
        'architecture.params.conv_block_channel_growth_factor': 1.5,
        'architecture.params.after_conv_flatten_filters': None,
        'architecture.params.head_units': 64,
        'architecture.params.head_layers': 1,
        'architecture.params.pdrop': 0.5,
    })

    writer(cfg, 'dataset_imdb', {
        'dataset.class': 'imdb',
        'dataset.base_folder': '/storage/groups/imm01/workspace/emilio-new/data/',

        # these values are fixed based on the cifar10 ablations
        'architecture.class': 'ensemble',
        'architecture.params.ensemble_size': 2,
        'loss.pseudo_labeled_loss_weight': 0.1,
        'loss.positive_unlabeled_loss.prior': 0.5,
        'pseudolabeler.max_new_labels': 1000,
        'pseudolabeler.new_labels_max_uncertainty': 0.05,
        'pseudolabeler.new_labels_uncertainty_percentile': None,
        'pseudolabeler.unlabel_min_uncertainty': 0.35,

        # these values are from hyperband
        'dataset.tokenizer_name': 'spacy',
        'dataset.tokenizer_language': 'en_core_web_sm',
        'dataset.glove_name': '6B',
        'dataset.glove_dim': 200,
        'architecture.class': 'ensemble',
        'architecture.params.ensemble_size': 2,
        'architecture.params.backbone': 'lstm',
        'architecture.params.mc_dropout_pl_samples': None,
        'architecture.params.input_size': 200,
        'architecture.params.lstm_hidden_size': 128,
        'architecture.params.lstm_num_layers': 2,
        'architecture.params.bidirectional': True,
        'architecture.params.mlp_layers': 2,
        'architecture.params.mlp_batchnorm': True,
        'architecture.params.mlp_units': 196,
        'architecture.params.lstm_dropout': 0.25,
        'architecture.params.mlp_dropout': 0.2,
        'architecture.params.reset_to_same_weights': True,
        'exp_params.batch_size': 128,
        'exp_params.learning_rate': 0.002,
        'exp_params.weight_decay': 0.0002,
        'exp_params.pseudolabel_every': 100,
    })

    for i in range(10):
        for loss in ['nnpu', 'nn_pusb']:
            # baseline
            writer(cfg, f'dataset_stl10_{loss.replace("_", "")}_baseline_r{i}', {
                'dataset.class': 'stl10',
                'dataset.val_fold': i,
                'dataset.n_labeled': None,
                'architecture.params.backbone': 'cnn',
                'loss.positive_unlabeled_loss.class': loss,

                'loss.pseudo_labeled_loss_weight': 0.1,
                'loss.positive_unlabeled_loss.prior': 0.4,
                'pseudolabeler.max_new_labels': 1000,
                'pseudolabeler.new_labels_max_uncertainty': 0.05,
                'pseudolabeler.new_labels_uncertainty_percentile': None,
                'pseudolabeler.unlabel_min_uncertainty': 0.35,

                'architecture.class': 'cnn',
                'architecture.params.ensemble_size': None,
                'architecture.params.channels_in': 3,
                'architecture.params.n_conv_blocks': 3,
                'architecture.params.first_conv_block_channels': 32,
                'architecture.params.conv_block_channel_growth_factor': 2,
                'architecture.params.after_conv_flatten_filters': None,
                'architecture.params.head_units': 128,
                'architecture.params.head_layers': 2,
                'architecture.params.pdrop': 0.2,
                'trainer_params.pl_iterations': 0,
            }, n=1)

            # our ensemble
            writer(cfg, f'dataset_stl10_{loss.replace("_", "")}_r{i}', {
                'dataset.class': 'stl10',
                'dataset.val_fold': i,
                'dataset.n_labeled': None,
                'architecture.params.backbone': 'cnn',
                'loss.positive_unlabeled_loss.class': loss,

                # these values are fixed based on the cifar10 ablations
                'architecture.class': 'ensemble',
                'architecture.params.ensemble_size': 2,
                'loss.pseudo_labeled_loss_weight': 0.1,
                'loss.positive_unlabeled_loss.prior': 0.4,
                'pseudolabeler.max_new_labels': 1000,
                'pseudolabeler.new_labels_max_uncertainty': 0.05,
                'pseudolabeler.new_labels_uncertainty_percentile': None,
                'pseudolabeler.unlabel_min_uncertainty': 0.35,

                # these values are from hyperband
                'architecture.params.channels_in': 3,
                'architecture.params.n_conv_blocks': 3,
                'architecture.params.first_conv_block_channels': 32,
                'architecture.params.conv_block_channel_growth_factor': 2,
                'architecture.params.after_conv_flatten_filters': None,
                'architecture.params.head_units': 128,
                'architecture.params.head_layers': 2,
                'architecture.params.pdrop': 0.2,
            }, n=1)


def make_predictions_experiments(cfg, writer):
    """
    Here we test whether using the raw predictions helps over using the uncertainty.
    """
    # use mean prediction of two networks
    writer(cfg, 'predictions_mean', {
        'exp_params.uncertainty_type': 'predictions',
    })

    # use prediction of a single network
    writer(cfg, 'predictions_single', {
        'exp_params.uncertainty_type': 'predictions',
        'architecture.params.ensemble_size': 1,
    })

    # use prediction of a single network and try calibrating
    writer(cfg, 'predictions_single_temp', {
        'exp_params.uncertainty_type': 'predictions',
        'architecture.params.ensemble_size': 1,
        'exp_params.validate_on_true_labels': True,
        'exp_params.fit_postprocessors_on_labeled_validation': True,
        'exp_params.temperature_scale': True,
    })


def make_skewed_experiments(cfg, writer):
    """
    These experiments use a variant of cifar10 where the classes
    making positive examples are sampled unevenly
    """
    cfg.obj['exp_params']['post_processing'] = {'pusb_scaler': {'prior': 0.4}}
    writer(cfg, 'skewed_cifar10_nnpusb', {
        'dataset.class': 'cifar10',
        'loss.positive_unlabeled_loss.class': 'nn_pusb',
    })
    writer(cfg, 'skewed_skewed_cifar10_nnpusb', {
        'dataset.class': 'skewed_cifar10',
        'loss.positive_unlabeled_loss.class': 'nn_pusb',
    })

    cfg.obj['exp_params']['post_processing'].pop('pusb_scaler')
    writer(cfg, 'skewed_skewed_cifar10_nnpu', {
        'dataset.class': 'skewed_cifar10',
        'loss.positive_unlabeled_loss.class': 'nnpu',
    })


def make_imbalanced_experiments(cfg, writer):
    """
    These experiments change the ratio of positives to negatives
    in the data by using fewer vehicle classes.
    """
    pos_classes = [0, 1, 8, 9]

    for i in range(1, 4):
        writer(cfg, f'imbalanced_{i}', {
            'dataset.class': 'skewed_cifar10',
            'loss.positive_unlabeled_loss.prior': i / 10,
            'dataset.positive_probs': {
                k: 1 / i for k in pos_classes[:i]
            }
        })


@click.command()
@click.argument('base-config', type=click.Path())
@click.option('--config-dir', type=click.Path(),
              default='ablations/configs', help='Where to save generated ablations')
@click.option('--output-dir', type=click.Path(),
              default='ablations/outputs', help='Where to save training outputs')
@click.option('--n-repetitions', default=5, help='How many runs for each configuration')
@click.option('--seed', default=1416269, help='Random generator seed')
@click.option('--dry-run', is_flag=True, help='Do not write any file')
def main(base_config, config_dir, output_dir, n_repetitions, seed, dry_run):
    """
    Creates ablation studies by modifying the given configuration.
    """

    # repetition i of each experiments use the same seeds
    rng = np.random.default_rng(seed)
    seeds = [int(rng.integers(0, 999999999)) for _ in range(2 * n_repetitions)]

    def update_config_and_write_files(cfg, base_name, updates, n=None):
        """
        Updates the config with the given values, writes n files with
        that config but different seeds, and restores the configuration
        to the previous values.
        """
        n = n or n_repetitions
        updates['logging_params.manual_seed'] = 0
        updates['dataset.seed'] = 0
        new_conf = EasyAccessForNestedDict(deepcopy(cfg.obj))
        for k, v in updates.items():
            new_conf.set(k, v)

        for i in range(n):
            fname = os.path.join(
                config_dir,
                f'{base_name}_r{i}.yaml' if n > 1 else base_name + '.yaml'
            )
            new_conf.set('logging_params.manual_seed', seeds[2 * i])
            new_conf.set('dataset.seed', seeds[2 * i + 1])

            # do not write to file if there's no change
            # this will prevent snakemake from running things unnecessarily
            if os.path.exists(fname):
                with open(fname) as f:
                    old = yaml.safe_load(f)
            else:
                old = None

            if old is None or new_conf.obj != old:
                print('Writing', fname)
                if not dry_run:
                    with open(fname, 'w') as f:
                        yaml.safe_dump(new_conf.obj, f)
            else:
                print('No change in', fname)

    with open(base_config) as f:
        cfg = EasyAccessForNestedDict(yaml.safe_load(f))

    update_config_and_write_files(cfg, 'base', {})
    make_one_at_a_time_experiments(cfg, update_config_and_write_files)
    make_positive_bias_experiments(cfg, update_config_and_write_files)
    make_no_labeled_validation_experiments(cfg, update_config_and_write_files)
    make_temperature_experiments(cfg, update_config_and_write_files)
    make_dataset_experiments(cfg, update_config_and_write_files)
    make_pl_experiments(cfg, update_config_and_write_files)
    make_predictions_experiments(cfg, update_config_and_write_files)
    make_skewed_experiments(cfg, update_config_and_write_files)
    make_imbalanced_experiments(cfg, update_config_and_write_files)

    # step 1 - grid search some stuff
    #make_ensemble_or_mcdrop_grid_experiments(cfg, update_config_and_write_files)

    # step 2 - perform five repetitions of best thing found above
    make_ensemble_or_mcdrop_experiments(cfg, update_config_and_write_files)

    if dry_run:
        print('Dry run - nothing changed')


if __name__ == '__main__':
    main()
