'''
    Module for uncertainty estimation.
    Contains the RunnerUncertainty class and
    main method to train that Runner.
'''
import click

import jax_networks
import ensemble
from runner import Runner
import utils

class RunnerUncertainty(Runner):
    '''
        A runner object for uncertainty estimation
        Based on: https://openreview.net/forum?id=BJlahxHYDS
    '''
    def __init__(self, params):
        super().__init__(params)

        if not self.already_ran:
            # rebuild network with different output dim
            if self.params['type'] == 's':
                na = 1
            elif self.params['type'] == 'sa':
                na = self.num_actions
            self.network = jax_networks.build_network(
                        self.params['network'], self.dummy_obs,
                        na * self.params['feature_dim'],
                        self.params['width'], self.params['depth'])
            self.prior = jax_networks.build_network(
                        self.params['network'], self.dummy_obs,
                        na * self.params['feature_dim'],
                        self.params['width'], self.params['depth'],
                        prior=True)

            # build learner
            self.learner = ensemble.EnsembleLearner(
                                net = self.network, opt = self.opt,
                                prior_net = self.prior,
                                dataset = self.dataset,
                                normalize_fn = self.normalize_fn,
                                seed = self.params['seed'],
                                n_comp = self.params['n_comp'],
                                num_actions = na,
                                feature_dim = self.params['feature_dim'],
                                noise_scale = self.params['noise_scale'],
                                prior_scale = self.params['prior_scale'],
                                beta = self.params['beta'],
                                logger = self.train_logger)

    def eval(self, step):
        for label, data in zip(self.eval_labels, self.eval_datasets):
            _, transition = next(data)
            metrics = self.learner.eval(transition.data.observation,
                                        transition.data.action,
                                        {'label': label, 'step': step})
            self.eval_logger.write(metrics)

@click.command()
@click.option('--config', '-c', default='train_uncertainty', help='config file name')
@click.option('--options', '-o', multiple=True, nargs=2, type=click.Tuple([str, str]))
def main(config, options):
    params = utils.config_and_options_to_dict(config, options)
    runner = RunnerUncertainty(params)
    if not runner.already_ran:
        runner.train()


if __name__ == '__main__':
    main()
